viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,1344 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Unit tests for the KubernetesBackend class in the Kubeflow Trainer SDK.
|
|
17
|
+
|
|
18
|
+
This module uses pytest and unittest.mock to simulate Kubernetes API interactions.
|
|
19
|
+
It tests KubernetesBackend's behavior across job listing, resource creation etc
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from dataclasses import asdict
|
|
23
|
+
import datetime
|
|
24
|
+
import multiprocessing
|
|
25
|
+
import random
|
|
26
|
+
import string
|
|
27
|
+
from typing import Optional
|
|
28
|
+
from unittest.mock import Mock, patch
|
|
29
|
+
import uuid
|
|
30
|
+
|
|
31
|
+
from kubeflow_trainer_api import models
|
|
32
|
+
import pytest
|
|
33
|
+
|
|
34
|
+
from viettelcloud.aiplatform.common.types import KubernetesBackendConfig
|
|
35
|
+
from viettelcloud.aiplatform.trainer.backends.kubernetes.backend import KubernetesBackend
|
|
36
|
+
import viettelcloud.aiplatform.trainer.backends.kubernetes.utils as utils
|
|
37
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
38
|
+
from viettelcloud.aiplatform.trainer.options import (
|
|
39
|
+
Annotations,
|
|
40
|
+
Labels,
|
|
41
|
+
SpecAnnotations,
|
|
42
|
+
SpecLabels,
|
|
43
|
+
)
|
|
44
|
+
from viettelcloud.aiplatform.trainer.test.common import (
|
|
45
|
+
DEFAULT_NAMESPACE,
|
|
46
|
+
FAILED,
|
|
47
|
+
RUNTIME,
|
|
48
|
+
SUCCESS,
|
|
49
|
+
TIMEOUT,
|
|
50
|
+
TestCase,
|
|
51
|
+
)
|
|
52
|
+
from viettelcloud.aiplatform.trainer.types import types
|
|
53
|
+
|
|
54
|
+
# In all tests runtime name is equal to the framework name.
|
|
55
|
+
TORCH_RUNTIME = "torch"
|
|
56
|
+
TORCH_TUNE_RUNTIME = "torchtune"
|
|
57
|
+
|
|
58
|
+
# 2 nodes * 2 nproc
|
|
59
|
+
RUNTIME_DEVICES = "4"
|
|
60
|
+
|
|
61
|
+
FAIL_LOGS = "fail_logs"
|
|
62
|
+
LIST_RUNTIMES = "list_runtimes"
|
|
63
|
+
BASIC_TRAIN_JOB_NAME = "basic-job"
|
|
64
|
+
TRAIN_JOBS = "trainjobs"
|
|
65
|
+
TRAIN_JOB_WITH_BUILT_IN_TRAINER = "train-job-with-built-in-trainer"
|
|
66
|
+
TRAIN_JOB_WITH_CUSTOM_TRAINER = "train-job-with-custom-trainer"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# --------------------------
|
|
70
|
+
# Fixtures
|
|
71
|
+
# --------------------------
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest.fixture
|
|
75
|
+
def kubernetes_backend(request):
|
|
76
|
+
"""Provide a KubernetesBackend with mocked Kubernetes APIs."""
|
|
77
|
+
with (
|
|
78
|
+
patch("kubernetes.config.load_kube_config", return_value=None),
|
|
79
|
+
patch(
|
|
80
|
+
"kubernetes.client.CustomObjectsApi",
|
|
81
|
+
return_value=Mock(
|
|
82
|
+
create_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
|
|
83
|
+
patch_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
|
|
84
|
+
delete_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
|
|
85
|
+
get_namespaced_custom_object=Mock(
|
|
86
|
+
side_effect=get_namespaced_custom_object_response
|
|
87
|
+
),
|
|
88
|
+
get_cluster_custom_object=Mock(side_effect=get_cluster_custom_object_response),
|
|
89
|
+
list_namespaced_custom_object=Mock(
|
|
90
|
+
side_effect=list_namespaced_custom_object_response
|
|
91
|
+
),
|
|
92
|
+
list_cluster_custom_object=Mock(side_effect=list_cluster_custom_object),
|
|
93
|
+
),
|
|
94
|
+
),
|
|
95
|
+
patch(
|
|
96
|
+
"kubernetes.client.CoreV1Api",
|
|
97
|
+
return_value=Mock(
|
|
98
|
+
list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response),
|
|
99
|
+
read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log),
|
|
100
|
+
list_namespaced_event=Mock(side_effect=mock_list_namespaced_event),
|
|
101
|
+
),
|
|
102
|
+
),
|
|
103
|
+
):
|
|
104
|
+
yield KubernetesBackend(KubernetesBackendConfig())
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# --------------------------
|
|
108
|
+
# Mock Handlers
|
|
109
|
+
# --------------------------
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def conditional_error_handler(*args, **kwargs):
|
|
113
|
+
"""Raise simulated errors based on resource name."""
|
|
114
|
+
if args[2] == TIMEOUT:
|
|
115
|
+
raise multiprocessing.TimeoutError()
|
|
116
|
+
elif args[2] == RUNTIME:
|
|
117
|
+
raise RuntimeError()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def list_namespaced_pod_response(*args, **kwargs):
|
|
121
|
+
"""Return mock pod list response."""
|
|
122
|
+
pod_list = get_mock_pod_list()
|
|
123
|
+
mock_thread = Mock()
|
|
124
|
+
mock_thread.get.return_value = pod_list
|
|
125
|
+
return mock_thread
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_mock_pod_list():
|
|
129
|
+
"""Create a mocked Kubernetes PodList object with pods for different training steps."""
|
|
130
|
+
return models.IoK8sApiCoreV1PodList(
|
|
131
|
+
items=[
|
|
132
|
+
# Dataset initializer pod
|
|
133
|
+
models.IoK8sApiCoreV1Pod(
|
|
134
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
135
|
+
name="dataset-initializer-pod",
|
|
136
|
+
namespace=DEFAULT_NAMESPACE,
|
|
137
|
+
labels={
|
|
138
|
+
constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
|
|
139
|
+
constants.JOBSET_RJOB_NAME_LABEL: constants.DATASET_INITIALIZER,
|
|
140
|
+
constants.JOB_INDEX_LABEL: "0",
|
|
141
|
+
},
|
|
142
|
+
),
|
|
143
|
+
spec=models.IoK8sApiCoreV1PodSpec(
|
|
144
|
+
containers=[
|
|
145
|
+
models.IoK8sApiCoreV1Container(
|
|
146
|
+
name=constants.DATASET_INITIALIZER,
|
|
147
|
+
image="dataset-initializer:latest",
|
|
148
|
+
command=["python", "-m", "dataset_initializer"],
|
|
149
|
+
)
|
|
150
|
+
]
|
|
151
|
+
),
|
|
152
|
+
status=models.IoK8sApiCoreV1PodStatus(phase="Running"),
|
|
153
|
+
),
|
|
154
|
+
# Model initializer pod
|
|
155
|
+
models.IoK8sApiCoreV1Pod(
|
|
156
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
157
|
+
name="model-initializer-pod",
|
|
158
|
+
namespace=DEFAULT_NAMESPACE,
|
|
159
|
+
labels={
|
|
160
|
+
constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
|
|
161
|
+
constants.JOBSET_RJOB_NAME_LABEL: constants.MODEL_INITIALIZER,
|
|
162
|
+
constants.JOB_INDEX_LABEL: "0",
|
|
163
|
+
},
|
|
164
|
+
),
|
|
165
|
+
spec=models.IoK8sApiCoreV1PodSpec(
|
|
166
|
+
containers=[
|
|
167
|
+
models.IoK8sApiCoreV1Container(
|
|
168
|
+
name=constants.MODEL_INITIALIZER,
|
|
169
|
+
image="model-initializer:latest",
|
|
170
|
+
command=["python", "-m", "model_initializer"],
|
|
171
|
+
)
|
|
172
|
+
]
|
|
173
|
+
),
|
|
174
|
+
status=models.IoK8sApiCoreV1PodStatus(phase="Running"),
|
|
175
|
+
),
|
|
176
|
+
# Training node pod
|
|
177
|
+
models.IoK8sApiCoreV1Pod(
|
|
178
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
179
|
+
name="node-0-pod",
|
|
180
|
+
namespace=DEFAULT_NAMESPACE,
|
|
181
|
+
labels={
|
|
182
|
+
constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
|
|
183
|
+
constants.JOBSET_RJOB_NAME_LABEL: constants.NODE,
|
|
184
|
+
constants.JOB_INDEX_LABEL: "0",
|
|
185
|
+
},
|
|
186
|
+
),
|
|
187
|
+
spec=models.IoK8sApiCoreV1PodSpec(
|
|
188
|
+
containers=[
|
|
189
|
+
models.IoK8sApiCoreV1Container(
|
|
190
|
+
name=constants.NODE,
|
|
191
|
+
image="trainer:latest",
|
|
192
|
+
command=["python", "-m", "trainer"],
|
|
193
|
+
resources=get_resource_requirements(),
|
|
194
|
+
)
|
|
195
|
+
]
|
|
196
|
+
),
|
|
197
|
+
status=models.IoK8sApiCoreV1PodStatus(phase="Running"),
|
|
198
|
+
),
|
|
199
|
+
]
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements:
|
|
204
|
+
"""Create a mock ResourceRequirements object for testing."""
|
|
205
|
+
return models.IoK8sApiCoreV1ResourceRequirements(
|
|
206
|
+
requests={
|
|
207
|
+
"nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity("1"),
|
|
208
|
+
"memory": models.IoK8sApimachineryPkgApiResourceQuantity("2Gi"),
|
|
209
|
+
},
|
|
210
|
+
limits={
|
|
211
|
+
"nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity("1"),
|
|
212
|
+
"memory": models.IoK8sApimachineryPkgApiResourceQuantity("4Gi"),
|
|
213
|
+
},
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def get_custom_trainer(
|
|
218
|
+
env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None,
|
|
219
|
+
pip_index_urls: Optional[list[str]] = constants.DEFAULT_PIP_INDEX_URLS,
|
|
220
|
+
packages_to_install: list[str] = ["torch", "numpy"],
|
|
221
|
+
image: Optional[str] = None,
|
|
222
|
+
) -> models.TrainerV1alpha1Trainer:
|
|
223
|
+
"""
|
|
224
|
+
Get the custom trainer for the TrainJob.
|
|
225
|
+
"""
|
|
226
|
+
pip_command = [f"--index-url {pip_index_urls[0]}"]
|
|
227
|
+
pip_command.extend([f"--extra-index-url {repo}" for repo in pip_index_urls[1:]])
|
|
228
|
+
pip_command = " ".join(pip_command)
|
|
229
|
+
|
|
230
|
+
packages_command = " ".join(packages_to_install)
|
|
231
|
+
return models.TrainerV1alpha1Trainer(
|
|
232
|
+
command=[
|
|
233
|
+
"bash",
|
|
234
|
+
"-c",
|
|
235
|
+
'\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip '
|
|
236
|
+
"|| python -m ensurepip --user || apt-get install python-pip"
|
|
237
|
+
"\nfi\n\n"
|
|
238
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
|
|
239
|
+
f" --no-warn-script-location {pip_command} --user {packages_command}"
|
|
240
|
+
" ||\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
|
|
241
|
+
f" --no-warn-script-location {pip_command} {packages_command}"
|
|
242
|
+
"\n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
|
|
243
|
+
'print("Hello World"),\n\n<lambda>(**'
|
|
244
|
+
"{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" "
|
|
245
|
+
'"$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"',
|
|
246
|
+
],
|
|
247
|
+
numNodes=2,
|
|
248
|
+
env=env,
|
|
249
|
+
image=image,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def get_custom_trainer_container(
|
|
254
|
+
image: str,
|
|
255
|
+
num_nodes: int,
|
|
256
|
+
resources_per_node: models.IoK8sApiCoreV1ResourceRequirements,
|
|
257
|
+
env: list[models.IoK8sApiCoreV1EnvVar],
|
|
258
|
+
) -> models.TrainerV1alpha1Trainer:
|
|
259
|
+
"""
|
|
260
|
+
Get the custom trainer container for the TrainJob.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
return models.TrainerV1alpha1Trainer(
|
|
264
|
+
image=image,
|
|
265
|
+
numNodes=num_nodes,
|
|
266
|
+
resourcesPerNode=resources_per_node,
|
|
267
|
+
env=env,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def get_builtin_trainer(
|
|
272
|
+
args: list[str],
|
|
273
|
+
) -> models.TrainerV1alpha1Trainer:
|
|
274
|
+
"""
|
|
275
|
+
Get the builtin trainer for the TrainJob.
|
|
276
|
+
"""
|
|
277
|
+
return models.TrainerV1alpha1Trainer(
|
|
278
|
+
args=args,
|
|
279
|
+
command=["tune", "run"],
|
|
280
|
+
numNodes=2,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def get_train_job(
|
|
285
|
+
runtime_name: str,
|
|
286
|
+
train_job_name: str = BASIC_TRAIN_JOB_NAME,
|
|
287
|
+
train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None,
|
|
288
|
+
labels: Optional[dict[str, str]] = None,
|
|
289
|
+
annotations: Optional[dict[str, str]] = None,
|
|
290
|
+
spec_labels: Optional[dict[str, str]] = None,
|
|
291
|
+
spec_annotations: Optional[dict[str, str]] = None,
|
|
292
|
+
) -> models.TrainerV1alpha1TrainJob:
|
|
293
|
+
"""
|
|
294
|
+
Create a mock TrainJob object with optional trainer configurations.
|
|
295
|
+
"""
|
|
296
|
+
train_job = models.TrainerV1alpha1TrainJob(
|
|
297
|
+
apiVersion=constants.API_VERSION,
|
|
298
|
+
kind=constants.TRAINJOB_KIND,
|
|
299
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
300
|
+
name=train_job_name,
|
|
301
|
+
labels=labels,
|
|
302
|
+
annotations=annotations,
|
|
303
|
+
),
|
|
304
|
+
spec=models.TrainerV1alpha1TrainJobSpec(
|
|
305
|
+
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
|
|
306
|
+
trainer=train_job_trainer,
|
|
307
|
+
labels=spec_labels,
|
|
308
|
+
annotations=spec_annotations,
|
|
309
|
+
),
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return train_job
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_cluster_custom_object_response(*args, **kwargs):
|
|
316
|
+
"""Return a mocked ClusterTrainingRuntime object."""
|
|
317
|
+
mock_thread = Mock()
|
|
318
|
+
if args[3] == TIMEOUT:
|
|
319
|
+
raise multiprocessing.TimeoutError()
|
|
320
|
+
if args[3] == RUNTIME:
|
|
321
|
+
raise RuntimeError()
|
|
322
|
+
if args[2] == constants.CLUSTER_TRAINING_RUNTIME_PLURAL:
|
|
323
|
+
mock_thread.get.return_value = normalize_model(
|
|
324
|
+
create_cluster_training_runtime(name=args[3]),
|
|
325
|
+
models.TrainerV1alpha1ClusterTrainingRuntime,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return mock_thread
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def get_namespaced_custom_object_response(*args, **kwargs):
|
|
332
|
+
"""Return a mocked TrainJob object."""
|
|
333
|
+
mock_thread = Mock()
|
|
334
|
+
if args[2] == TIMEOUT or args[4] == TIMEOUT:
|
|
335
|
+
raise multiprocessing.TimeoutError()
|
|
336
|
+
if args[2] == RUNTIME or args[4] == RUNTIME:
|
|
337
|
+
raise RuntimeError()
|
|
338
|
+
if args[3] == TRAIN_JOBS: # TODO: review this.
|
|
339
|
+
mock_thread.get.return_value = add_status(create_train_job(train_job_name=args[4]))
|
|
340
|
+
|
|
341
|
+
return mock_thread
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def add_status(
|
|
345
|
+
train_job: models.TrainerV1alpha1TrainJob,
|
|
346
|
+
) -> models.TrainerV1alpha1TrainJob:
|
|
347
|
+
"""
|
|
348
|
+
Add status information to the train job.
|
|
349
|
+
"""
|
|
350
|
+
# Set initial status to Created
|
|
351
|
+
status = models.TrainerV1alpha1TrainJobStatus(
|
|
352
|
+
conditions=[
|
|
353
|
+
models.IoK8sApimachineryPkgApisMetaV1Condition(
|
|
354
|
+
type="Complete",
|
|
355
|
+
status="True",
|
|
356
|
+
lastTransitionTime=datetime.datetime.now(),
|
|
357
|
+
reason="JobCompleted",
|
|
358
|
+
message="Job completed successfully",
|
|
359
|
+
)
|
|
360
|
+
]
|
|
361
|
+
)
|
|
362
|
+
train_job.status = status
|
|
363
|
+
return train_job
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def list_namespaced_custom_object_response(*args, **kwargs):
|
|
367
|
+
"""Return a list of mocked TrainJob objects."""
|
|
368
|
+
mock_thread = Mock()
|
|
369
|
+
if args[2] == TIMEOUT:
|
|
370
|
+
raise multiprocessing.TimeoutError()
|
|
371
|
+
if args[2] == RUNTIME:
|
|
372
|
+
raise RuntimeError()
|
|
373
|
+
if args[3] == constants.TRAINJOB_PLURAL:
|
|
374
|
+
items = [
|
|
375
|
+
add_status(create_train_job(train_job_name="basic-job-1")),
|
|
376
|
+
add_status(create_train_job(train_job_name="basic-job-2")),
|
|
377
|
+
]
|
|
378
|
+
mock_thread.get.return_value = normalize_model(
|
|
379
|
+
models.TrainerV1alpha1TrainJobList(items=items),
|
|
380
|
+
models.TrainerV1alpha1TrainJobList,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
return mock_thread
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def list_cluster_custom_object(*args, **kwargs):
|
|
387
|
+
"""Return a generic mocked response for cluster object listing."""
|
|
388
|
+
mock_thread = Mock()
|
|
389
|
+
if args[2] == TIMEOUT:
|
|
390
|
+
raise multiprocessing.TimeoutError()
|
|
391
|
+
if args[2] == RUNTIME:
|
|
392
|
+
raise RuntimeError()
|
|
393
|
+
if args[2] == constants.CLUSTER_TRAINING_RUNTIME_PLURAL:
|
|
394
|
+
items = [
|
|
395
|
+
create_cluster_training_runtime(name="runtime-1"),
|
|
396
|
+
create_cluster_training_runtime(name="runtime-2"),
|
|
397
|
+
]
|
|
398
|
+
mock_thread.get.return_value = normalize_model(
|
|
399
|
+
models.TrainerV1alpha1ClusterTrainingRuntimeList(items=items),
|
|
400
|
+
models.TrainerV1alpha1ClusterTrainingRuntimeList,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return mock_thread
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def mock_read_namespaced_pod_log(*args, **kwargs):
|
|
407
|
+
"""Simulate log retrieval from a pod."""
|
|
408
|
+
if kwargs.get("namespace") == FAIL_LOGS:
|
|
409
|
+
raise Exception("Failed to read logs")
|
|
410
|
+
return "test log content"
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def mock_list_namespaced_event(*args, **kwargs):
|
|
414
|
+
"""Simulate event listing from namespace."""
|
|
415
|
+
namespace = kwargs.get("namespace")
|
|
416
|
+
|
|
417
|
+
# Errors occur at call time, not during .get()
|
|
418
|
+
if namespace == TIMEOUT:
|
|
419
|
+
raise multiprocessing.TimeoutError()
|
|
420
|
+
if namespace == RUNTIME:
|
|
421
|
+
raise RuntimeError()
|
|
422
|
+
|
|
423
|
+
mock_thread = Mock()
|
|
424
|
+
mock_thread.get.return_value = models.IoK8sApiCoreV1EventList(
|
|
425
|
+
items=[
|
|
426
|
+
models.IoK8sApiCoreV1Event(
|
|
427
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
428
|
+
name="test-event-1",
|
|
429
|
+
namespace=DEFAULT_NAMESPACE,
|
|
430
|
+
),
|
|
431
|
+
involvedObject=models.IoK8sApiCoreV1ObjectReference(
|
|
432
|
+
kind=constants.TRAINJOB_KIND,
|
|
433
|
+
name=BASIC_TRAIN_JOB_NAME,
|
|
434
|
+
namespace=DEFAULT_NAMESPACE,
|
|
435
|
+
),
|
|
436
|
+
message="TrainJob created successfully",
|
|
437
|
+
reason="Created",
|
|
438
|
+
firstTimestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
|
|
439
|
+
),
|
|
440
|
+
models.IoK8sApiCoreV1Event(
|
|
441
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
442
|
+
name="test-event-2",
|
|
443
|
+
namespace=DEFAULT_NAMESPACE,
|
|
444
|
+
),
|
|
445
|
+
involvedObject=models.IoK8sApiCoreV1ObjectReference(
|
|
446
|
+
kind="Pod",
|
|
447
|
+
name="node-0-pod",
|
|
448
|
+
namespace=DEFAULT_NAMESPACE,
|
|
449
|
+
),
|
|
450
|
+
message="Pod scheduled successfully",
|
|
451
|
+
reason="Scheduled",
|
|
452
|
+
firstTimestamp=datetime.datetime(2025, 6, 1, 10, 31, 0),
|
|
453
|
+
),
|
|
454
|
+
]
|
|
455
|
+
)
|
|
456
|
+
return mock_thread
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def mock_watch(*args, **kwargs):
|
|
460
|
+
"""Simulate watch event"""
|
|
461
|
+
if kwargs.get("timeout_seconds") == 1:
|
|
462
|
+
raise TimeoutError("Watch timeout")
|
|
463
|
+
|
|
464
|
+
events = [
|
|
465
|
+
{
|
|
466
|
+
"type": "MODIFIED",
|
|
467
|
+
"object": {
|
|
468
|
+
"metadata": {
|
|
469
|
+
"name": f"{BASIC_TRAIN_JOB_NAME}-node-0",
|
|
470
|
+
"labels": {
|
|
471
|
+
constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME,
|
|
472
|
+
constants.JOBSET_RJOB_NAME_LABEL: constants.NODE,
|
|
473
|
+
constants.JOB_INDEX_LABEL: "0",
|
|
474
|
+
},
|
|
475
|
+
},
|
|
476
|
+
"spec": {"containers": [{"name": constants.NODE}]},
|
|
477
|
+
"status": {"phase": "Running"},
|
|
478
|
+
},
|
|
479
|
+
}
|
|
480
|
+
]
|
|
481
|
+
|
|
482
|
+
return iter(events)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def normalize_model(model_obj, model_class):
|
|
486
|
+
# Simulate real api behavior
|
|
487
|
+
# Converts model to raw dictionary, like a real API response
|
|
488
|
+
# Parses dict and ensures correct model instantiation and type validation
|
|
489
|
+
return model_class.from_dict(model_obj.to_dict())
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
# --------------------------
|
|
493
|
+
# Object Creators
|
|
494
|
+
# --------------------------
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def create_train_job(
|
|
498
|
+
train_job_name: str = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11],
|
|
499
|
+
namespace: str = "default",
|
|
500
|
+
image: str = "pytorch/pytorch:latest",
|
|
501
|
+
initializer: Optional[types.Initializer] = None,
|
|
502
|
+
command: Optional[list] = None,
|
|
503
|
+
args: Optional[list] = None,
|
|
504
|
+
) -> models.TrainerV1alpha1TrainJob:
|
|
505
|
+
"""Create a mock TrainJob object."""
|
|
506
|
+
return models.TrainerV1alpha1TrainJob(
|
|
507
|
+
apiVersion=constants.API_VERSION,
|
|
508
|
+
kind=constants.TRAINJOB_KIND,
|
|
509
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
510
|
+
name=train_job_name,
|
|
511
|
+
namespace=namespace,
|
|
512
|
+
creationTimestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
|
|
513
|
+
),
|
|
514
|
+
spec=models.TrainerV1alpha1TrainJobSpec(
|
|
515
|
+
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=TORCH_RUNTIME),
|
|
516
|
+
trainer=None,
|
|
517
|
+
initializer=(
|
|
518
|
+
models.TrainerV1alpha1Initializer(
|
|
519
|
+
dataset=utils.get_dataset_initializer(initializer.dataset),
|
|
520
|
+
model=utils.get_model_initializer(initializer.model),
|
|
521
|
+
)
|
|
522
|
+
if initializer
|
|
523
|
+
else None
|
|
524
|
+
),
|
|
525
|
+
),
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def create_cluster_training_runtime(
|
|
530
|
+
name: str,
|
|
531
|
+
namespace: str = "default",
|
|
532
|
+
) -> models.TrainerV1alpha1ClusterTrainingRuntime:
|
|
533
|
+
"""Create a mock ClusterTrainingRuntime object."""
|
|
534
|
+
|
|
535
|
+
return models.TrainerV1alpha1ClusterTrainingRuntime(
|
|
536
|
+
apiVersion=constants.API_VERSION,
|
|
537
|
+
kind="ClusterTrainingRuntime",
|
|
538
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
539
|
+
name=name,
|
|
540
|
+
namespace=namespace,
|
|
541
|
+
labels={constants.RUNTIME_FRAMEWORK_LABEL: name},
|
|
542
|
+
),
|
|
543
|
+
spec=models.TrainerV1alpha1TrainingRuntimeSpec(
|
|
544
|
+
mlPolicy=models.TrainerV1alpha1MLPolicy(
|
|
545
|
+
torch=models.TrainerV1alpha1TorchMLPolicySource(
|
|
546
|
+
numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2)
|
|
547
|
+
),
|
|
548
|
+
numNodes=2,
|
|
549
|
+
),
|
|
550
|
+
template=models.TrainerV1alpha1JobSetTemplateSpec(
|
|
551
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
552
|
+
name=name,
|
|
553
|
+
namespace=namespace,
|
|
554
|
+
),
|
|
555
|
+
spec=models.JobsetV1alpha2JobSetSpec(replicatedJobs=[get_replicated_job()]),
|
|
556
|
+
),
|
|
557
|
+
),
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob:
|
|
562
|
+
return models.JobsetV1alpha2ReplicatedJob(
|
|
563
|
+
name="node",
|
|
564
|
+
replicas=1,
|
|
565
|
+
template=models.IoK8sApiBatchV1JobTemplateSpec(
|
|
566
|
+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
|
|
567
|
+
labels={"trainer.kubeflow.org/trainjob-ancestor-step": "trainer"}
|
|
568
|
+
),
|
|
569
|
+
spec=models.IoK8sApiBatchV1JobSpec(
|
|
570
|
+
template=models.IoK8sApiCoreV1PodTemplateSpec(
|
|
571
|
+
spec=models.IoK8sApiCoreV1PodSpec(containers=[get_container()])
|
|
572
|
+
)
|
|
573
|
+
),
|
|
574
|
+
),
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def get_container() -> models.IoK8sApiCoreV1Container:
|
|
579
|
+
return models.IoK8sApiCoreV1Container(
|
|
580
|
+
name="node",
|
|
581
|
+
image="example.com/test-runtime",
|
|
582
|
+
command=["echo", "Hello World"],
|
|
583
|
+
resources=get_resource_requirements(),
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def create_runtime_type(
|
|
588
|
+
name: str,
|
|
589
|
+
) -> types.Runtime:
|
|
590
|
+
"""Create a mock Runtime object for testing."""
|
|
591
|
+
trainer = types.RuntimeTrainer(
|
|
592
|
+
trainer_type=types.TrainerType.CUSTOM_TRAINER,
|
|
593
|
+
framework=name,
|
|
594
|
+
num_nodes=2,
|
|
595
|
+
device="gpu",
|
|
596
|
+
device_count=RUNTIME_DEVICES,
|
|
597
|
+
image="example.com/test-runtime",
|
|
598
|
+
)
|
|
599
|
+
trainer.set_command(constants.TORCH_COMMAND)
|
|
600
|
+
return types.Runtime(
|
|
601
|
+
name=name,
|
|
602
|
+
trainer=trainer,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def get_train_job_data_type(
|
|
607
|
+
runtime_name: str,
|
|
608
|
+
train_job_name: str,
|
|
609
|
+
) -> types.TrainJob:
|
|
610
|
+
"""Create a mock TrainJob object with the expected structure for testing."""
|
|
611
|
+
|
|
612
|
+
trainer = types.RuntimeTrainer(
|
|
613
|
+
trainer_type=types.TrainerType.CUSTOM_TRAINER,
|
|
614
|
+
framework=runtime_name,
|
|
615
|
+
device="gpu",
|
|
616
|
+
device_count=RUNTIME_DEVICES,
|
|
617
|
+
num_nodes=2,
|
|
618
|
+
image="example.com/test-runtime",
|
|
619
|
+
)
|
|
620
|
+
trainer.set_command(constants.TORCH_COMMAND)
|
|
621
|
+
return types.TrainJob(
|
|
622
|
+
name=train_job_name,
|
|
623
|
+
creation_timestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
|
|
624
|
+
runtime=types.Runtime(
|
|
625
|
+
name=runtime_name,
|
|
626
|
+
trainer=trainer,
|
|
627
|
+
),
|
|
628
|
+
steps=[
|
|
629
|
+
types.Step(
|
|
630
|
+
name="dataset-initializer",
|
|
631
|
+
status="Running",
|
|
632
|
+
pod_name="dataset-initializer-pod",
|
|
633
|
+
device="Unknown",
|
|
634
|
+
device_count="Unknown",
|
|
635
|
+
),
|
|
636
|
+
types.Step(
|
|
637
|
+
name="model-initializer",
|
|
638
|
+
status="Running",
|
|
639
|
+
pod_name="model-initializer-pod",
|
|
640
|
+
device="Unknown",
|
|
641
|
+
device_count="Unknown",
|
|
642
|
+
),
|
|
643
|
+
types.Step(
|
|
644
|
+
name="node-0",
|
|
645
|
+
status="Running",
|
|
646
|
+
pod_name="node-0-pod",
|
|
647
|
+
device="gpu",
|
|
648
|
+
device_count="1",
|
|
649
|
+
),
|
|
650
|
+
],
|
|
651
|
+
num_nodes=2,
|
|
652
|
+
status="Complete",
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
# --------------------------
|
|
657
|
+
# Tests
|
|
658
|
+
# --------------------------
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
@pytest.mark.parametrize(
|
|
662
|
+
"test_case",
|
|
663
|
+
[
|
|
664
|
+
TestCase(
|
|
665
|
+
name="valid flow with all defaults",
|
|
666
|
+
expected_status=SUCCESS,
|
|
667
|
+
config={"name": TORCH_RUNTIME},
|
|
668
|
+
expected_output=create_runtime_type(name=TORCH_RUNTIME),
|
|
669
|
+
),
|
|
670
|
+
TestCase(
|
|
671
|
+
name="timeout error when getting runtime",
|
|
672
|
+
expected_status=FAILED,
|
|
673
|
+
config={"name": TIMEOUT},
|
|
674
|
+
expected_error=TimeoutError,
|
|
675
|
+
),
|
|
676
|
+
TestCase(
|
|
677
|
+
name="runtime error when getting runtime",
|
|
678
|
+
expected_status=FAILED,
|
|
679
|
+
config={"name": RUNTIME},
|
|
680
|
+
expected_error=RuntimeError,
|
|
681
|
+
),
|
|
682
|
+
],
|
|
683
|
+
)
|
|
684
|
+
def test_get_runtime(kubernetes_backend, test_case):
|
|
685
|
+
"""Test KubernetesBackend.get_runtime with basic success path."""
|
|
686
|
+
print("Executing test:", test_case.name)
|
|
687
|
+
try:
|
|
688
|
+
runtime = kubernetes_backend.get_runtime(**test_case.config)
|
|
689
|
+
|
|
690
|
+
assert test_case.expected_status == SUCCESS
|
|
691
|
+
assert isinstance(runtime, types.Runtime)
|
|
692
|
+
assert asdict(runtime) == asdict(test_case.expected_output)
|
|
693
|
+
|
|
694
|
+
except Exception as e:
|
|
695
|
+
assert type(e) is test_case.expected_error
|
|
696
|
+
print("test execution complete")
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
@pytest.mark.parametrize(
|
|
700
|
+
"test_case",
|
|
701
|
+
[
|
|
702
|
+
TestCase(
|
|
703
|
+
name="valid flow with all defaults",
|
|
704
|
+
expected_status=SUCCESS,
|
|
705
|
+
config={"name": LIST_RUNTIMES},
|
|
706
|
+
expected_output=[
|
|
707
|
+
create_runtime_type(name="runtime-1"),
|
|
708
|
+
create_runtime_type(name="runtime-2"),
|
|
709
|
+
],
|
|
710
|
+
),
|
|
711
|
+
],
|
|
712
|
+
)
|
|
713
|
+
def test_list_runtimes(kubernetes_backend, test_case):
|
|
714
|
+
"""Test KubernetesBackend.list_runtimes with basic success path."""
|
|
715
|
+
print("Executing test:", test_case.name)
|
|
716
|
+
try:
|
|
717
|
+
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
|
|
718
|
+
runtimes = kubernetes_backend.list_runtimes()
|
|
719
|
+
|
|
720
|
+
assert test_case.expected_status == SUCCESS
|
|
721
|
+
assert isinstance(runtimes, list)
|
|
722
|
+
assert all(isinstance(r, types.Runtime) for r in runtimes)
|
|
723
|
+
assert [asdict(r) for r in runtimes] == [asdict(r) for r in test_case.expected_output]
|
|
724
|
+
|
|
725
|
+
except Exception as e:
|
|
726
|
+
assert type(e) is test_case.expected_error
|
|
727
|
+
print("test execution complete")
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@pytest.mark.parametrize(
|
|
731
|
+
"test_case",
|
|
732
|
+
[
|
|
733
|
+
TestCase(
|
|
734
|
+
name="valid flow with custom trainer runtime",
|
|
735
|
+
expected_status=SUCCESS,
|
|
736
|
+
config={"runtime": create_runtime_type(name=TORCH_RUNTIME)},
|
|
737
|
+
),
|
|
738
|
+
TestCase(
|
|
739
|
+
name="value error with builtin trainer runtime",
|
|
740
|
+
expected_status=FAILED,
|
|
741
|
+
config={
|
|
742
|
+
"runtime": types.Runtime(
|
|
743
|
+
name="torchtune-runtime",
|
|
744
|
+
trainer=types.RuntimeTrainer(
|
|
745
|
+
trainer_type=types.TrainerType.BUILTIN_TRAINER,
|
|
746
|
+
framework="torchtune",
|
|
747
|
+
num_nodes=1,
|
|
748
|
+
device="cpu",
|
|
749
|
+
device_count="1",
|
|
750
|
+
image="example.com/image",
|
|
751
|
+
),
|
|
752
|
+
)
|
|
753
|
+
},
|
|
754
|
+
expected_error=ValueError,
|
|
755
|
+
),
|
|
756
|
+
],
|
|
757
|
+
)
|
|
758
|
+
def test_get_runtime_packages(kubernetes_backend, test_case):
|
|
759
|
+
"""Test KubernetesBackend.get_runtime_packages with basic success path."""
|
|
760
|
+
print("Executing test:", test_case.name)
|
|
761
|
+
|
|
762
|
+
try:
|
|
763
|
+
kubernetes_backend.get_runtime_packages(**test_case.config)
|
|
764
|
+
except Exception as e:
|
|
765
|
+
assert type(e) is test_case.expected_error
|
|
766
|
+
|
|
767
|
+
print("test execution complete")
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
@pytest.mark.parametrize(
|
|
771
|
+
"test_case",
|
|
772
|
+
[
|
|
773
|
+
TestCase(
|
|
774
|
+
name="valid flow with all defaults",
|
|
775
|
+
expected_status=SUCCESS,
|
|
776
|
+
config={},
|
|
777
|
+
expected_output=get_train_job(
|
|
778
|
+
runtime_name=TORCH_RUNTIME,
|
|
779
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
780
|
+
),
|
|
781
|
+
),
|
|
782
|
+
TestCase(
|
|
783
|
+
name="valid flow with built in trainer",
|
|
784
|
+
expected_status=SUCCESS,
|
|
785
|
+
config={
|
|
786
|
+
"trainer": types.BuiltinTrainer(
|
|
787
|
+
config=types.TorchTuneConfig(
|
|
788
|
+
num_nodes=2,
|
|
789
|
+
batch_size=2,
|
|
790
|
+
epochs=2,
|
|
791
|
+
loss=types.Loss.CEWithChunkedOutputLoss,
|
|
792
|
+
)
|
|
793
|
+
),
|
|
794
|
+
"runtime": TORCH_TUNE_RUNTIME,
|
|
795
|
+
},
|
|
796
|
+
expected_output=get_train_job(
|
|
797
|
+
runtime_name=TORCH_TUNE_RUNTIME,
|
|
798
|
+
train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER,
|
|
799
|
+
train_job_trainer=get_builtin_trainer(
|
|
800
|
+
args=["batch_size=2", "epochs=2", "loss=Loss.CEWithChunkedOutputLoss"],
|
|
801
|
+
),
|
|
802
|
+
),
|
|
803
|
+
),
|
|
804
|
+
TestCase(
|
|
805
|
+
name="valid flow with built in trainer and lora config",
|
|
806
|
+
expected_status=SUCCESS,
|
|
807
|
+
config={
|
|
808
|
+
"trainer": types.BuiltinTrainer(
|
|
809
|
+
config=types.TorchTuneConfig(
|
|
810
|
+
num_nodes=2,
|
|
811
|
+
peft_config=types.LoraConfig(
|
|
812
|
+
apply_lora_to_mlp=True,
|
|
813
|
+
lora_rank=8,
|
|
814
|
+
lora_alpha=16,
|
|
815
|
+
lora_dropout=0.1,
|
|
816
|
+
),
|
|
817
|
+
),
|
|
818
|
+
),
|
|
819
|
+
"runtime": TORCH_TUNE_RUNTIME,
|
|
820
|
+
},
|
|
821
|
+
expected_output=get_train_job(
|
|
822
|
+
runtime_name=TORCH_TUNE_RUNTIME,
|
|
823
|
+
train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER,
|
|
824
|
+
train_job_trainer=get_builtin_trainer(
|
|
825
|
+
args=[
|
|
826
|
+
"model.apply_lora_to_mlp=True",
|
|
827
|
+
"model.lora_rank=8",
|
|
828
|
+
"model.lora_alpha=16",
|
|
829
|
+
"model.lora_dropout=0.1",
|
|
830
|
+
"model.lora_attn_modules=[q_proj,v_proj,output_proj]",
|
|
831
|
+
],
|
|
832
|
+
),
|
|
833
|
+
),
|
|
834
|
+
),
|
|
835
|
+
TestCase(
|
|
836
|
+
name="valid flow with custom trainer",
|
|
837
|
+
expected_status=SUCCESS,
|
|
838
|
+
config={
|
|
839
|
+
"trainer": types.CustomTrainer(
|
|
840
|
+
func=lambda: print("Hello World"),
|
|
841
|
+
func_args={"learning_rate": 0.001, "batch_size": 32},
|
|
842
|
+
packages_to_install=["torch", "numpy"],
|
|
843
|
+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
|
|
844
|
+
num_nodes=2,
|
|
845
|
+
)
|
|
846
|
+
},
|
|
847
|
+
expected_output=get_train_job(
|
|
848
|
+
runtime_name=TORCH_RUNTIME,
|
|
849
|
+
train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
|
|
850
|
+
train_job_trainer=get_custom_trainer(
|
|
851
|
+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
|
|
852
|
+
packages_to_install=["torch", "numpy"],
|
|
853
|
+
),
|
|
854
|
+
),
|
|
855
|
+
),
|
|
856
|
+
TestCase(
|
|
857
|
+
name="valid flow with custom trainer that has env and image",
|
|
858
|
+
expected_status=SUCCESS,
|
|
859
|
+
config={
|
|
860
|
+
"trainer": types.CustomTrainer(
|
|
861
|
+
func=lambda: print("Hello World"),
|
|
862
|
+
func_args={"learning_rate": 0.001, "batch_size": 32},
|
|
863
|
+
packages_to_install=["torch", "numpy"],
|
|
864
|
+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
|
|
865
|
+
num_nodes=2,
|
|
866
|
+
env={
|
|
867
|
+
"TEST_ENV": "test_value",
|
|
868
|
+
"ANOTHER_ENV": "another_value",
|
|
869
|
+
},
|
|
870
|
+
image="my-custom-image",
|
|
871
|
+
)
|
|
872
|
+
},
|
|
873
|
+
expected_output=get_train_job(
|
|
874
|
+
runtime_name=TORCH_RUNTIME,
|
|
875
|
+
train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
|
|
876
|
+
train_job_trainer=get_custom_trainer(
|
|
877
|
+
env=[
|
|
878
|
+
models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"),
|
|
879
|
+
models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"),
|
|
880
|
+
],
|
|
881
|
+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
|
|
882
|
+
packages_to_install=["torch", "numpy"],
|
|
883
|
+
image="my-custom-image",
|
|
884
|
+
),
|
|
885
|
+
),
|
|
886
|
+
),
|
|
887
|
+
TestCase(
|
|
888
|
+
name="valid flow with custom trainer container",
|
|
889
|
+
expected_status=SUCCESS,
|
|
890
|
+
config={
|
|
891
|
+
"trainer": types.CustomTrainerContainer(
|
|
892
|
+
image="example.com/my-image",
|
|
893
|
+
num_nodes=2,
|
|
894
|
+
resources_per_node={"cpu": 5, "gpu": 3},
|
|
895
|
+
env={
|
|
896
|
+
"TEST_ENV": "test_value",
|
|
897
|
+
"ANOTHER_ENV": "another_value",
|
|
898
|
+
},
|
|
899
|
+
)
|
|
900
|
+
},
|
|
901
|
+
expected_output=get_train_job(
|
|
902
|
+
runtime_name=TORCH_RUNTIME,
|
|
903
|
+
train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
|
|
904
|
+
train_job_trainer=get_custom_trainer_container(
|
|
905
|
+
image="example.com/my-image",
|
|
906
|
+
num_nodes=2,
|
|
907
|
+
resources_per_node=models.IoK8sApiCoreV1ResourceRequirements(
|
|
908
|
+
requests={
|
|
909
|
+
"cpu": models.IoK8sApimachineryPkgApiResourceQuantity(5),
|
|
910
|
+
"nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity(3),
|
|
911
|
+
},
|
|
912
|
+
limits={
|
|
913
|
+
"cpu": models.IoK8sApimachineryPkgApiResourceQuantity(5),
|
|
914
|
+
"nvidia.com/gpu": models.IoK8sApimachineryPkgApiResourceQuantity(3),
|
|
915
|
+
},
|
|
916
|
+
),
|
|
917
|
+
env=[
|
|
918
|
+
models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"),
|
|
919
|
+
models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"),
|
|
920
|
+
],
|
|
921
|
+
),
|
|
922
|
+
),
|
|
923
|
+
),
|
|
924
|
+
TestCase(
|
|
925
|
+
name="timeout error when deleting job",
|
|
926
|
+
expected_status=FAILED,
|
|
927
|
+
config={
|
|
928
|
+
"namespace": TIMEOUT,
|
|
929
|
+
},
|
|
930
|
+
expected_error=TimeoutError,
|
|
931
|
+
),
|
|
932
|
+
TestCase(
|
|
933
|
+
name="runtime error when deleting job",
|
|
934
|
+
expected_status=FAILED,
|
|
935
|
+
config={
|
|
936
|
+
"namespace": RUNTIME,
|
|
937
|
+
},
|
|
938
|
+
expected_error=RuntimeError,
|
|
939
|
+
),
|
|
940
|
+
TestCase(
|
|
941
|
+
name="value error when runtime doesn't support CustomTrainer",
|
|
942
|
+
expected_status=FAILED,
|
|
943
|
+
config={
|
|
944
|
+
"trainer": types.CustomTrainer(
|
|
945
|
+
func=lambda: print("Hello World"),
|
|
946
|
+
num_nodes=2,
|
|
947
|
+
),
|
|
948
|
+
"runtime": TORCH_TUNE_RUNTIME,
|
|
949
|
+
},
|
|
950
|
+
expected_error=ValueError,
|
|
951
|
+
),
|
|
952
|
+
TestCase(
|
|
953
|
+
name="train with metadata labels and annotations",
|
|
954
|
+
expected_status=SUCCESS,
|
|
955
|
+
config={
|
|
956
|
+
"options": [
|
|
957
|
+
Labels({"team": "ml-platform"}),
|
|
958
|
+
Annotations({"created-by": "sdk"}),
|
|
959
|
+
],
|
|
960
|
+
},
|
|
961
|
+
expected_output=get_train_job(
|
|
962
|
+
runtime_name=TORCH_RUNTIME,
|
|
963
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
964
|
+
labels={"team": "ml-platform"},
|
|
965
|
+
annotations={"created-by": "sdk"},
|
|
966
|
+
),
|
|
967
|
+
),
|
|
968
|
+
TestCase(
|
|
969
|
+
name="train with spec labels and annotations",
|
|
970
|
+
expected_status=SUCCESS,
|
|
971
|
+
config={
|
|
972
|
+
"options": [
|
|
973
|
+
SpecLabels({"app": "training", "version": "v1.0"}),
|
|
974
|
+
SpecAnnotations({"prometheus.io/scrape": "true"}),
|
|
975
|
+
],
|
|
976
|
+
},
|
|
977
|
+
expected_output=get_train_job(
|
|
978
|
+
runtime_name=TORCH_RUNTIME,
|
|
979
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
980
|
+
spec_labels={"app": "training", "version": "v1.0"},
|
|
981
|
+
spec_annotations={"prometheus.io/scrape": "true"},
|
|
982
|
+
),
|
|
983
|
+
),
|
|
984
|
+
TestCase(
|
|
985
|
+
name="train with both metadata and spec labels/annotations",
|
|
986
|
+
expected_status=SUCCESS,
|
|
987
|
+
config={
|
|
988
|
+
"options": [
|
|
989
|
+
Labels({"owner": "ml-team"}),
|
|
990
|
+
Annotations({"description": "Fine-tuning job"}),
|
|
991
|
+
SpecLabels({"app": "training", "version": "v1.0"}),
|
|
992
|
+
SpecAnnotations({"prometheus.io/scrape": "true"}),
|
|
993
|
+
],
|
|
994
|
+
},
|
|
995
|
+
expected_output=get_train_job(
|
|
996
|
+
runtime_name=TORCH_RUNTIME,
|
|
997
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
998
|
+
labels={"owner": "ml-team"},
|
|
999
|
+
annotations={"description": "Fine-tuning job"},
|
|
1000
|
+
spec_labels={"app": "training", "version": "v1.0"},
|
|
1001
|
+
spec_annotations={"prometheus.io/scrape": "true"},
|
|
1002
|
+
),
|
|
1003
|
+
),
|
|
1004
|
+
],
|
|
1005
|
+
)
|
|
1006
|
+
def test_train(kubernetes_backend, test_case):
|
|
1007
|
+
"""Test KubernetesBackend.train with basic success path."""
|
|
1008
|
+
print("Executing test:", test_case.name)
|
|
1009
|
+
try:
|
|
1010
|
+
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
|
|
1011
|
+
runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME))
|
|
1012
|
+
|
|
1013
|
+
options = test_case.config.get("options", [])
|
|
1014
|
+
|
|
1015
|
+
train_job_name = kubernetes_backend.train(
|
|
1016
|
+
runtime=runtime,
|
|
1017
|
+
trainer=test_case.config.get("trainer", None),
|
|
1018
|
+
options=options,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
assert test_case.expected_status == SUCCESS
|
|
1022
|
+
|
|
1023
|
+
# This is to get around the fact that the train job name is dynamically generated
|
|
1024
|
+
# In the future name generation may be more deterministic, and we can revisit this approach
|
|
1025
|
+
expected_output = test_case.expected_output
|
|
1026
|
+
expected_output.metadata.name = train_job_name
|
|
1027
|
+
|
|
1028
|
+
kubernetes_backend.custom_api.create_namespaced_custom_object.assert_called_with(
|
|
1029
|
+
constants.GROUP,
|
|
1030
|
+
constants.VERSION,
|
|
1031
|
+
DEFAULT_NAMESPACE,
|
|
1032
|
+
constants.TRAINJOB_PLURAL,
|
|
1033
|
+
expected_output.to_dict(),
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
except Exception as e:
|
|
1037
|
+
assert type(e) is test_case.expected_error
|
|
1038
|
+
print("test execution complete")
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
@pytest.mark.parametrize(
|
|
1042
|
+
"test_case",
|
|
1043
|
+
[
|
|
1044
|
+
TestCase(
|
|
1045
|
+
name="valid flow with all defaults",
|
|
1046
|
+
expected_status=SUCCESS,
|
|
1047
|
+
config={"name": BASIC_TRAIN_JOB_NAME},
|
|
1048
|
+
expected_output=get_train_job_data_type(
|
|
1049
|
+
runtime_name=TORCH_RUNTIME,
|
|
1050
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
1051
|
+
),
|
|
1052
|
+
),
|
|
1053
|
+
TestCase(
|
|
1054
|
+
name="timeout error when getting job",
|
|
1055
|
+
expected_status=FAILED,
|
|
1056
|
+
config={"name": TIMEOUT},
|
|
1057
|
+
expected_error=TimeoutError,
|
|
1058
|
+
),
|
|
1059
|
+
TestCase(
|
|
1060
|
+
name="runtime error when getting job",
|
|
1061
|
+
expected_status=FAILED,
|
|
1062
|
+
config={"name": RUNTIME},
|
|
1063
|
+
expected_error=RuntimeError,
|
|
1064
|
+
),
|
|
1065
|
+
],
|
|
1066
|
+
)
|
|
1067
|
+
def test_get_job(kubernetes_backend, test_case):
|
|
1068
|
+
"""Test KubernetesBackend.get_job with basic success path."""
|
|
1069
|
+
print("Executing test:", test_case.name)
|
|
1070
|
+
try:
|
|
1071
|
+
job = kubernetes_backend.get_job(**test_case.config)
|
|
1072
|
+
|
|
1073
|
+
assert test_case.expected_status == SUCCESS
|
|
1074
|
+
assert asdict(job) == asdict(test_case.expected_output)
|
|
1075
|
+
|
|
1076
|
+
except Exception as e:
|
|
1077
|
+
assert type(e) is test_case.expected_error
|
|
1078
|
+
print("test execution complete")
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
@pytest.mark.parametrize(
|
|
1082
|
+
"test_case",
|
|
1083
|
+
[
|
|
1084
|
+
TestCase(
|
|
1085
|
+
name="valid flow with all defaults",
|
|
1086
|
+
expected_status=SUCCESS,
|
|
1087
|
+
config={},
|
|
1088
|
+
expected_output=[
|
|
1089
|
+
get_train_job_data_type(
|
|
1090
|
+
runtime_name=TORCH_RUNTIME,
|
|
1091
|
+
train_job_name="basic-job-1",
|
|
1092
|
+
),
|
|
1093
|
+
get_train_job_data_type(
|
|
1094
|
+
runtime_name=TORCH_RUNTIME,
|
|
1095
|
+
train_job_name="basic-job-2",
|
|
1096
|
+
),
|
|
1097
|
+
],
|
|
1098
|
+
),
|
|
1099
|
+
TestCase(
|
|
1100
|
+
name="timeout error when listing jobs",
|
|
1101
|
+
expected_status=FAILED,
|
|
1102
|
+
config={"namespace": TIMEOUT},
|
|
1103
|
+
expected_error=TimeoutError,
|
|
1104
|
+
),
|
|
1105
|
+
TestCase(
|
|
1106
|
+
name="runtime error when listing jobs",
|
|
1107
|
+
expected_status=FAILED,
|
|
1108
|
+
config={"namespace": RUNTIME},
|
|
1109
|
+
expected_error=RuntimeError,
|
|
1110
|
+
),
|
|
1111
|
+
],
|
|
1112
|
+
)
|
|
1113
|
+
def test_list_jobs(kubernetes_backend, test_case):
|
|
1114
|
+
"""Test KubernetesBackend.list_jobs with basic success path."""
|
|
1115
|
+
print("Executing test:", test_case.name)
|
|
1116
|
+
try:
|
|
1117
|
+
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
|
|
1118
|
+
jobs = kubernetes_backend.list_jobs()
|
|
1119
|
+
|
|
1120
|
+
assert test_case.expected_status == SUCCESS
|
|
1121
|
+
assert isinstance(jobs, list)
|
|
1122
|
+
assert len(jobs) == 2
|
|
1123
|
+
assert [asdict(j) for j in jobs] == [asdict(r) for r in test_case.expected_output]
|
|
1124
|
+
|
|
1125
|
+
except Exception as e:
|
|
1126
|
+
assert type(e) is test_case.expected_error
|
|
1127
|
+
print("test execution complete")
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
@pytest.mark.parametrize(
|
|
1131
|
+
"test_case",
|
|
1132
|
+
[
|
|
1133
|
+
TestCase(
|
|
1134
|
+
name="valid flow with all defaults",
|
|
1135
|
+
expected_status=SUCCESS,
|
|
1136
|
+
config={"name": BASIC_TRAIN_JOB_NAME},
|
|
1137
|
+
expected_output=["test log content"],
|
|
1138
|
+
),
|
|
1139
|
+
TestCase(
|
|
1140
|
+
name="runtime error when getting logs",
|
|
1141
|
+
expected_status=FAILED,
|
|
1142
|
+
config={"name": BASIC_TRAIN_JOB_NAME, "namespace": FAIL_LOGS},
|
|
1143
|
+
expected_error=RuntimeError,
|
|
1144
|
+
),
|
|
1145
|
+
],
|
|
1146
|
+
)
|
|
1147
|
+
def test_get_job_logs(kubernetes_backend, test_case):
|
|
1148
|
+
"""Test KubernetesBackend.get_job_logs with basic success path."""
|
|
1149
|
+
print("Executing test:", test_case.name)
|
|
1150
|
+
try:
|
|
1151
|
+
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
|
|
1152
|
+
logs = kubernetes_backend.get_job_logs(test_case.config.get("name"))
|
|
1153
|
+
# Convert iterator to list for comparison.
|
|
1154
|
+
logs_list = list(logs)
|
|
1155
|
+
assert test_case.expected_status == SUCCESS
|
|
1156
|
+
assert logs_list == test_case.expected_output
|
|
1157
|
+
except Exception as e:
|
|
1158
|
+
assert type(e) is test_case.expected_error
|
|
1159
|
+
print("test execution complete")
|
|
1160
|
+
|
|
1161
|
+
|
|
1162
|
+
@pytest.mark.parametrize(
|
|
1163
|
+
"test_case",
|
|
1164
|
+
[
|
|
1165
|
+
TestCase(
|
|
1166
|
+
name="wait for complete status (default)",
|
|
1167
|
+
expected_status=SUCCESS,
|
|
1168
|
+
config={"name": BASIC_TRAIN_JOB_NAME},
|
|
1169
|
+
expected_output=get_train_job_data_type(
|
|
1170
|
+
runtime_name=TORCH_RUNTIME,
|
|
1171
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
1172
|
+
),
|
|
1173
|
+
),
|
|
1174
|
+
TestCase(
|
|
1175
|
+
name="wait for multiple statuses",
|
|
1176
|
+
expected_status=SUCCESS,
|
|
1177
|
+
config={
|
|
1178
|
+
"name": BASIC_TRAIN_JOB_NAME,
|
|
1179
|
+
"status": {constants.TRAINJOB_RUNNING, constants.TRAINJOB_COMPLETE},
|
|
1180
|
+
},
|
|
1181
|
+
expected_output=get_train_job_data_type(
|
|
1182
|
+
runtime_name=TORCH_RUNTIME,
|
|
1183
|
+
train_job_name=BASIC_TRAIN_JOB_NAME,
|
|
1184
|
+
),
|
|
1185
|
+
),
|
|
1186
|
+
TestCase(
|
|
1187
|
+
name="invalid status set error",
|
|
1188
|
+
expected_status=FAILED,
|
|
1189
|
+
config={
|
|
1190
|
+
"name": BASIC_TRAIN_JOB_NAME,
|
|
1191
|
+
"status": {"InvalidStatus"},
|
|
1192
|
+
},
|
|
1193
|
+
expected_error=ValueError,
|
|
1194
|
+
),
|
|
1195
|
+
TestCase(
|
|
1196
|
+
name="polling interval is more than timeout error",
|
|
1197
|
+
expected_status=FAILED,
|
|
1198
|
+
config={
|
|
1199
|
+
"name": BASIC_TRAIN_JOB_NAME,
|
|
1200
|
+
"timeout": 1,
|
|
1201
|
+
"polling_interval": 2,
|
|
1202
|
+
},
|
|
1203
|
+
expected_error=ValueError,
|
|
1204
|
+
),
|
|
1205
|
+
TestCase(
|
|
1206
|
+
name="job failed when not expected",
|
|
1207
|
+
expected_status=FAILED,
|
|
1208
|
+
config={
|
|
1209
|
+
"name": "failed-job",
|
|
1210
|
+
"status": {constants.TRAINJOB_RUNNING},
|
|
1211
|
+
},
|
|
1212
|
+
expected_error=RuntimeError,
|
|
1213
|
+
),
|
|
1214
|
+
TestCase(
|
|
1215
|
+
name="timeout error to wait for failed status",
|
|
1216
|
+
expected_status=FAILED,
|
|
1217
|
+
config={
|
|
1218
|
+
"name": BASIC_TRAIN_JOB_NAME,
|
|
1219
|
+
"status": {constants.TRAINJOB_FAILED},
|
|
1220
|
+
"polling_interval": 1,
|
|
1221
|
+
"timeout": 2,
|
|
1222
|
+
},
|
|
1223
|
+
expected_error=TimeoutError,
|
|
1224
|
+
),
|
|
1225
|
+
],
|
|
1226
|
+
)
|
|
1227
|
+
def test_wait_for_job_status(kubernetes_backend, test_case):
|
|
1228
|
+
"""Test KubernetesBackend.wait_for_job_status with various scenarios."""
|
|
1229
|
+
print("Executing test:", test_case.name)
|
|
1230
|
+
|
|
1231
|
+
original_get_job = kubernetes_backend.get_job
|
|
1232
|
+
|
|
1233
|
+
# TrainJob has unexpected failed status.
|
|
1234
|
+
def mock_get_job(name):
|
|
1235
|
+
job = original_get_job(name)
|
|
1236
|
+
if test_case.config.get("name") == "failed-job":
|
|
1237
|
+
job.status = constants.TRAINJOB_FAILED
|
|
1238
|
+
return job
|
|
1239
|
+
|
|
1240
|
+
kubernetes_backend.get_job = mock_get_job
|
|
1241
|
+
|
|
1242
|
+
try:
|
|
1243
|
+
job = kubernetes_backend.wait_for_job_status(**test_case.config)
|
|
1244
|
+
|
|
1245
|
+
assert test_case.expected_status == SUCCESS
|
|
1246
|
+
assert isinstance(job, types.TrainJob)
|
|
1247
|
+
# Job status should be in the expected set.
|
|
1248
|
+
assert job.status in test_case.config.get("status", {constants.TRAINJOB_COMPLETE})
|
|
1249
|
+
|
|
1250
|
+
except Exception as e:
|
|
1251
|
+
assert type(e) is test_case.expected_error
|
|
1252
|
+
|
|
1253
|
+
print("test execution complete")
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
@pytest.mark.parametrize(
|
|
1257
|
+
"test_case",
|
|
1258
|
+
[
|
|
1259
|
+
TestCase(
|
|
1260
|
+
name="valid flow with all defaults",
|
|
1261
|
+
expected_status=SUCCESS,
|
|
1262
|
+
config={"name": BASIC_TRAIN_JOB_NAME},
|
|
1263
|
+
expected_output=None,
|
|
1264
|
+
),
|
|
1265
|
+
TestCase(
|
|
1266
|
+
name="timeout error when deleting job",
|
|
1267
|
+
expected_status=FAILED,
|
|
1268
|
+
config={"namespace": TIMEOUT},
|
|
1269
|
+
expected_error=TimeoutError,
|
|
1270
|
+
),
|
|
1271
|
+
TestCase(
|
|
1272
|
+
name="runtime error when deleting job",
|
|
1273
|
+
expected_status=FAILED,
|
|
1274
|
+
config={"namespace": RUNTIME},
|
|
1275
|
+
expected_error=RuntimeError,
|
|
1276
|
+
),
|
|
1277
|
+
],
|
|
1278
|
+
)
|
|
1279
|
+
def test_delete_job(kubernetes_backend, test_case):
|
|
1280
|
+
"""Test KubernetesBackend.delete_job with basic success path."""
|
|
1281
|
+
print("Executing test:", test_case.name)
|
|
1282
|
+
try:
|
|
1283
|
+
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
|
|
1284
|
+
kubernetes_backend.delete_job(test_case.config.get("name"))
|
|
1285
|
+
assert test_case.expected_status == SUCCESS
|
|
1286
|
+
|
|
1287
|
+
except Exception as e:
|
|
1288
|
+
assert type(e) is test_case.expected_error
|
|
1289
|
+
print("test execution complete")
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
@pytest.mark.parametrize(
|
|
1293
|
+
"test_case",
|
|
1294
|
+
[
|
|
1295
|
+
TestCase(
|
|
1296
|
+
name="get job events with valid trainjob",
|
|
1297
|
+
expected_status=SUCCESS,
|
|
1298
|
+
config={"name": BASIC_TRAIN_JOB_NAME},
|
|
1299
|
+
expected_output=[
|
|
1300
|
+
types.Event(
|
|
1301
|
+
involved_object_kind=constants.TRAINJOB_KIND,
|
|
1302
|
+
involved_object_name=BASIC_TRAIN_JOB_NAME,
|
|
1303
|
+
message="TrainJob created successfully",
|
|
1304
|
+
reason="Created",
|
|
1305
|
+
event_time=datetime.datetime(2025, 6, 1, 10, 30, 0),
|
|
1306
|
+
),
|
|
1307
|
+
types.Event(
|
|
1308
|
+
involved_object_kind="Pod",
|
|
1309
|
+
involved_object_name="node-0-pod",
|
|
1310
|
+
message="Pod scheduled successfully",
|
|
1311
|
+
reason="Scheduled",
|
|
1312
|
+
event_time=datetime.datetime(2025, 6, 1, 10, 31, 0),
|
|
1313
|
+
),
|
|
1314
|
+
],
|
|
1315
|
+
),
|
|
1316
|
+
TestCase(
|
|
1317
|
+
name="timeout error when getting job events",
|
|
1318
|
+
expected_status=FAILED,
|
|
1319
|
+
config={"namespace": TIMEOUT, "name": BASIC_TRAIN_JOB_NAME},
|
|
1320
|
+
expected_error=TimeoutError,
|
|
1321
|
+
),
|
|
1322
|
+
TestCase(
|
|
1323
|
+
name="runtime error when getting job events",
|
|
1324
|
+
expected_status=FAILED,
|
|
1325
|
+
config={"namespace": RUNTIME, "name": BASIC_TRAIN_JOB_NAME},
|
|
1326
|
+
expected_error=RuntimeError,
|
|
1327
|
+
),
|
|
1328
|
+
],
|
|
1329
|
+
)
|
|
1330
|
+
def test_get_job_events(kubernetes_backend, test_case):
|
|
1331
|
+
"""Test KubernetesBackend.get_job_events with various scenarios."""
|
|
1332
|
+
print("Executing test:", test_case.name)
|
|
1333
|
+
try:
|
|
1334
|
+
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
|
|
1335
|
+
events = kubernetes_backend.get_job_events(test_case.config.get("name"))
|
|
1336
|
+
|
|
1337
|
+
assert test_case.expected_status == SUCCESS
|
|
1338
|
+
assert isinstance(events, list)
|
|
1339
|
+
assert len(events) == len(test_case.expected_output)
|
|
1340
|
+
assert [asdict(e) for e in events] == [asdict(e) for e in test_case.expected_output]
|
|
1341
|
+
|
|
1342
|
+
except Exception as e:
|
|
1343
|
+
assert type(e) is test_case.expected_error
|
|
1344
|
+
print("test execution complete")
|