viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,637 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Unit tests for runtime_loader module.
17
+
18
+ Tests runtime loading from various sources including GitHub, HTTP, and filesystem.
19
+ """
20
+
21
+ from unittest.mock import MagicMock, patch
22
+
23
+ import pytest
24
+
25
+ from viettelcloud.aiplatform.trainer.backends.container import runtime_loader
26
+ from viettelcloud.aiplatform.trainer.constants import constants
27
+ from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
28
+ from viettelcloud.aiplatform.trainer.types import types as base_types
29
+
30
+ # Sample runtime YAML data for testing
31
+ SAMPLE_RUNTIME_YAML = {
32
+ "apiVersion": "trainer.kubeflow.org/v1alpha1",
33
+ "kind": "ClusterTrainingRuntime",
34
+ "metadata": {
35
+ "name": constants.DEFAULT_TRAINING_RUNTIME,
36
+ "labels": {"trainer.kubeflow.org/framework": "torch"},
37
+ },
38
+ "spec": {
39
+ "mlPolicy": {"numNodes": 1},
40
+ "template": {
41
+ "spec": {
42
+ "replicatedJobs": [
43
+ {
44
+ "name": "node",
45
+ "template": {
46
+ "spec": {
47
+ "template": {
48
+ "spec": {
49
+ "containers": [
50
+ {
51
+ "name": "trainer",
52
+ "image": "pytorch/pytorch:2.0.0",
53
+ }
54
+ ]
55
+ }
56
+ }
57
+ }
58
+ },
59
+ }
60
+ ]
61
+ }
62
+ },
63
+ },
64
+ }
65
+
66
+
67
+ @pytest.mark.parametrize(
68
+ "test_case",
69
+ [
70
+ TestCase(
71
+ name="parse github url",
72
+ expected_status=SUCCESS,
73
+ config={
74
+ "url": "github://kubeflow/trainer",
75
+ "expected_type": "github",
76
+ "expected_path": "kubeflow/trainer",
77
+ },
78
+ ),
79
+ TestCase(
80
+ name="parse github url with path",
81
+ expected_status=SUCCESS,
82
+ config={
83
+ "url": "github://myorg/myrepo/custom/path",
84
+ "expected_type": "github",
85
+ "expected_path": "myorg/myrepo/custom/path",
86
+ },
87
+ ),
88
+ TestCase(
89
+ name="parse https url",
90
+ expected_status=SUCCESS,
91
+ config={
92
+ "url": "https://example.com/runtime.yaml",
93
+ "expected_type": "https",
94
+ "expected_path": "https://example.com/runtime.yaml",
95
+ },
96
+ ),
97
+ TestCase(
98
+ name="parse http url",
99
+ expected_status=SUCCESS,
100
+ config={
101
+ "url": "http://example.com/runtime.yaml",
102
+ "expected_type": "http",
103
+ "expected_path": "http://example.com/runtime.yaml",
104
+ },
105
+ ),
106
+ TestCase(
107
+ name="parse file url",
108
+ expected_status=SUCCESS,
109
+ config={
110
+ "url": "file:///path/to/runtime.yaml",
111
+ "expected_type": "file",
112
+ "expected_path": "/path/to/runtime.yaml",
113
+ },
114
+ ),
115
+ TestCase(
116
+ name="parse absolute path",
117
+ expected_status=SUCCESS,
118
+ config={
119
+ "url": "/absolute/path/to/runtime.yaml",
120
+ "expected_type": "file",
121
+ "expected_path": "/absolute/path/to/runtime.yaml",
122
+ },
123
+ ),
124
+ TestCase(
125
+ name="parse unsupported scheme",
126
+ expected_status=FAILED,
127
+ config={"url": "ftp://example.com/runtime.yaml"},
128
+ expected_error=ValueError,
129
+ ),
130
+ ],
131
+ )
132
+ def test_parse_source_url(test_case):
133
+ """Test parsing various source URL formats."""
134
+ print("Executing test:", test_case.name)
135
+ try:
136
+ source_type, path = runtime_loader._parse_source_url(test_case.config["url"])
137
+
138
+ assert test_case.expected_status == SUCCESS
139
+ assert source_type == test_case.config["expected_type"]
140
+ assert path == test_case.config["expected_path"]
141
+
142
+ except Exception as e:
143
+ assert type(e) is test_case.expected_error
144
+ print("test execution complete")
145
+
146
+
147
+ @pytest.mark.parametrize(
148
+ "test_case",
149
+ [
150
+ TestCase(
151
+ name="load from default github",
152
+ expected_status=SUCCESS,
153
+ config={
154
+ "github_path": "kubeflow/trainer",
155
+ "discovered_files": ["torch_distributed.yaml"],
156
+ "expected_runtime_name": constants.DEFAULT_TRAINING_RUNTIME,
157
+ "expected_framework": "torch",
158
+ },
159
+ ),
160
+ TestCase(
161
+ name="load from custom github",
162
+ expected_status=SUCCESS,
163
+ config={
164
+ "github_path": "myorg/myrepo",
165
+ "discovered_files": ["custom_runtime.yaml"],
166
+ "expected_runtime_name": "custom-runtime",
167
+ "expected_framework": "custom",
168
+ },
169
+ ),
170
+ TestCase(
171
+ name="load from github no files",
172
+ expected_status=SUCCESS,
173
+ config={
174
+ "github_path": "kubeflow/trainer",
175
+ "discovered_files": [],
176
+ "expected_count": 0,
177
+ },
178
+ ),
179
+ TestCase(
180
+ name="load from github invalid path",
181
+ expected_status=SUCCESS,
182
+ config={
183
+ "github_path": "invalid",
184
+ "expected_count": 0,
185
+ },
186
+ ),
187
+ ],
188
+ )
189
+ def test_load_from_github_url(test_case):
190
+ """Test loading runtimes from GitHub URLs."""
191
+ print("Executing test:", test_case.name)
192
+ try:
193
+ with (
194
+ patch(
195
+ "viettelcloud.aiplatform.trainer.backends.container.runtime_loader._discover_github_runtime_files"
196
+ ) as mock_discover,
197
+ patch(
198
+ "viettelcloud.aiplatform.trainer.backends.container.runtime_loader._fetch_runtime_from_github"
199
+ ) as mock_fetch,
200
+ ):
201
+ if test_case.name == "load from github invalid path":
202
+ # Don't set up mocks for invalid path test
203
+ runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"])
204
+ assert len(runtimes) == test_case.config["expected_count"]
205
+ else:
206
+ mock_discover.return_value = test_case.config.get("discovered_files", [])
207
+
208
+ # Create runtime YAML with custom name/framework if specified
209
+ runtime_yaml = SAMPLE_RUNTIME_YAML.copy()
210
+ if "expected_runtime_name" in test_case.config:
211
+ runtime_yaml["metadata"]["name"] = test_case.config["expected_runtime_name"]
212
+ runtime_yaml["metadata"]["labels"]["trainer.kubeflow.org/framework"] = (
213
+ test_case.config["expected_framework"]
214
+ )
215
+ mock_fetch.return_value = runtime_yaml
216
+
217
+ runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"])
218
+
219
+ if "expected_count" in test_case.config:
220
+ assert len(runtimes) == test_case.config["expected_count"]
221
+ else:
222
+ assert len(runtimes) == 1
223
+ assert runtimes[0].name == test_case.config["expected_runtime_name"]
224
+ assert runtimes[0].trainer.framework == test_case.config["expected_framework"]
225
+
226
+ assert test_case.expected_status == SUCCESS
227
+
228
+ except Exception as e:
229
+ assert type(e) is test_case.expected_error
230
+ print("test execution complete")
231
+
232
+
233
+ @pytest.mark.parametrize(
234
+ "test_case",
235
+ [
236
+ TestCase(
237
+ name="priority order github sources",
238
+ expected_status=SUCCESS,
239
+ config={
240
+ "sources": ["github://myorg/myrepo", "github://kubeflow/trainer"],
241
+ "expected_count": 2,
242
+ "expected_names": [constants.DEFAULT_TRAINING_RUNTIME, "deepspeed-distributed"],
243
+ },
244
+ ),
245
+ TestCase(
246
+ name="duplicate runtime names skipped",
247
+ expected_status=SUCCESS,
248
+ config={
249
+ "sources": ["github://myorg/myrepo", "github://kubeflow/trainer"],
250
+ "duplicate_names": True,
251
+ "expected_count": 1,
252
+ "expected_names": [constants.DEFAULT_TRAINING_RUNTIME],
253
+ },
254
+ ),
255
+ TestCase(
256
+ name="fallback to defaults",
257
+ expected_status=SUCCESS,
258
+ config={
259
+ "sources": ["github://myorg/myrepo"],
260
+ "no_github_runtimes": True,
261
+ "expected_count": 1,
262
+ "expected_names": [constants.DEFAULT_TRAINING_RUNTIME],
263
+ },
264
+ ),
265
+ ],
266
+ )
267
+ def test_list_training_runtimes_from_sources(test_case):
268
+ """Test listing runtimes from multiple sources."""
269
+ print("Executing test:", test_case.name)
270
+ try:
271
+ with (
272
+ patch(
273
+ "viettelcloud.aiplatform.trainer.backends.container.runtime_loader._load_from_github_url"
274
+ ) as mock_github,
275
+ patch(
276
+ "viettelcloud.aiplatform.trainer.backends.container.runtime_loader._create_default_runtimes"
277
+ ) as mock_defaults,
278
+ ):
279
+ if test_case.name == "priority order github sources":
280
+ torch_runtime = base_types.Runtime(
281
+ name=constants.DEFAULT_TRAINING_RUNTIME,
282
+ trainer=base_types.RuntimeTrainer(
283
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
284
+ framework="torch",
285
+ num_nodes=1,
286
+ image="example.com/container",
287
+ ),
288
+ )
289
+ deepspeed_runtime = base_types.Runtime(
290
+ name="deepspeed-distributed",
291
+ trainer=base_types.RuntimeTrainer(
292
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
293
+ framework="deepspeed",
294
+ num_nodes=1,
295
+ image="example.com/container",
296
+ ),
297
+ )
298
+ mock_github.side_effect = [[torch_runtime], [deepspeed_runtime]]
299
+ mock_defaults.return_value = []
300
+
301
+ elif test_case.name == "duplicate runtime names skipped":
302
+ torch_runtime_1 = base_types.Runtime(
303
+ name=constants.DEFAULT_TRAINING_RUNTIME,
304
+ trainer=base_types.RuntimeTrainer(
305
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
306
+ framework="torch",
307
+ num_nodes=1,
308
+ image="example.com/container",
309
+ ),
310
+ )
311
+ torch_runtime_2 = base_types.Runtime(
312
+ name=constants.DEFAULT_TRAINING_RUNTIME,
313
+ trainer=base_types.RuntimeTrainer(
314
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
315
+ framework="torch",
316
+ num_nodes=2,
317
+ image="example.com/container",
318
+ ),
319
+ )
320
+ mock_github.side_effect = [[torch_runtime_1], [torch_runtime_2]]
321
+ mock_defaults.return_value = []
322
+
323
+ elif test_case.name == "fallback to defaults":
324
+ mock_github.return_value = []
325
+ default_runtime = base_types.Runtime(
326
+ name=constants.DEFAULT_TRAINING_RUNTIME,
327
+ trainer=base_types.RuntimeTrainer(
328
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
329
+ framework="torch",
330
+ num_nodes=1,
331
+ image="example.com/container",
332
+ ),
333
+ )
334
+ mock_defaults.return_value = [default_runtime]
335
+
336
+ runtimes = runtime_loader.list_training_runtimes_from_sources(
337
+ test_case.config["sources"]
338
+ )
339
+
340
+ assert len(runtimes) == test_case.config["expected_count"]
341
+ runtime_names = [r.name for r in runtimes]
342
+ for expected_name in test_case.config["expected_names"]:
343
+ assert expected_name in runtime_names
344
+
345
+ assert test_case.expected_status == SUCCESS
346
+
347
+ except Exception as e:
348
+ assert type(e) is test_case.expected_error
349
+ print("test execution complete")
350
+
351
+
352
+ def test_create_default_runtimes():
353
+ """Test creating default runtimes from constants."""
354
+ print("Executing test: create default runtimes")
355
+ runtimes = runtime_loader._create_default_runtimes()
356
+
357
+ assert len(runtimes) == len(constants.DEFAULT_FRAMEWORK_IMAGES)
358
+
359
+ # Check torch runtime
360
+ torch_runtimes = [r for r in runtimes if r.trainer.framework == "torch"]
361
+ assert len(torch_runtimes) == 1
362
+ assert torch_runtimes[0].name == constants.DEFAULT_TRAINING_RUNTIME
363
+ assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER
364
+ assert torch_runtimes[0].trainer.num_nodes == 1
365
+ # Verify default image is set
366
+ assert torch_runtimes[0].trainer.image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"]
367
+ print("test execution complete")
368
+
369
+
370
+ @pytest.mark.parametrize(
371
+ "test_case",
372
+ [
373
+ TestCase(
374
+ name="discover runtime files",
375
+ expected_status=SUCCESS,
376
+ config={
377
+ "html_content": """
378
+ <html>
379
+ <a>torch_distributed.yaml</a>
380
+ <a>deepspeed_distributed.yaml</a>
381
+ <a>kustomization.yaml</a>
382
+ </html>
383
+ """,
384
+ "expected_files": ["torch_distributed.yaml", "deepspeed_distributed.yaml"],
385
+ "excluded_files": ["kustomization.yaml"],
386
+ },
387
+ ),
388
+ TestCase(
389
+ name="discover runtime files custom repo",
390
+ expected_status=SUCCESS,
391
+ config={
392
+ "html_content": """
393
+ <html>
394
+ <a>custom_runtime.yaml</a>
395
+ </html>
396
+ """,
397
+ "expected_files": ["custom_runtime.yaml"],
398
+ "owner": "myorg",
399
+ "repo": "myrepo",
400
+ "path": "custom/path",
401
+ },
402
+ ),
403
+ TestCase(
404
+ name="discover runtime files network error",
405
+ expected_status=SUCCESS,
406
+ config={
407
+ "network_error": True,
408
+ "expected_files": [],
409
+ },
410
+ ),
411
+ ],
412
+ )
413
+ def test_discover_github_runtime_files(test_case):
414
+ """Test discovering runtime files from GitHub."""
415
+ print("Executing test:", test_case.name)
416
+ try:
417
+ with patch("urllib.request.urlopen") as mock_urlopen:
418
+ if test_case.config.get("network_error"):
419
+ mock_urlopen.side_effect = Exception("Network error")
420
+ else:
421
+ mock_response = MagicMock()
422
+ mock_response.read.return_value = test_case.config["html_content"].encode("utf-8")
423
+ mock_response.__enter__.return_value = mock_response
424
+ mock_urlopen.return_value = mock_response
425
+
426
+ kwargs = {}
427
+ if "owner" in test_case.config:
428
+ kwargs["owner"] = test_case.config["owner"]
429
+ kwargs["repo"] = test_case.config["repo"]
430
+ kwargs["path"] = test_case.config["path"]
431
+
432
+ files = runtime_loader._discover_github_runtime_files(**kwargs)
433
+
434
+ for expected_file in test_case.config["expected_files"]:
435
+ assert expected_file in files
436
+
437
+ for excluded_file in test_case.config.get("excluded_files", []):
438
+ assert excluded_file not in files
439
+
440
+ if "owner" in test_case.config and not test_case.config.get("network_error"):
441
+ called_url = mock_urlopen.call_args[0][0]
442
+ assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url
443
+ assert kwargs["path"] in called_url
444
+
445
+ assert test_case.expected_status == SUCCESS
446
+
447
+ except Exception as e:
448
+ assert type(e) is test_case.expected_error
449
+ print("test execution complete")
450
+
451
+
452
+ @pytest.mark.parametrize(
453
+ "test_case",
454
+ [
455
+ TestCase(
456
+ name="fetch runtime success",
457
+ expected_status=SUCCESS,
458
+ config={
459
+ "yaml_content": """
460
+ apiVersion: trainer.kubeflow.org/v1alpha1
461
+ kind: ClusterTrainingRuntime
462
+ metadata:
463
+ name: torch-distributed
464
+ """,
465
+ "expected_name": constants.DEFAULT_TRAINING_RUNTIME,
466
+ },
467
+ ),
468
+ TestCase(
469
+ name="fetch runtime custom repo",
470
+ expected_status=SUCCESS,
471
+ config={
472
+ "yaml_content": """
473
+ apiVersion: trainer.kubeflow.org/v1alpha1
474
+ kind: ClusterTrainingRuntime
475
+ metadata:
476
+ name: custom-runtime
477
+ """,
478
+ "expected_name": "custom-runtime",
479
+ "runtime_file": "custom.yaml",
480
+ "owner": "myorg",
481
+ "repo": "myrepo",
482
+ "path": "custom/path",
483
+ },
484
+ ),
485
+ TestCase(
486
+ name="fetch runtime network error",
487
+ expected_status=SUCCESS,
488
+ config={
489
+ "network_error": True,
490
+ "expected_none": True,
491
+ },
492
+ ),
493
+ ],
494
+ )
495
+ def test_fetch_runtime_from_github(test_case):
496
+ """Test fetching runtime YAML from GitHub."""
497
+ print("Executing test:", test_case.name)
498
+ try:
499
+ with patch("urllib.request.urlopen") as mock_urlopen:
500
+ if test_case.config.get("network_error"):
501
+ mock_urlopen.side_effect = Exception("Network error")
502
+ else:
503
+ mock_response = MagicMock()
504
+ mock_response.read.return_value = test_case.config["yaml_content"].encode("utf-8")
505
+ mock_response.__enter__.return_value = mock_response
506
+ mock_urlopen.return_value = mock_response
507
+
508
+ default_runtime_file = "torch_distributed.yaml"
509
+ kwargs = {"runtime_file": test_case.config.get("runtime_file", default_runtime_file)}
510
+ if "owner" in test_case.config:
511
+ kwargs["owner"] = test_case.config["owner"]
512
+ kwargs["repo"] = test_case.config["repo"]
513
+ kwargs["path"] = test_case.config["path"]
514
+
515
+ data = runtime_loader._fetch_runtime_from_github(**kwargs)
516
+
517
+ if test_case.config.get("expected_none"):
518
+ assert data is None
519
+ else:
520
+ assert data is not None
521
+ assert data["metadata"]["name"] == test_case.config["expected_name"]
522
+
523
+ if "owner" in test_case.config:
524
+ called_url = mock_urlopen.call_args[0][0]
525
+ assert "raw.githubusercontent.com" in called_url
526
+ assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url
527
+ assert f"{kwargs['path']}/{kwargs['runtime_file']}" in called_url
528
+
529
+ assert test_case.expected_status == SUCCESS
530
+
531
+ except Exception as e:
532
+ assert type(e) is test_case.expected_error
533
+ print("test execution complete")
534
+
535
+
536
+ @pytest.mark.parametrize(
537
+ "test_case",
538
+ [
539
+ TestCase(
540
+ name="parse runtime yaml with custom image",
541
+ expected_status=SUCCESS,
542
+ config={
543
+ "custom_image": "quay.io/custom/pytorch-arm:v1.0",
544
+ "runtime_name": "torch-arm",
545
+ "framework": "torch",
546
+ "num_nodes": 2,
547
+ },
548
+ ),
549
+ TestCase(
550
+ name="parse runtime yaml with different custom image",
551
+ expected_status=SUCCESS,
552
+ config={
553
+ "custom_image": "my-registry.io/pytorch:gpu-arm64",
554
+ "runtime_name": "torch-gpu-arm",
555
+ "framework": "torch",
556
+ "num_nodes": 4,
557
+ },
558
+ ),
559
+ TestCase(
560
+ name="parse runtime yaml prefers container named node",
561
+ expected_status=SUCCESS,
562
+ config={
563
+ "custom_image": "correct-node-image:v1.0",
564
+ "runtime_name": "multi-container-runtime",
565
+ "framework": "torch",
566
+ "num_nodes": 1,
567
+ "multiple_containers": True,
568
+ },
569
+ ),
570
+ ],
571
+ )
572
+ def test_parse_runtime_yaml_extracts_image(test_case):
573
+ """
574
+ Test that _parse_runtime_yaml correctly extracts and stores the container image.
575
+ This prevents regression of bugs where custom images are ignored.
576
+ """
577
+ print("Executing test:", test_case.name)
578
+ try:
579
+ # Create container list based on test case
580
+ if test_case.config.get("multiple_containers"):
581
+ # Test case with multiple containers - should prefer 'node' container
582
+ containers = [
583
+ {
584
+ "name": "sidecar",
585
+ "image": "wrong-sidecar-image:v1.0",
586
+ },
587
+ {
588
+ "name": "node",
589
+ "image": test_case.config["custom_image"],
590
+ },
591
+ ]
592
+ else:
593
+ # Single container test case
594
+ containers = [
595
+ {
596
+ "name": "trainer",
597
+ "image": test_case.config["custom_image"],
598
+ }
599
+ ]
600
+
601
+ # Create runtime YAML with custom image
602
+ runtime_yaml = {
603
+ "kind": "ClusterTrainingRuntime",
604
+ "metadata": {
605
+ "name": test_case.config["runtime_name"],
606
+ "labels": {"trainer.kubeflow.org/framework": test_case.config["framework"]},
607
+ },
608
+ "spec": {
609
+ "mlPolicy": {"numNodes": test_case.config["num_nodes"]},
610
+ "template": {
611
+ "spec": {
612
+ "replicatedJobs": [
613
+ {
614
+ "name": "node",
615
+ "template": {
616
+ "spec": {"template": {"spec": {"containers": containers}}}
617
+ },
618
+ }
619
+ ]
620
+ }
621
+ },
622
+ },
623
+ }
624
+
625
+ runtime = runtime_loader._parse_runtime_yaml(runtime_yaml, "test")
626
+
627
+ # Verify image is extracted and stored
628
+ assert runtime.name == test_case.config["runtime_name"]
629
+ assert runtime.trainer.framework == test_case.config["framework"]
630
+ assert runtime.trainer.num_nodes == test_case.config["num_nodes"]
631
+ assert runtime.trainer.image == test_case.config["custom_image"]
632
+
633
+ assert test_case.expected_status == SUCCESS
634
+
635
+ except Exception as e:
636
+ assert type(e) is test_case.expected_error
637
+ print("test execution complete")