viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,1344 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Unit tests for the KubernetesBackend class in the Kubeflow Trainer SDK.
17
+
18
+ This module uses pytest and unittest.mock to simulate Kubernetes API interactions.
19
+ It tests KubernetesBackend's behavior across job listing, resource creation etc
20
+ """
21
+
22
+ from dataclasses import asdict
23
+ import datetime
24
+ import multiprocessing
25
+ import random
26
+ import string
27
+ from typing import Optional
28
+ from unittest.mock import Mock, patch
29
+ import uuid
30
+
31
+ from kubeflow_trainer_api import models
32
+ import pytest
33
+
34
+ from viettelcloud.aiplatform.common.types import KubernetesBackendConfig
35
+ from viettelcloud.aiplatform.trainer.backends.kubernetes.backend import KubernetesBackend
36
+ import viettelcloud.aiplatform.trainer.backends.kubernetes.utils as utils
37
+ from viettelcloud.aiplatform.trainer.constants import constants
38
+ from viettelcloud.aiplatform.trainer.options import (
39
+ Annotations,
40
+ Labels,
41
+ SpecAnnotations,
42
+ SpecLabels,
43
+ )
44
+ from viettelcloud.aiplatform.trainer.test.common import (
45
+ DEFAULT_NAMESPACE,
46
+ FAILED,
47
+ RUNTIME,
48
+ SUCCESS,
49
+ TIMEOUT,
50
+ TestCase,
51
+ )
52
+ from viettelcloud.aiplatform.trainer.types import types
53
+
54
+ # In all tests runtime name is equal to the framework name.
55
+ TORCH_RUNTIME = "torch"
56
+ TORCH_TUNE_RUNTIME = "torchtune"
57
+
58
+ # 2 nodes * 2 nproc
59
+ RUNTIME_DEVICES = "4"
60
+
61
+ FAIL_LOGS = "fail_logs"
62
+ LIST_RUNTIMES = "list_runtimes"
63
+ BASIC_TRAIN_JOB_NAME = "basic-job"
64
+ TRAIN_JOBS = "trainjobs"
65
+ TRAIN_JOB_WITH_BUILT_IN_TRAINER = "train-job-with-built-in-trainer"
66
+ TRAIN_JOB_WITH_CUSTOM_TRAINER = "train-job-with-custom-trainer"
67
+
68
+
69
+ # --------------------------
70
+ # Fixtures
71
+ # --------------------------
72
+
73
+
74
+ @pytest.fixture
75
+ def kubernetes_backend(request):
76
+ """Provide a KubernetesBackend with mocked Kubernetes APIs."""
77
+ with (
78
+ patch("kubernetes.config.load_kube_config", return_value=None),
79
+ patch(
80
+ "kubernetes.client.CustomObjectsApi",
81
+ return_value=Mock(
82
+ create_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
83
+ patch_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
84
+ delete_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
85
+ get_namespaced_custom_object=Mock(
86
+ side_effect=get_namespaced_custom_object_response
87
+ ),
88
+ get_cluster_custom_object=Mock(side_effect=get_cluster_custom_object_response),
89
+ list_namespaced_custom_object=Mock(
90
+ side_effect=list_namespaced_custom_object_response
91
+ ),
92
+ list_cluster_custom_object=Mock(side_effect=list_cluster_custom_object),
93
+ ),
94
+ ),
95
+ patch(
96
+ "kubernetes.client.CoreV1Api",
97
+ return_value=Mock(
98
+ list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response),
99
+ read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log),
100
+ list_namespaced_event=Mock(side_effect=mock_list_namespaced_event),
101
+ ),
102
+ ),
103
+ ):
104
+ yield KubernetesBackend(KubernetesBackendConfig())
105
+
106
+
107
+ # --------------------------
108
+ # Mock Handlers
109
+ # --------------------------
110
+
111
+
112
+ def conditional_error_handler(*args, **kwargs):
113
+ """Raise simulated errors based on resource name."""
114
+ if args[2] == TIMEOUT:
115
+ raise multiprocessing.TimeoutError()
116
+ elif args[2] == RUNTIME:
117
+ raise RuntimeError()
118
+
119
+
120
+ def list_namespaced_pod_response(*args, **kwargs):
121
+ """Return mock pod list response."""
122
+ pod_list = get_mock_pod_list()
123
+ mock_thread = Mock()
124
+ mock_thread.get.return_value = pod_list
125
+ return mock_thread
126
+
127
+
128
+ def get_mock_pod_list():
129
+ """Create a mocked Kubernetes PodList object with pods for different training steps."""
130
+ return models.IoK8sApiCoreV1PodList(
131
+ items=[
132
+ # Dataset initializer pod
133
+ models.IoK8sApiCoreV1Pod(
134
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
135
+ name="dataset-initializer-pod",
136
+ namespace=DEFAULT_NAMESPACE,
137
+ labels={
138
+ constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
139
+ constants.JOBSET_RJOB_NAME_LABEL: constants.DATASET_INITIALIZER,
140
+ constants.JOB_INDEX_LABEL: "0",
141
+ },
142
+ ),
143
+ spec=models.IoK8sApiCoreV1PodSpec(
144
+ containers=[
145
+ models.IoK8sApiCoreV1Container(
146
+ name=constants.DATASET_INITIALIZER,
147
+ image="dataset-initializer:latest",
148
+ command=["python", "-m", "dataset_initializer"],
149
+ )
150
+ ]
151
+ ),
152
+ status=models.IoK8sApiCoreV1PodStatus(phase="Running"),
153
+ ),
154
+ # Model initializer pod
155
+ models.IoK8sApiCoreV1Pod(
156
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
157
+ name="model-initializer-pod",
158
+ namespace=DEFAULT_NAMESPACE,
159
+ labels={
160
+ constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
161
+ constants.JOBSET_RJOB_NAME_LABEL: constants.MODEL_INITIALIZER,
162
+ constants.JOB_INDEX_LABEL: "0",
163
+ },
164
+ ),
165
+ spec=models.IoK8sApiCoreV1PodSpec(
166
+ containers=[
167
+ models.IoK8sApiCoreV1Container(
168
+ name=constants.MODEL_INITIALIZER,
169
+ image="model-initializer:latest",
170
+ command=["python", "-m", "model_initializer"],
171
+ )
172
+ ]
173
+ ),
174
+ status=models.IoK8sApiCoreV1PodStatus(phase="Running"),
175
+ ),
176
+ # Training node pod
177
+ models.IoK8sApiCoreV1Pod(
178
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
179
+ name="node-0-pod",
180
+ namespace=DEFAULT_NAMESPACE,
181
+ labels={
182
+ constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
183
+ constants.JOBSET_RJOB_NAME_LABEL: constants.NODE,
184
+ constants.JOB_INDEX_LABEL: "0",
185
+ },
186
+ ),
187
+ spec=models.IoK8sApiCoreV1PodSpec(
188
+ containers=[
189
+ models.IoK8sApiCoreV1Container(
190
+ name=constants.NODE,
191
+ image="trainer:latest",
192
+ command=["python", "-m", "trainer"],
193
+ resources=get_resource_requirements(),
194
+ )
195
+ ]
196
+ ),
197
+ status=models.IoK8sApiCoreV1PodStatus(phase="Running"),
198
+ ),
199
+ ]
200
+ )
201
+
202
+
203
+ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements:
204
+ """Create a mock ResourceRequirements object for testing."""
205
+ return models.IoK8sApiCoreV1ResourceRequirements(
206
+ requests={
207
+ "nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity("1"),
208
+ "memory": models.IoK8sApimachineryPkgApiResourceQuantity("2Gi"),
209
+ },
210
+ limits={
211
+ "nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity("1"),
212
+ "memory": models.IoK8sApimachineryPkgApiResourceQuantity("4Gi"),
213
+ },
214
+ )
215
+
216
+
217
+ def get_custom_trainer(
218
+ env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None,
219
+ pip_index_urls: Optional[list[str]] = constants.DEFAULT_PIP_INDEX_URLS,
220
+ packages_to_install: list[str] = ["torch", "numpy"],
221
+ image: Optional[str] = None,
222
+ ) -> models.TrainerV1alpha1Trainer:
223
+ """
224
+ Get the custom trainer for the TrainJob.
225
+ """
226
+ pip_command = [f"--index-url {pip_index_urls[0]}"]
227
+ pip_command.extend([f"--extra-index-url {repo}" for repo in pip_index_urls[1:]])
228
+ pip_command = " ".join(pip_command)
229
+
230
+ packages_command = " ".join(packages_to_install)
231
+ return models.TrainerV1alpha1Trainer(
232
+ command=[
233
+ "bash",
234
+ "-c",
235
+ '\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip '
236
+ "|| python -m ensurepip --user || apt-get install python-pip"
237
+ "\nfi\n\n"
238
+ "PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
239
+ f" --no-warn-script-location {pip_command} --user {packages_command}"
240
+ " ||\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
241
+ f" --no-warn-script-location {pip_command} {packages_command}"
242
+ "\n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
243
+ 'print("Hello World"),\n\n<lambda>(**'
244
+ "{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" "
245
+ '"$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"',
246
+ ],
247
+ numNodes=2,
248
+ env=env,
249
+ image=image,
250
+ )
251
+
252
+
253
+ def get_custom_trainer_container(
254
+ image: str,
255
+ num_nodes: int,
256
+ resources_per_node: models.IoK8sApiCoreV1ResourceRequirements,
257
+ env: list[models.IoK8sApiCoreV1EnvVar],
258
+ ) -> models.TrainerV1alpha1Trainer:
259
+ """
260
+ Get the custom trainer container for the TrainJob.
261
+ """
262
+
263
+ return models.TrainerV1alpha1Trainer(
264
+ image=image,
265
+ numNodes=num_nodes,
266
+ resourcesPerNode=resources_per_node,
267
+ env=env,
268
+ )
269
+
270
+
271
+ def get_builtin_trainer(
272
+ args: list[str],
273
+ ) -> models.TrainerV1alpha1Trainer:
274
+ """
275
+ Get the builtin trainer for the TrainJob.
276
+ """
277
+ return models.TrainerV1alpha1Trainer(
278
+ args=args,
279
+ command=["tune", "run"],
280
+ numNodes=2,
281
+ )
282
+
283
+
284
+ def get_train_job(
285
+ runtime_name: str,
286
+ train_job_name: str = BASIC_TRAIN_JOB_NAME,
287
+ train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None,
288
+ labels: Optional[dict[str, str]] = None,
289
+ annotations: Optional[dict[str, str]] = None,
290
+ spec_labels: Optional[dict[str, str]] = None,
291
+ spec_annotations: Optional[dict[str, str]] = None,
292
+ ) -> models.TrainerV1alpha1TrainJob:
293
+ """
294
+ Create a mock TrainJob object with optional trainer configurations.
295
+ """
296
+ train_job = models.TrainerV1alpha1TrainJob(
297
+ apiVersion=constants.API_VERSION,
298
+ kind=constants.TRAINJOB_KIND,
299
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
300
+ name=train_job_name,
301
+ labels=labels,
302
+ annotations=annotations,
303
+ ),
304
+ spec=models.TrainerV1alpha1TrainJobSpec(
305
+ runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
306
+ trainer=train_job_trainer,
307
+ labels=spec_labels,
308
+ annotations=spec_annotations,
309
+ ),
310
+ )
311
+
312
+ return train_job
313
+
314
+
315
+ def get_cluster_custom_object_response(*args, **kwargs):
316
+ """Return a mocked ClusterTrainingRuntime object."""
317
+ mock_thread = Mock()
318
+ if args[3] == TIMEOUT:
319
+ raise multiprocessing.TimeoutError()
320
+ if args[3] == RUNTIME:
321
+ raise RuntimeError()
322
+ if args[2] == constants.CLUSTER_TRAINING_RUNTIME_PLURAL:
323
+ mock_thread.get.return_value = normalize_model(
324
+ create_cluster_training_runtime(name=args[3]),
325
+ models.TrainerV1alpha1ClusterTrainingRuntime,
326
+ )
327
+
328
+ return mock_thread
329
+
330
+
331
+ def get_namespaced_custom_object_response(*args, **kwargs):
332
+ """Return a mocked TrainJob object."""
333
+ mock_thread = Mock()
334
+ if args[2] == TIMEOUT or args[4] == TIMEOUT:
335
+ raise multiprocessing.TimeoutError()
336
+ if args[2] == RUNTIME or args[4] == RUNTIME:
337
+ raise RuntimeError()
338
+ if args[3] == TRAIN_JOBS: # TODO: review this.
339
+ mock_thread.get.return_value = add_status(create_train_job(train_job_name=args[4]))
340
+
341
+ return mock_thread
342
+
343
+
344
+ def add_status(
345
+ train_job: models.TrainerV1alpha1TrainJob,
346
+ ) -> models.TrainerV1alpha1TrainJob:
347
+ """
348
+ Add status information to the train job.
349
+ """
350
+ # Set initial status to Created
351
+ status = models.TrainerV1alpha1TrainJobStatus(
352
+ conditions=[
353
+ models.IoK8sApimachineryPkgApisMetaV1Condition(
354
+ type="Complete",
355
+ status="True",
356
+ lastTransitionTime=datetime.datetime.now(),
357
+ reason="JobCompleted",
358
+ message="Job completed successfully",
359
+ )
360
+ ]
361
+ )
362
+ train_job.status = status
363
+ return train_job
364
+
365
+
366
+ def list_namespaced_custom_object_response(*args, **kwargs):
367
+ """Return a list of mocked TrainJob objects."""
368
+ mock_thread = Mock()
369
+ if args[2] == TIMEOUT:
370
+ raise multiprocessing.TimeoutError()
371
+ if args[2] == RUNTIME:
372
+ raise RuntimeError()
373
+ if args[3] == constants.TRAINJOB_PLURAL:
374
+ items = [
375
+ add_status(create_train_job(train_job_name="basic-job-1")),
376
+ add_status(create_train_job(train_job_name="basic-job-2")),
377
+ ]
378
+ mock_thread.get.return_value = normalize_model(
379
+ models.TrainerV1alpha1TrainJobList(items=items),
380
+ models.TrainerV1alpha1TrainJobList,
381
+ )
382
+
383
+ return mock_thread
384
+
385
+
386
+ def list_cluster_custom_object(*args, **kwargs):
387
+ """Return a generic mocked response for cluster object listing."""
388
+ mock_thread = Mock()
389
+ if args[2] == TIMEOUT:
390
+ raise multiprocessing.TimeoutError()
391
+ if args[2] == RUNTIME:
392
+ raise RuntimeError()
393
+ if args[2] == constants.CLUSTER_TRAINING_RUNTIME_PLURAL:
394
+ items = [
395
+ create_cluster_training_runtime(name="runtime-1"),
396
+ create_cluster_training_runtime(name="runtime-2"),
397
+ ]
398
+ mock_thread.get.return_value = normalize_model(
399
+ models.TrainerV1alpha1ClusterTrainingRuntimeList(items=items),
400
+ models.TrainerV1alpha1ClusterTrainingRuntimeList,
401
+ )
402
+
403
+ return mock_thread
404
+
405
+
406
+ def mock_read_namespaced_pod_log(*args, **kwargs):
407
+ """Simulate log retrieval from a pod."""
408
+ if kwargs.get("namespace") == FAIL_LOGS:
409
+ raise Exception("Failed to read logs")
410
+ return "test log content"
411
+
412
+
413
+ def mock_list_namespaced_event(*args, **kwargs):
414
+ """Simulate event listing from namespace."""
415
+ namespace = kwargs.get("namespace")
416
+
417
+ # Errors occur at call time, not during .get()
418
+ if namespace == TIMEOUT:
419
+ raise multiprocessing.TimeoutError()
420
+ if namespace == RUNTIME:
421
+ raise RuntimeError()
422
+
423
+ mock_thread = Mock()
424
+ mock_thread.get.return_value = models.IoK8sApiCoreV1EventList(
425
+ items=[
426
+ models.IoK8sApiCoreV1Event(
427
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
428
+ name="test-event-1",
429
+ namespace=DEFAULT_NAMESPACE,
430
+ ),
431
+ involvedObject=models.IoK8sApiCoreV1ObjectReference(
432
+ kind=constants.TRAINJOB_KIND,
433
+ name=BASIC_TRAIN_JOB_NAME,
434
+ namespace=DEFAULT_NAMESPACE,
435
+ ),
436
+ message="TrainJob created successfully",
437
+ reason="Created",
438
+ firstTimestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
439
+ ),
440
+ models.IoK8sApiCoreV1Event(
441
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
442
+ name="test-event-2",
443
+ namespace=DEFAULT_NAMESPACE,
444
+ ),
445
+ involvedObject=models.IoK8sApiCoreV1ObjectReference(
446
+ kind="Pod",
447
+ name="node-0-pod",
448
+ namespace=DEFAULT_NAMESPACE,
449
+ ),
450
+ message="Pod scheduled successfully",
451
+ reason="Scheduled",
452
+ firstTimestamp=datetime.datetime(2025, 6, 1, 10, 31, 0),
453
+ ),
454
+ ]
455
+ )
456
+ return mock_thread
457
+
458
+
459
+ def mock_watch(*args, **kwargs):
460
+ """Simulate watch event"""
461
+ if kwargs.get("timeout_seconds") == 1:
462
+ raise TimeoutError("Watch timeout")
463
+
464
+ events = [
465
+ {
466
+ "type": "MODIFIED",
467
+ "object": {
468
+ "metadata": {
469
+ "name": f"{BASIC_TRAIN_JOB_NAME}-node-0",
470
+ "labels": {
471
+ constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
472
+ constants.JOBSET_RJOB_NAME_LABEL: constants.NODE,
473
+ constants.JOB_INDEX_LABEL: "0",
474
+ },
475
+ },
476
+ "spec": {"containers": [{"name": constants.NODE}]},
477
+ "status": {"phase": "Running"},
478
+ },
479
+ }
480
+ ]
481
+
482
+ return iter(events)
483
+
484
+
485
+ def normalize_model(model_obj, model_class):
486
+ # Simulate real api behavior
487
+ # Converts model to raw dictionary, like a real API response
488
+ # Parses dict and ensures correct model instantiation and type validation
489
+ return model_class.from_dict(model_obj.to_dict())
490
+
491
+
492
+ # --------------------------
493
+ # Object Creators
494
+ # --------------------------
495
+
496
+
497
+ def create_train_job(
498
+ train_job_name: str = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11],
499
+ namespace: str = "default",
500
+ image: str = "pytorch/pytorch:latest",
501
+ initializer: Optional[types.Initializer] = None,
502
+ command: Optional[list] = None,
503
+ args: Optional[list] = None,
504
+ ) -> models.TrainerV1alpha1TrainJob:
505
+ """Create a mock TrainJob object."""
506
+ return models.TrainerV1alpha1TrainJob(
507
+ apiVersion=constants.API_VERSION,
508
+ kind=constants.TRAINJOB_KIND,
509
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
510
+ name=train_job_name,
511
+ namespace=namespace,
512
+ creationTimestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
513
+ ),
514
+ spec=models.TrainerV1alpha1TrainJobSpec(
515
+ runtimeRef=models.TrainerV1alpha1RuntimeRef(name=TORCH_RUNTIME),
516
+ trainer=None,
517
+ initializer=(
518
+ models.TrainerV1alpha1Initializer(
519
+ dataset=utils.get_dataset_initializer(initializer.dataset),
520
+ model=utils.get_model_initializer(initializer.model),
521
+ )
522
+ if initializer
523
+ else None
524
+ ),
525
+ ),
526
+ )
527
+
528
+
529
+ def create_cluster_training_runtime(
530
+ name: str,
531
+ namespace: str = "default",
532
+ ) -> models.TrainerV1alpha1ClusterTrainingRuntime:
533
+ """Create a mock ClusterTrainingRuntime object."""
534
+
535
+ return models.TrainerV1alpha1ClusterTrainingRuntime(
536
+ apiVersion=constants.API_VERSION,
537
+ kind="ClusterTrainingRuntime",
538
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
539
+ name=name,
540
+ namespace=namespace,
541
+ labels={constants.RUNTIME_FRAMEWORK_LABEL: name},
542
+ ),
543
+ spec=models.TrainerV1alpha1TrainingRuntimeSpec(
544
+ mlPolicy=models.TrainerV1alpha1MLPolicy(
545
+ torch=models.TrainerV1alpha1TorchMLPolicySource(
546
+ numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2)
547
+ ),
548
+ numNodes=2,
549
+ ),
550
+ template=models.TrainerV1alpha1JobSetTemplateSpec(
551
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
552
+ name=name,
553
+ namespace=namespace,
554
+ ),
555
+ spec=models.JobsetV1alpha2JobSetSpec(replicatedJobs=[get_replicated_job()]),
556
+ ),
557
+ ),
558
+ )
559
+
560
+
561
+ def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob:
562
+ return models.JobsetV1alpha2ReplicatedJob(
563
+ name="node",
564
+ replicas=1,
565
+ template=models.IoK8sApiBatchV1JobTemplateSpec(
566
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
567
+ labels={"trainer.kubeflow.org/trainjob-ancestor-step": "trainer"}
568
+ ),
569
+ spec=models.IoK8sApiBatchV1JobSpec(
570
+ template=models.IoK8sApiCoreV1PodTemplateSpec(
571
+ spec=models.IoK8sApiCoreV1PodSpec(containers=[get_container()])
572
+ )
573
+ ),
574
+ ),
575
+ )
576
+
577
+
578
+ def get_container() -> models.IoK8sApiCoreV1Container:
579
+ return models.IoK8sApiCoreV1Container(
580
+ name="node",
581
+ image="example.com/test-runtime",
582
+ command=["echo", "Hello World"],
583
+ resources=get_resource_requirements(),
584
+ )
585
+
586
+
587
+ def create_runtime_type(
588
+ name: str,
589
+ ) -> types.Runtime:
590
+ """Create a mock Runtime object for testing."""
591
+ trainer = types.RuntimeTrainer(
592
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
593
+ framework=name,
594
+ num_nodes=2,
595
+ device="gpu",
596
+ device_count=RUNTIME_DEVICES,
597
+ image="example.com/test-runtime",
598
+ )
599
+ trainer.set_command(constants.TORCH_COMMAND)
600
+ return types.Runtime(
601
+ name=name,
602
+ trainer=trainer,
603
+ )
604
+
605
+
606
+ def get_train_job_data_type(
607
+ runtime_name: str,
608
+ train_job_name: str,
609
+ ) -> types.TrainJob:
610
+ """Create a mock TrainJob object with the expected structure for testing."""
611
+
612
+ trainer = types.RuntimeTrainer(
613
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
614
+ framework=runtime_name,
615
+ device="gpu",
616
+ device_count=RUNTIME_DEVICES,
617
+ num_nodes=2,
618
+ image="example.com/test-runtime",
619
+ )
620
+ trainer.set_command(constants.TORCH_COMMAND)
621
+ return types.TrainJob(
622
+ name=train_job_name,
623
+ creation_timestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
624
+ runtime=types.Runtime(
625
+ name=runtime_name,
626
+ trainer=trainer,
627
+ ),
628
+ steps=[
629
+ types.Step(
630
+ name="dataset-initializer",
631
+ status="Running",
632
+ pod_name="dataset-initializer-pod",
633
+ device="Unknown",
634
+ device_count="Unknown",
635
+ ),
636
+ types.Step(
637
+ name="model-initializer",
638
+ status="Running",
639
+ pod_name="model-initializer-pod",
640
+ device="Unknown",
641
+ device_count="Unknown",
642
+ ),
643
+ types.Step(
644
+ name="node-0",
645
+ status="Running",
646
+ pod_name="node-0-pod",
647
+ device="gpu",
648
+ device_count="1",
649
+ ),
650
+ ],
651
+ num_nodes=2,
652
+ status="Complete",
653
+ )
654
+
655
+
656
+ # --------------------------
657
+ # Tests
658
+ # --------------------------
659
+
660
+
661
+ @pytest.mark.parametrize(
662
+ "test_case",
663
+ [
664
+ TestCase(
665
+ name="valid flow with all defaults",
666
+ expected_status=SUCCESS,
667
+ config={"name": TORCH_RUNTIME},
668
+ expected_output=create_runtime_type(name=TORCH_RUNTIME),
669
+ ),
670
+ TestCase(
671
+ name="timeout error when getting runtime",
672
+ expected_status=FAILED,
673
+ config={"name": TIMEOUT},
674
+ expected_error=TimeoutError,
675
+ ),
676
+ TestCase(
677
+ name="runtime error when getting runtime",
678
+ expected_status=FAILED,
679
+ config={"name": RUNTIME},
680
+ expected_error=RuntimeError,
681
+ ),
682
+ ],
683
+ )
684
+ def test_get_runtime(kubernetes_backend, test_case):
685
+ """Test KubernetesBackend.get_runtime with basic success path."""
686
+ print("Executing test:", test_case.name)
687
+ try:
688
+ runtime = kubernetes_backend.get_runtime(**test_case.config)
689
+
690
+ assert test_case.expected_status == SUCCESS
691
+ assert isinstance(runtime, types.Runtime)
692
+ assert asdict(runtime) == asdict(test_case.expected_output)
693
+
694
+ except Exception as e:
695
+ assert type(e) is test_case.expected_error
696
+ print("test execution complete")
697
+
698
+
699
+ @pytest.mark.parametrize(
700
+ "test_case",
701
+ [
702
+ TestCase(
703
+ name="valid flow with all defaults",
704
+ expected_status=SUCCESS,
705
+ config={"name": LIST_RUNTIMES},
706
+ expected_output=[
707
+ create_runtime_type(name="runtime-1"),
708
+ create_runtime_type(name="runtime-2"),
709
+ ],
710
+ ),
711
+ ],
712
+ )
713
+ def test_list_runtimes(kubernetes_backend, test_case):
714
+ """Test KubernetesBackend.list_runtimes with basic success path."""
715
+ print("Executing test:", test_case.name)
716
+ try:
717
+ kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
718
+ runtimes = kubernetes_backend.list_runtimes()
719
+
720
+ assert test_case.expected_status == SUCCESS
721
+ assert isinstance(runtimes, list)
722
+ assert all(isinstance(r, types.Runtime) for r in runtimes)
723
+ assert [asdict(r) for r in runtimes] == [asdict(r) for r in test_case.expected_output]
724
+
725
+ except Exception as e:
726
+ assert type(e) is test_case.expected_error
727
+ print("test execution complete")
728
+
729
+
730
+ @pytest.mark.parametrize(
731
+ "test_case",
732
+ [
733
+ TestCase(
734
+ name="valid flow with custom trainer runtime",
735
+ expected_status=SUCCESS,
736
+ config={"runtime": create_runtime_type(name=TORCH_RUNTIME)},
737
+ ),
738
+ TestCase(
739
+ name="value error with builtin trainer runtime",
740
+ expected_status=FAILED,
741
+ config={
742
+ "runtime": types.Runtime(
743
+ name="torchtune-runtime",
744
+ trainer=types.RuntimeTrainer(
745
+ trainer_type=types.TrainerType.BUILTIN_TRAINER,
746
+ framework="torchtune",
747
+ num_nodes=1,
748
+ device="cpu",
749
+ device_count="1",
750
+ image="example.com/image",
751
+ ),
752
+ )
753
+ },
754
+ expected_error=ValueError,
755
+ ),
756
+ ],
757
+ )
758
+ def test_get_runtime_packages(kubernetes_backend, test_case):
759
+ """Test KubernetesBackend.get_runtime_packages with basic success path."""
760
+ print("Executing test:", test_case.name)
761
+
762
+ try:
763
+ kubernetes_backend.get_runtime_packages(**test_case.config)
764
+ except Exception as e:
765
+ assert type(e) is test_case.expected_error
766
+
767
+ print("test execution complete")
768
+
769
+
770
+ @pytest.mark.parametrize(
771
+ "test_case",
772
+ [
773
+ TestCase(
774
+ name="valid flow with all defaults",
775
+ expected_status=SUCCESS,
776
+ config={},
777
+ expected_output=get_train_job(
778
+ runtime_name=TORCH_RUNTIME,
779
+ train_job_name=BASIC_TRAIN_JOB_NAME,
780
+ ),
781
+ ),
782
+ TestCase(
783
+ name="valid flow with built in trainer",
784
+ expected_status=SUCCESS,
785
+ config={
786
+ "trainer": types.BuiltinTrainer(
787
+ config=types.TorchTuneConfig(
788
+ num_nodes=2,
789
+ batch_size=2,
790
+ epochs=2,
791
+ loss=types.Loss.CEWithChunkedOutputLoss,
792
+ )
793
+ ),
794
+ "runtime": TORCH_TUNE_RUNTIME,
795
+ },
796
+ expected_output=get_train_job(
797
+ runtime_name=TORCH_TUNE_RUNTIME,
798
+ train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER,
799
+ train_job_trainer=get_builtin_trainer(
800
+ args=["batch_size=2", "epochs=2", "loss=Loss.CEWithChunkedOutputLoss"],
801
+ ),
802
+ ),
803
+ ),
804
+ TestCase(
805
+ name="valid flow with built in trainer and lora config",
806
+ expected_status=SUCCESS,
807
+ config={
808
+ "trainer": types.BuiltinTrainer(
809
+ config=types.TorchTuneConfig(
810
+ num_nodes=2,
811
+ peft_config=types.LoraConfig(
812
+ apply_lora_to_mlp=True,
813
+ lora_rank=8,
814
+ lora_alpha=16,
815
+ lora_dropout=0.1,
816
+ ),
817
+ ),
818
+ ),
819
+ "runtime": TORCH_TUNE_RUNTIME,
820
+ },
821
+ expected_output=get_train_job(
822
+ runtime_name=TORCH_TUNE_RUNTIME,
823
+ train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER,
824
+ train_job_trainer=get_builtin_trainer(
825
+ args=[
826
+ "model.apply_lora_to_mlp=True",
827
+ "model.lora_rank=8",
828
+ "model.lora_alpha=16",
829
+ "model.lora_dropout=0.1",
830
+ "model.lora_attn_modules=[q_proj,v_proj,output_proj]",
831
+ ],
832
+ ),
833
+ ),
834
+ ),
835
+ TestCase(
836
+ name="valid flow with custom trainer",
837
+ expected_status=SUCCESS,
838
+ config={
839
+ "trainer": types.CustomTrainer(
840
+ func=lambda: print("Hello World"),
841
+ func_args={"learning_rate": 0.001, "batch_size": 32},
842
+ packages_to_install=["torch", "numpy"],
843
+ pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
844
+ num_nodes=2,
845
+ )
846
+ },
847
+ expected_output=get_train_job(
848
+ runtime_name=TORCH_RUNTIME,
849
+ train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
850
+ train_job_trainer=get_custom_trainer(
851
+ pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
852
+ packages_to_install=["torch", "numpy"],
853
+ ),
854
+ ),
855
+ ),
856
+ TestCase(
857
+ name="valid flow with custom trainer that has env and image",
858
+ expected_status=SUCCESS,
859
+ config={
860
+ "trainer": types.CustomTrainer(
861
+ func=lambda: print("Hello World"),
862
+ func_args={"learning_rate": 0.001, "batch_size": 32},
863
+ packages_to_install=["torch", "numpy"],
864
+ pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
865
+ num_nodes=2,
866
+ env={
867
+ "TEST_ENV": "test_value",
868
+ "ANOTHER_ENV": "another_value",
869
+ },
870
+ image="my-custom-image",
871
+ )
872
+ },
873
+ expected_output=get_train_job(
874
+ runtime_name=TORCH_RUNTIME,
875
+ train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
876
+ train_job_trainer=get_custom_trainer(
877
+ env=[
878
+ models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"),
879
+ models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"),
880
+ ],
881
+ pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
882
+ packages_to_install=["torch", "numpy"],
883
+ image="my-custom-image",
884
+ ),
885
+ ),
886
+ ),
887
+ TestCase(
888
+ name="valid flow with custom trainer container",
889
+ expected_status=SUCCESS,
890
+ config={
891
+ "trainer": types.CustomTrainerContainer(
892
+ image="example.com/my-image",
893
+ num_nodes=2,
894
+ resources_per_node={"cpu": 5, "gpu": 3},
895
+ env={
896
+ "TEST_ENV": "test_value",
897
+ "ANOTHER_ENV": "another_value",
898
+ },
899
+ )
900
+ },
901
+ expected_output=get_train_job(
902
+ runtime_name=TORCH_RUNTIME,
903
+ train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
904
+ train_job_trainer=get_custom_trainer_container(
905
+ image="example.com/my-image",
906
+ num_nodes=2,
907
+ resources_per_node=models.IoK8sApiCoreV1ResourceRequirements(
908
+ requests={
909
+ "cpu": models.IoK8sApimachineryPkgApiResourceQuantity(5),
910
+ "nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity(3),
911
+ },
912
+ limits={
913
+ "cpu": models.IoK8sApimachineryPkgApiResourceQuantity(5),
914
+ "nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity(3),
915
+ },
916
+ ),
917
+ env=[
918
+ models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"),
919
+ models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"),
920
+ ],
921
+ ),
922
+ ),
923
+ ),
924
+ TestCase(
925
+ name="timeout error when deleting job",
926
+ expected_status=FAILED,
927
+ config={
928
+ "namespace": TIMEOUT,
929
+ },
930
+ expected_error=TimeoutError,
931
+ ),
932
+ TestCase(
933
+ name="runtime error when deleting job",
934
+ expected_status=FAILED,
935
+ config={
936
+ "namespace": RUNTIME,
937
+ },
938
+ expected_error=RuntimeError,
939
+ ),
940
+ TestCase(
941
+ name="value error when runtime doesn't support CustomTrainer",
942
+ expected_status=FAILED,
943
+ config={
944
+ "trainer": types.CustomTrainer(
945
+ func=lambda: print("Hello World"),
946
+ num_nodes=2,
947
+ ),
948
+ "runtime": TORCH_TUNE_RUNTIME,
949
+ },
950
+ expected_error=ValueError,
951
+ ),
952
+ TestCase(
953
+ name="train with metadata labels and annotations",
954
+ expected_status=SUCCESS,
955
+ config={
956
+ "options": [
957
+ Labels({"team": "ml-platform"}),
958
+ Annotations({"created-by": "sdk"}),
959
+ ],
960
+ },
961
+ expected_output=get_train_job(
962
+ runtime_name=TORCH_RUNTIME,
963
+ train_job_name=BASIC_TRAIN_JOB_NAME,
964
+ labels={"team": "ml-platform"},
965
+ annotations={"created-by": "sdk"},
966
+ ),
967
+ ),
968
+ TestCase(
969
+ name="train with spec labels and annotations",
970
+ expected_status=SUCCESS,
971
+ config={
972
+ "options": [
973
+ SpecLabels({"app": "training", "version": "v1.0"}),
974
+ SpecAnnotations({"prometheus.io/scrape": "true"}),
975
+ ],
976
+ },
977
+ expected_output=get_train_job(
978
+ runtime_name=TORCH_RUNTIME,
979
+ train_job_name=BASIC_TRAIN_JOB_NAME,
980
+ spec_labels={"app": "training", "version": "v1.0"},
981
+ spec_annotations={"prometheus.io/scrape": "true"},
982
+ ),
983
+ ),
984
+ TestCase(
985
+ name="train with both metadata and spec labels/annotations",
986
+ expected_status=SUCCESS,
987
+ config={
988
+ "options": [
989
+ Labels({"owner": "ml-team"}),
990
+ Annotations({"description": "Fine-tuning job"}),
991
+ SpecLabels({"app": "training", "version": "v1.0"}),
992
+ SpecAnnotations({"prometheus.io/scrape": "true"}),
993
+ ],
994
+ },
995
+ expected_output=get_train_job(
996
+ runtime_name=TORCH_RUNTIME,
997
+ train_job_name=BASIC_TRAIN_JOB_NAME,
998
+ labels={"owner": "ml-team"},
999
+ annotations={"description": "Fine-tuning job"},
1000
+ spec_labels={"app": "training", "version": "v1.0"},
1001
+ spec_annotations={"prometheus.io/scrape": "true"},
1002
+ ),
1003
+ ),
1004
+ ],
1005
+ )
1006
+ def test_train(kubernetes_backend, test_case):
1007
+ """Test KubernetesBackend.train with basic success path."""
1008
+ print("Executing test:", test_case.name)
1009
+ try:
1010
+ kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
1011
+ runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME))
1012
+
1013
+ options = test_case.config.get("options", [])
1014
+
1015
+ train_job_name = kubernetes_backend.train(
1016
+ runtime=runtime,
1017
+ trainer=test_case.config.get("trainer", None),
1018
+ options=options,
1019
+ )
1020
+
1021
+ assert test_case.expected_status == SUCCESS
1022
+
1023
+ # This is to get around the fact that the train job name is dynamically generated
1024
+ # In the future name generation may be more deterministic, and we can revisit this approach
1025
+ expected_output = test_case.expected_output
1026
+ expected_output.metadata.name = train_job_name
1027
+
1028
+ kubernetes_backend.custom_api.create_namespaced_custom_object.assert_called_with(
1029
+ constants.GROUP,
1030
+ constants.VERSION,
1031
+ DEFAULT_NAMESPACE,
1032
+ constants.TRAINJOB_PLURAL,
1033
+ expected_output.to_dict(),
1034
+ )
1035
+
1036
+ except Exception as e:
1037
+ assert type(e) is test_case.expected_error
1038
+ print("test execution complete")
1039
+
1040
+
1041
+ @pytest.mark.parametrize(
1042
+ "test_case",
1043
+ [
1044
+ TestCase(
1045
+ name="valid flow with all defaults",
1046
+ expected_status=SUCCESS,
1047
+ config={"name": BASIC_TRAIN_JOB_NAME},
1048
+ expected_output=get_train_job_data_type(
1049
+ runtime_name=TORCH_RUNTIME,
1050
+ train_job_name=BASIC_TRAIN_JOB_NAME,
1051
+ ),
1052
+ ),
1053
+ TestCase(
1054
+ name="timeout error when getting job",
1055
+ expected_status=FAILED,
1056
+ config={"name": TIMEOUT},
1057
+ expected_error=TimeoutError,
1058
+ ),
1059
+ TestCase(
1060
+ name="runtime error when getting job",
1061
+ expected_status=FAILED,
1062
+ config={"name": RUNTIME},
1063
+ expected_error=RuntimeError,
1064
+ ),
1065
+ ],
1066
+ )
1067
+ def test_get_job(kubernetes_backend, test_case):
1068
+ """Test KubernetesBackend.get_job with basic success path."""
1069
+ print("Executing test:", test_case.name)
1070
+ try:
1071
+ job = kubernetes_backend.get_job(**test_case.config)
1072
+
1073
+ assert test_case.expected_status == SUCCESS
1074
+ assert asdict(job) == asdict(test_case.expected_output)
1075
+
1076
+ except Exception as e:
1077
+ assert type(e) is test_case.expected_error
1078
+ print("test execution complete")
1079
+
1080
+
1081
+ @pytest.mark.parametrize(
1082
+ "test_case",
1083
+ [
1084
+ TestCase(
1085
+ name="valid flow with all defaults",
1086
+ expected_status=SUCCESS,
1087
+ config={},
1088
+ expected_output=[
1089
+ get_train_job_data_type(
1090
+ runtime_name=TORCH_RUNTIME,
1091
+ train_job_name="basic-job-1",
1092
+ ),
1093
+ get_train_job_data_type(
1094
+ runtime_name=TORCH_RUNTIME,
1095
+ train_job_name="basic-job-2",
1096
+ ),
1097
+ ],
1098
+ ),
1099
+ TestCase(
1100
+ name="timeout error when listing jobs",
1101
+ expected_status=FAILED,
1102
+ config={"namespace": TIMEOUT},
1103
+ expected_error=TimeoutError,
1104
+ ),
1105
+ TestCase(
1106
+ name="runtime error when listing jobs",
1107
+ expected_status=FAILED,
1108
+ config={"namespace": RUNTIME},
1109
+ expected_error=RuntimeError,
1110
+ ),
1111
+ ],
1112
+ )
1113
+ def test_list_jobs(kubernetes_backend, test_case):
1114
+ """Test KubernetesBackend.list_jobs with basic success path."""
1115
+ print("Executing test:", test_case.name)
1116
+ try:
1117
+ kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
1118
+ jobs = kubernetes_backend.list_jobs()
1119
+
1120
+ assert test_case.expected_status == SUCCESS
1121
+ assert isinstance(jobs, list)
1122
+ assert len(jobs) == 2
1123
+ assert [asdict(j) for j in jobs] == [asdict(r) for r in test_case.expected_output]
1124
+
1125
+ except Exception as e:
1126
+ assert type(e) is test_case.expected_error
1127
+ print("test execution complete")
1128
+
1129
+
1130
+ @pytest.mark.parametrize(
1131
+ "test_case",
1132
+ [
1133
+ TestCase(
1134
+ name="valid flow with all defaults",
1135
+ expected_status=SUCCESS,
1136
+ config={"name": BASIC_TRAIN_JOB_NAME},
1137
+ expected_output=["test log content"],
1138
+ ),
1139
+ TestCase(
1140
+ name="runtime error when getting logs",
1141
+ expected_status=FAILED,
1142
+ config={"name": BASIC_TRAIN_JOB_NAME, "namespace": FAIL_LOGS},
1143
+ expected_error=RuntimeError,
1144
+ ),
1145
+ ],
1146
+ )
1147
+ def test_get_job_logs(kubernetes_backend, test_case):
1148
+ """Test KubernetesBackend.get_job_logs with basic success path."""
1149
+ print("Executing test:", test_case.name)
1150
+ try:
1151
+ kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
1152
+ logs = kubernetes_backend.get_job_logs(test_case.config.get("name"))
1153
+ # Convert iterator to list for comparison.
1154
+ logs_list = list(logs)
1155
+ assert test_case.expected_status == SUCCESS
1156
+ assert logs_list == test_case.expected_output
1157
+ except Exception as e:
1158
+ assert type(e) is test_case.expected_error
1159
+ print("test execution complete")
1160
+
1161
+
1162
+ @pytest.mark.parametrize(
1163
+ "test_case",
1164
+ [
1165
+ TestCase(
1166
+ name="wait for complete status (default)",
1167
+ expected_status=SUCCESS,
1168
+ config={"name": BASIC_TRAIN_JOB_NAME},
1169
+ expected_output=get_train_job_data_type(
1170
+ runtime_name=TORCH_RUNTIME,
1171
+ train_job_name=BASIC_TRAIN_JOB_NAME,
1172
+ ),
1173
+ ),
1174
+ TestCase(
1175
+ name="wait for multiple statuses",
1176
+ expected_status=SUCCESS,
1177
+ config={
1178
+ "name": BASIC_TRAIN_JOB_NAME,
1179
+ "status": {constants.TRAINJOB_RUNNING, constants.TRAINJOB_COMPLETE},
1180
+ },
1181
+ expected_output=get_train_job_data_type(
1182
+ runtime_name=TORCH_RUNTIME,
1183
+ train_job_name=BASIC_TRAIN_JOB_NAME,
1184
+ ),
1185
+ ),
1186
+ TestCase(
1187
+ name="invalid status set error",
1188
+ expected_status=FAILED,
1189
+ config={
1190
+ "name": BASIC_TRAIN_JOB_NAME,
1191
+ "status": {"InvalidStatus"},
1192
+ },
1193
+ expected_error=ValueError,
1194
+ ),
1195
+ TestCase(
1196
+ name="polling interval is more than timeout error",
1197
+ expected_status=FAILED,
1198
+ config={
1199
+ "name": BASIC_TRAIN_JOB_NAME,
1200
+ "timeout": 1,
1201
+ "polling_interval": 2,
1202
+ },
1203
+ expected_error=ValueError,
1204
+ ),
1205
+ TestCase(
1206
+ name="job failed when not expected",
1207
+ expected_status=FAILED,
1208
+ config={
1209
+ "name": "failed-job",
1210
+ "status": {constants.TRAINJOB_RUNNING},
1211
+ },
1212
+ expected_error=RuntimeError,
1213
+ ),
1214
+ TestCase(
1215
+ name="timeout error to wait for failed status",
1216
+ expected_status=FAILED,
1217
+ config={
1218
+ "name": BASIC_TRAIN_JOB_NAME,
1219
+ "status": {constants.TRAINJOB_FAILED},
1220
+ "polling_interval": 1,
1221
+ "timeout": 2,
1222
+ },
1223
+ expected_error=TimeoutError,
1224
+ ),
1225
+ ],
1226
+ )
1227
+ def test_wait_for_job_status(kubernetes_backend, test_case):
1228
+ """Test KubernetesBackend.wait_for_job_status with various scenarios."""
1229
+ print("Executing test:", test_case.name)
1230
+
1231
+ original_get_job = kubernetes_backend.get_job
1232
+
1233
+ # TrainJob has unexpected failed status.
1234
+ def mock_get_job(name):
1235
+ job = original_get_job(name)
1236
+ if test_case.config.get("name") == "failed-job":
1237
+ job.status = constants.TRAINJOB_FAILED
1238
+ return job
1239
+
1240
+ kubernetes_backend.get_job = mock_get_job
1241
+
1242
+ try:
1243
+ job = kubernetes_backend.wait_for_job_status(**test_case.config)
1244
+
1245
+ assert test_case.expected_status == SUCCESS
1246
+ assert isinstance(job, types.TrainJob)
1247
+ # Job status should be in the expected set.
1248
+ assert job.status in test_case.config.get("status", {constants.TRAINJOB_COMPLETE})
1249
+
1250
+ except Exception as e:
1251
+ assert type(e) is test_case.expected_error
1252
+
1253
+ print("test execution complete")
1254
+
1255
+
1256
+ @pytest.mark.parametrize(
1257
+ "test_case",
1258
+ [
1259
+ TestCase(
1260
+ name="valid flow with all defaults",
1261
+ expected_status=SUCCESS,
1262
+ config={"name": BASIC_TRAIN_JOB_NAME},
1263
+ expected_output=None,
1264
+ ),
1265
+ TestCase(
1266
+ name="timeout error when deleting job",
1267
+ expected_status=FAILED,
1268
+ config={"namespace": TIMEOUT},
1269
+ expected_error=TimeoutError,
1270
+ ),
1271
+ TestCase(
1272
+ name="runtime error when deleting job",
1273
+ expected_status=FAILED,
1274
+ config={"namespace": RUNTIME},
1275
+ expected_error=RuntimeError,
1276
+ ),
1277
+ ],
1278
+ )
1279
+ def test_delete_job(kubernetes_backend, test_case):
1280
+ """Test KubernetesBackend.delete_job with basic success path."""
1281
+ print("Executing test:", test_case.name)
1282
+ try:
1283
+ kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
1284
+ kubernetes_backend.delete_job(test_case.config.get("name"))
1285
+ assert test_case.expected_status == SUCCESS
1286
+
1287
+ except Exception as e:
1288
+ assert type(e) is test_case.expected_error
1289
+ print("test execution complete")
1290
+
1291
+
1292
+ @pytest.mark.parametrize(
1293
+ "test_case",
1294
+ [
1295
+ TestCase(
1296
+ name="get job events with valid trainjob",
1297
+ expected_status=SUCCESS,
1298
+ config={"name": BASIC_TRAIN_JOB_NAME},
1299
+ expected_output=[
1300
+ types.Event(
1301
+ involved_object_kind=constants.TRAINJOB_KIND,
1302
+ involved_object_name=BASIC_TRAIN_JOB_NAME,
1303
+ message="TrainJob created successfully",
1304
+ reason="Created",
1305
+ event_time=datetime.datetime(2025, 6, 1, 10, 30, 0),
1306
+ ),
1307
+ types.Event(
1308
+ involved_object_kind="Pod",
1309
+ involved_object_name="node-0-pod",
1310
+ message="Pod scheduled successfully",
1311
+ reason="Scheduled",
1312
+ event_time=datetime.datetime(2025, 6, 1, 10, 31, 0),
1313
+ ),
1314
+ ],
1315
+ ),
1316
+ TestCase(
1317
+ name="timeout error when getting job events",
1318
+ expected_status=FAILED,
1319
+ config={"namespace": TIMEOUT, "name": BASIC_TRAIN_JOB_NAME},
1320
+ expected_error=TimeoutError,
1321
+ ),
1322
+ TestCase(
1323
+ name="runtime error when getting job events",
1324
+ expected_status=FAILED,
1325
+ config={"namespace": RUNTIME, "name": BASIC_TRAIN_JOB_NAME},
1326
+ expected_error=RuntimeError,
1327
+ ),
1328
+ ],
1329
+ )
1330
+ def test_get_job_events(kubernetes_backend, test_case):
1331
+ """Test KubernetesBackend.get_job_events with various scenarios."""
1332
+ print("Executing test:", test_case.name)
1333
+ try:
1334
+ kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
1335
+ events = kubernetes_backend.get_job_events(test_case.config.get("name"))
1336
+
1337
+ assert test_case.expected_status == SUCCESS
1338
+ assert isinstance(events, list)
1339
+ assert len(events) == len(test_case.expected_output)
1340
+ assert [asdict(e) for e in events] == [asdict(e) for e in test_case.expected_output]
1341
+
1342
+ except Exception as e:
1343
+ assert type(e) is test_case.expected_error
1344
+ print("test execution complete")