viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from kubeflow_trainer_api import models
|
|
16
|
+
import pytest
|
|
17
|
+
|
|
18
|
+
import viettelcloud.aiplatform.trainer.backends.kubernetes.utils as utils
|
|
19
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
20
|
+
from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
|
|
21
|
+
from viettelcloud.aiplatform.trainer.types import types
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _build_runtime() -> types.Runtime:
|
|
25
|
+
runtime_trainer = types.RuntimeTrainer(
|
|
26
|
+
trainer_type=types.TrainerType.CUSTOM_TRAINER,
|
|
27
|
+
framework="torch",
|
|
28
|
+
device="cpu",
|
|
29
|
+
device_count="1",
|
|
30
|
+
image="example.com/image",
|
|
31
|
+
)
|
|
32
|
+
runtime_trainer.set_command(constants.DEFAULT_COMMAND)
|
|
33
|
+
return types.Runtime(name="test-runtime", trainer=runtime_trainer)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.mark.parametrize(
|
|
37
|
+
"test_case",
|
|
38
|
+
[
|
|
39
|
+
TestCase(
|
|
40
|
+
name="single MIG limit returns device and count",
|
|
41
|
+
expected_status=SUCCESS,
|
|
42
|
+
config={
|
|
43
|
+
"resources": models.IoK8sApiCoreV1ResourceRequirements(
|
|
44
|
+
limits={
|
|
45
|
+
"nvidia.com/mig-1g.5gb": models.IoK8sApimachineryPkgApiResourceQuantity(2),
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
},
|
|
49
|
+
expected_output=("mig-1g.5gb", "2.0"),
|
|
50
|
+
),
|
|
51
|
+
TestCase(
|
|
52
|
+
name="multiple MIG limits are not supported",
|
|
53
|
+
expected_status=FAILED,
|
|
54
|
+
config={
|
|
55
|
+
"resources": models.IoK8sApiCoreV1ResourceRequirements(
|
|
56
|
+
limits={
|
|
57
|
+
"nvidia.com/mig-1g.5gb": models.IoK8sApimachineryPkgApiResourceQuantity(1),
|
|
58
|
+
"nvidia.com/mig-2g.10gb": models.IoK8sApimachineryPkgApiResourceQuantity(1),
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
},
|
|
62
|
+
expected_error=ValueError,
|
|
63
|
+
),
|
|
64
|
+
],
|
|
65
|
+
)
|
|
66
|
+
def test_get_container_devices(test_case: TestCase):
|
|
67
|
+
print("Executing test:", test_case.name)
|
|
68
|
+
try:
|
|
69
|
+
device = utils.get_container_devices(test_case.config["resources"])
|
|
70
|
+
|
|
71
|
+
assert test_case.expected_status == SUCCESS
|
|
72
|
+
assert device == test_case.expected_output
|
|
73
|
+
|
|
74
|
+
except Exception as e:
|
|
75
|
+
assert test_case.expected_status == FAILED
|
|
76
|
+
assert type(e) is test_case.expected_error
|
|
77
|
+
print("test execution complete")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@pytest.mark.parametrize(
|
|
81
|
+
"test_case",
|
|
82
|
+
[
|
|
83
|
+
TestCase(
|
|
84
|
+
name="mig alias expands to fully qualified key",
|
|
85
|
+
expected_status=SUCCESS,
|
|
86
|
+
config={
|
|
87
|
+
"resources_per_node": {
|
|
88
|
+
"MiG-1G.5GB": 2,
|
|
89
|
+
"cpu": "500m",
|
|
90
|
+
}
|
|
91
|
+
},
|
|
92
|
+
expected_output=models.IoK8sApiCoreV1ResourceRequirements(
|
|
93
|
+
limits={
|
|
94
|
+
"cpu": models.IoK8sApimachineryPkgApiResourceQuantity("500m"),
|
|
95
|
+
"nvidia.com/mig-1g.5gb": models.IoK8sApimachineryPkgApiResourceQuantity(2),
|
|
96
|
+
},
|
|
97
|
+
requests={
|
|
98
|
+
"cpu": models.IoK8sApimachineryPkgApiResourceQuantity("500m"),
|
|
99
|
+
"nvidia.com/mig-1g.5gb": models.IoK8sApimachineryPkgApiResourceQuantity(2),
|
|
100
|
+
},
|
|
101
|
+
),
|
|
102
|
+
),
|
|
103
|
+
TestCase(
|
|
104
|
+
name="gpu and mig together raises error",
|
|
105
|
+
expected_status=FAILED,
|
|
106
|
+
config={"resources_per_node": {"gpu": 1, "mig-1g.5gb": 1}},
|
|
107
|
+
expected_error=ValueError,
|
|
108
|
+
),
|
|
109
|
+
TestCase(
|
|
110
|
+
name="multiple mig resource types raises error",
|
|
111
|
+
expected_status=FAILED,
|
|
112
|
+
config={
|
|
113
|
+
"resources_per_node": {
|
|
114
|
+
"mig-1g.5gb": 1,
|
|
115
|
+
"nvidia.com/mig-2g.10gb": 1,
|
|
116
|
+
}
|
|
117
|
+
},
|
|
118
|
+
expected_error=ValueError,
|
|
119
|
+
),
|
|
120
|
+
],
|
|
121
|
+
)
|
|
122
|
+
def test_get_resources_per_node(test_case: TestCase):
|
|
123
|
+
print("Executing test:", test_case.name)
|
|
124
|
+
try:
|
|
125
|
+
resources = utils.get_resources_per_node(test_case.config["resources_per_node"])
|
|
126
|
+
|
|
127
|
+
assert test_case.expected_status == SUCCESS
|
|
128
|
+
assert resources == test_case.expected_output
|
|
129
|
+
|
|
130
|
+
except Exception as e:
|
|
131
|
+
assert test_case.expected_status == FAILED
|
|
132
|
+
assert type(e) is test_case.expected_error
|
|
133
|
+
print("test execution complete")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@pytest.mark.parametrize(
|
|
137
|
+
"test_case",
|
|
138
|
+
[
|
|
139
|
+
TestCase(
|
|
140
|
+
name="multiple pip index URLs",
|
|
141
|
+
config={
|
|
142
|
+
"packages_to_install": ["torch", "numpy", "custom-package"],
|
|
143
|
+
"pip_index_urls": [
|
|
144
|
+
"https://pypi.org/simple",
|
|
145
|
+
"https://private.repo.com/simple",
|
|
146
|
+
"https://internal.company.com/simple",
|
|
147
|
+
],
|
|
148
|
+
"is_mpi": False,
|
|
149
|
+
},
|
|
150
|
+
expected_output=(
|
|
151
|
+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
|
|
152
|
+
" python -m ensurepip || python -m ensurepip --user || "
|
|
153
|
+
"apt-get install python-pip\n"
|
|
154
|
+
"fi\n\n"
|
|
155
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
156
|
+
"--no-warn-script-location --index-url https://pypi.org/simple "
|
|
157
|
+
"--extra-index-url https://private.repo.com/simple "
|
|
158
|
+
"--extra-index-url https://internal.company.com/simple "
|
|
159
|
+
"--user torch numpy custom-package ||\n"
|
|
160
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
161
|
+
"--no-warn-script-location --index-url https://pypi.org/simple "
|
|
162
|
+
"--extra-index-url https://private.repo.com/simple "
|
|
163
|
+
"--extra-index-url https://internal.company.com/simple "
|
|
164
|
+
"torch numpy custom-package\n"
|
|
165
|
+
),
|
|
166
|
+
),
|
|
167
|
+
TestCase(
|
|
168
|
+
name="single pip index URL (backward compatibility)",
|
|
169
|
+
config={
|
|
170
|
+
"packages_to_install": ["torch", "numpy", "custom-package"],
|
|
171
|
+
"pip_index_urls": ["https://pypi.org/simple"],
|
|
172
|
+
"is_mpi": False,
|
|
173
|
+
},
|
|
174
|
+
expected_output=(
|
|
175
|
+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
|
|
176
|
+
" python -m ensurepip || python -m ensurepip --user || "
|
|
177
|
+
"apt-get install python-pip\n"
|
|
178
|
+
"fi\n\n"
|
|
179
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
180
|
+
"--no-warn-script-location --index-url https://pypi.org/simple "
|
|
181
|
+
"--user torch numpy custom-package ||\n"
|
|
182
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
183
|
+
"--no-warn-script-location --index-url https://pypi.org/simple "
|
|
184
|
+
"torch numpy custom-package\n"
|
|
185
|
+
),
|
|
186
|
+
),
|
|
187
|
+
TestCase(
|
|
188
|
+
name="multiple pip index URLs with MPI",
|
|
189
|
+
config={
|
|
190
|
+
"packages_to_install": ["torch", "numpy", "custom-package"],
|
|
191
|
+
"pip_index_urls": [
|
|
192
|
+
"https://pypi.org/simple",
|
|
193
|
+
"https://private.repo.com/simple",
|
|
194
|
+
"https://internal.company.com/simple",
|
|
195
|
+
],
|
|
196
|
+
"is_mpi": True,
|
|
197
|
+
},
|
|
198
|
+
expected_output=(
|
|
199
|
+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
|
|
200
|
+
" python -m ensurepip || python -m ensurepip --user || "
|
|
201
|
+
"apt-get install python-pip\n"
|
|
202
|
+
"fi\n\n"
|
|
203
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
204
|
+
"--no-warn-script-location --index-url https://pypi.org/simple "
|
|
205
|
+
"--extra-index-url https://private.repo.com/simple "
|
|
206
|
+
"--extra-index-url https://internal.company.com/simple "
|
|
207
|
+
"--user torch numpy custom-package ||\n"
|
|
208
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
209
|
+
"--no-warn-script-location --index-url https://pypi.org/simple "
|
|
210
|
+
"--extra-index-url https://private.repo.com/simple "
|
|
211
|
+
"--extra-index-url https://internal.company.com/simple "
|
|
212
|
+
"torch numpy custom-package\n"
|
|
213
|
+
),
|
|
214
|
+
),
|
|
215
|
+
TestCase(
|
|
216
|
+
name="default pip index URLs",
|
|
217
|
+
config={
|
|
218
|
+
"packages_to_install": ["torch", "numpy"],
|
|
219
|
+
"pip_index_urls": constants.DEFAULT_PIP_INDEX_URLS,
|
|
220
|
+
"is_mpi": False,
|
|
221
|
+
},
|
|
222
|
+
expected_output=(
|
|
223
|
+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
|
|
224
|
+
" python -m ensurepip || python -m ensurepip --user || "
|
|
225
|
+
"apt-get install python-pip\n"
|
|
226
|
+
"fi\n\n"
|
|
227
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
228
|
+
f"--no-warn-script-location --index-url "
|
|
229
|
+
f"{constants.DEFAULT_PIP_INDEX_URLS[0]} --user torch numpy ||\n"
|
|
230
|
+
"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
|
|
231
|
+
f"--no-warn-script-location --index-url "
|
|
232
|
+
f"{constants.DEFAULT_PIP_INDEX_URLS[0]} torch numpy\n"
|
|
233
|
+
),
|
|
234
|
+
),
|
|
235
|
+
],
|
|
236
|
+
)
|
|
237
|
+
def test_get_script_for_python_packages(test_case):
|
|
238
|
+
"""Test get_script_for_python_packages with various configurations."""
|
|
239
|
+
script = utils.get_script_for_python_packages(
|
|
240
|
+
packages_to_install=test_case.config["packages_to_install"],
|
|
241
|
+
pip_index_urls=test_case.config["pip_index_urls"],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
assert test_case.expected_output == script
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@pytest.mark.parametrize(
|
|
248
|
+
"test_case",
|
|
249
|
+
[
|
|
250
|
+
TestCase(
|
|
251
|
+
name="with args dict always unpacks kwargs",
|
|
252
|
+
expected_status=SUCCESS,
|
|
253
|
+
config={
|
|
254
|
+
"func": (lambda: print("Hello World")),
|
|
255
|
+
"func_args": {"batch_size": 128, "learning_rate": 0.001, "epochs": 20},
|
|
256
|
+
"runtime": _build_runtime(),
|
|
257
|
+
},
|
|
258
|
+
expected_output=[
|
|
259
|
+
"bash",
|
|
260
|
+
"-c",
|
|
261
|
+
(
|
|
262
|
+
"\nread -r -d '' SCRIPT << EOM\n\n"
|
|
263
|
+
'"func": (lambda: print("Hello World")),\n\n'
|
|
264
|
+
"<lambda>(**{'batch_size': 128, 'learning_rate': 0.001, 'epochs': 20})\n\n"
|
|
265
|
+
"EOM\n"
|
|
266
|
+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
|
|
267
|
+
'python "utils_test.py"'
|
|
268
|
+
),
|
|
269
|
+
],
|
|
270
|
+
),
|
|
271
|
+
TestCase(
|
|
272
|
+
name="without args calls function with no params",
|
|
273
|
+
expected_status=SUCCESS,
|
|
274
|
+
config={
|
|
275
|
+
"func": (lambda: print("Hello World")),
|
|
276
|
+
"func_args": None,
|
|
277
|
+
"runtime": _build_runtime(),
|
|
278
|
+
},
|
|
279
|
+
expected_output=[
|
|
280
|
+
"bash",
|
|
281
|
+
"-c",
|
|
282
|
+
(
|
|
283
|
+
"\nread -r -d '' SCRIPT << EOM\n\n"
|
|
284
|
+
'"func": (lambda: print("Hello World")),\n\n'
|
|
285
|
+
"<lambda>()\n\n"
|
|
286
|
+
"EOM\n"
|
|
287
|
+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
|
|
288
|
+
'python "utils_test.py"'
|
|
289
|
+
),
|
|
290
|
+
],
|
|
291
|
+
),
|
|
292
|
+
TestCase(
|
|
293
|
+
name="raises when runtime has no trainer",
|
|
294
|
+
expected_status=FAILED,
|
|
295
|
+
config={
|
|
296
|
+
"func": (lambda: print("Hello World")),
|
|
297
|
+
"func_args": None,
|
|
298
|
+
"runtime": types.Runtime(name="no-trainer", trainer=None),
|
|
299
|
+
},
|
|
300
|
+
expected_error=ValueError,
|
|
301
|
+
),
|
|
302
|
+
TestCase(
|
|
303
|
+
name="raises when train_func is not callable",
|
|
304
|
+
expected_status=FAILED,
|
|
305
|
+
config={
|
|
306
|
+
"func": "not callable",
|
|
307
|
+
"func_args": None,
|
|
308
|
+
"runtime": _build_runtime(),
|
|
309
|
+
},
|
|
310
|
+
expected_error=ValueError,
|
|
311
|
+
),
|
|
312
|
+
TestCase(
|
|
313
|
+
name="single dict param also unpacks kwargs",
|
|
314
|
+
expected_status=SUCCESS,
|
|
315
|
+
config={
|
|
316
|
+
"func": (lambda: print("Hello World")),
|
|
317
|
+
"func_args": {"a": 1, "b": 2},
|
|
318
|
+
"runtime": _build_runtime(),
|
|
319
|
+
},
|
|
320
|
+
expected_output=[
|
|
321
|
+
"bash",
|
|
322
|
+
"-c",
|
|
323
|
+
(
|
|
324
|
+
"\nread -r -d '' SCRIPT << EOM\n\n"
|
|
325
|
+
'"func": (lambda: print("Hello World")),\n\n'
|
|
326
|
+
"<lambda>(**{'a': 1, 'b': 2})\n\n"
|
|
327
|
+
"EOM\n"
|
|
328
|
+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
|
|
329
|
+
'python "utils_test.py"'
|
|
330
|
+
),
|
|
331
|
+
],
|
|
332
|
+
),
|
|
333
|
+
TestCase(
|
|
334
|
+
name="multi-param function uses kwargs-unpacking",
|
|
335
|
+
expected_status=SUCCESS,
|
|
336
|
+
config={
|
|
337
|
+
"func": (lambda **kwargs: "ok"),
|
|
338
|
+
"func_args": {"a": 3, "b": "hi", "c": 0.2},
|
|
339
|
+
"runtime": _build_runtime(),
|
|
340
|
+
},
|
|
341
|
+
expected_output=[
|
|
342
|
+
"bash",
|
|
343
|
+
"-c",
|
|
344
|
+
(
|
|
345
|
+
"\nread -r -d '' SCRIPT << EOM\n\n"
|
|
346
|
+
'"func": (lambda **kwargs: "ok"),\n\n'
|
|
347
|
+
"<lambda>(**{'a': 3, 'b': 'hi', 'c': 0.2})\n\n"
|
|
348
|
+
"EOM\n"
|
|
349
|
+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
|
|
350
|
+
'python "utils_test.py"'
|
|
351
|
+
),
|
|
352
|
+
],
|
|
353
|
+
),
|
|
354
|
+
],
|
|
355
|
+
)
|
|
356
|
+
def test_get_command_using_train_func(test_case: TestCase):
|
|
357
|
+
try:
|
|
358
|
+
command = utils.get_command_using_train_func(
|
|
359
|
+
runtime=test_case.config["runtime"],
|
|
360
|
+
train_func=test_case.config.get("func"),
|
|
361
|
+
train_func_parameters=test_case.config.get("func_args"),
|
|
362
|
+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
|
|
363
|
+
packages_to_install=[],
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
assert test_case.expected_status == SUCCESS
|
|
367
|
+
assert command == test_case.expected_output
|
|
368
|
+
|
|
369
|
+
except Exception as e:
|
|
370
|
+
assert type(e) is test_case.expected_error
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@pytest.mark.parametrize(
|
|
374
|
+
"test_case",
|
|
375
|
+
[
|
|
376
|
+
TestCase(
|
|
377
|
+
name="DataCacheInitializer with all optional fields",
|
|
378
|
+
expected_status=SUCCESS,
|
|
379
|
+
config={
|
|
380
|
+
"initializer": types.DataCacheInitializer(
|
|
381
|
+
storage_uri="cache://test_schema/test_table",
|
|
382
|
+
num_data_nodes=3,
|
|
383
|
+
metadata_loc="s3://bucket/metadata",
|
|
384
|
+
head_cpu="1",
|
|
385
|
+
head_mem="1Gi",
|
|
386
|
+
worker_cpu="2",
|
|
387
|
+
worker_mem="2Gi",
|
|
388
|
+
iam_role="arn:aws:iam::123456789012:role/test-role",
|
|
389
|
+
),
|
|
390
|
+
},
|
|
391
|
+
expected_output={
|
|
392
|
+
"storage_uri": "cache://test_schema/test_table",
|
|
393
|
+
"env": {
|
|
394
|
+
"CLUSTER_SIZE": "4",
|
|
395
|
+
"METADATA_LOC": "s3://bucket/metadata",
|
|
396
|
+
"HEAD_CPU": "1",
|
|
397
|
+
"HEAD_MEM": "1Gi",
|
|
398
|
+
"WORKER_CPU": "2",
|
|
399
|
+
"WORKER_MEM": "2Gi",
|
|
400
|
+
"IAM_ROLE": "arn:aws:iam::123456789012:role/test-role",
|
|
401
|
+
},
|
|
402
|
+
},
|
|
403
|
+
),
|
|
404
|
+
TestCase(
|
|
405
|
+
name="DataCacheInitializer with only required fields",
|
|
406
|
+
expected_status=SUCCESS,
|
|
407
|
+
config={
|
|
408
|
+
"initializer": types.DataCacheInitializer(
|
|
409
|
+
storage_uri="cache://schema/table",
|
|
410
|
+
num_data_nodes=2,
|
|
411
|
+
metadata_loc="s3://bucket/metadata.json",
|
|
412
|
+
),
|
|
413
|
+
},
|
|
414
|
+
expected_output={
|
|
415
|
+
"storage_uri": "cache://schema/table",
|
|
416
|
+
"env": {
|
|
417
|
+
"CLUSTER_SIZE": "3",
|
|
418
|
+
"METADATA_LOC": "s3://bucket/metadata.json",
|
|
419
|
+
},
|
|
420
|
+
},
|
|
421
|
+
),
|
|
422
|
+
TestCase(
|
|
423
|
+
name="HuggingFaceDatasetInitializer without access token",
|
|
424
|
+
expected_status=SUCCESS,
|
|
425
|
+
config={
|
|
426
|
+
"initializer": types.HuggingFaceDatasetInitializer(
|
|
427
|
+
storage_uri="hf://datasets/public-dataset",
|
|
428
|
+
),
|
|
429
|
+
},
|
|
430
|
+
expected_output={
|
|
431
|
+
"storage_uri": "hf://datasets/public-dataset",
|
|
432
|
+
"env": {},
|
|
433
|
+
},
|
|
434
|
+
),
|
|
435
|
+
TestCase(
|
|
436
|
+
name="S3DatasetInitializer with all optional fields",
|
|
437
|
+
expected_status=SUCCESS,
|
|
438
|
+
config={
|
|
439
|
+
"initializer": types.S3DatasetInitializer(
|
|
440
|
+
storage_uri="s3://my-bucket/datasets/train",
|
|
441
|
+
endpoint="https://s3.custom.com",
|
|
442
|
+
access_key_id="test-access-key",
|
|
443
|
+
secret_access_key="test-secret-key",
|
|
444
|
+
region="us-west-2",
|
|
445
|
+
role_arn="arn:aws:iam::123456789012:role/test-role",
|
|
446
|
+
),
|
|
447
|
+
},
|
|
448
|
+
expected_output={
|
|
449
|
+
"storage_uri": "s3://my-bucket/datasets/train",
|
|
450
|
+
"env": {
|
|
451
|
+
"ENDPOINT": "https://s3.custom.com",
|
|
452
|
+
"ACCESS_KEY_ID": "test-access-key",
|
|
453
|
+
"SECRET_ACCESS_KEY": "test-secret-key",
|
|
454
|
+
"REGION": "us-west-2",
|
|
455
|
+
"ROLE_ARN": "arn:aws:iam::123456789012:role/test-role",
|
|
456
|
+
},
|
|
457
|
+
},
|
|
458
|
+
),
|
|
459
|
+
TestCase(
|
|
460
|
+
name="Invalid dataset type",
|
|
461
|
+
expected_status=FAILED,
|
|
462
|
+
config={
|
|
463
|
+
"initializer": "invalid_type",
|
|
464
|
+
},
|
|
465
|
+
expected_error=ValueError,
|
|
466
|
+
),
|
|
467
|
+
],
|
|
468
|
+
)
|
|
469
|
+
def test_get_dataset_initializer(test_case):
|
|
470
|
+
"""Test get_dataset_initializer with various dataset initializer types."""
|
|
471
|
+
print("Executing test:", test_case.name)
|
|
472
|
+
try:
|
|
473
|
+
dataset_initializer = utils.get_dataset_initializer(test_case.config["initializer"])
|
|
474
|
+
|
|
475
|
+
assert test_case.expected_status == SUCCESS
|
|
476
|
+
assert dataset_initializer is not None
|
|
477
|
+
assert dataset_initializer.storage_uri == test_case.expected_output["storage_uri"]
|
|
478
|
+
|
|
479
|
+
# Check env vars if expected
|
|
480
|
+
expected_env = test_case.expected_output.get("env", {})
|
|
481
|
+
env_dict = {
|
|
482
|
+
env_var.name: env_var.value for env_var in getattr(dataset_initializer, "env", [])
|
|
483
|
+
}
|
|
484
|
+
assert env_dict == expected_env, f"Expected env {expected_env}, got {env_dict}"
|
|
485
|
+
|
|
486
|
+
except Exception as e:
|
|
487
|
+
assert type(e) is test_case.expected_error
|
|
488
|
+
print("test execution complete")
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
@pytest.mark.parametrize(
|
|
492
|
+
"test_case",
|
|
493
|
+
[
|
|
494
|
+
TestCase(
|
|
495
|
+
name="HuggingFaceModelInitializer with access token and ignore patterns",
|
|
496
|
+
expected_status=SUCCESS,
|
|
497
|
+
config={
|
|
498
|
+
"initializer": types.HuggingFaceModelInitializer(
|
|
499
|
+
storage_uri="hf://username/my-model",
|
|
500
|
+
access_token="hf_test_token_789",
|
|
501
|
+
ignore_patterns=["*.bin", "*.safetensors"],
|
|
502
|
+
),
|
|
503
|
+
},
|
|
504
|
+
expected_output={
|
|
505
|
+
"storage_uri": "hf://username/my-model",
|
|
506
|
+
"env": {
|
|
507
|
+
"ACCESS_TOKEN": "hf_test_token_789",
|
|
508
|
+
"IGNORE_PATTERNS": "*.bin,*.safetensors",
|
|
509
|
+
},
|
|
510
|
+
},
|
|
511
|
+
),
|
|
512
|
+
TestCase(
|
|
513
|
+
name="HuggingFaceModelInitializer without access token",
|
|
514
|
+
expected_status=SUCCESS,
|
|
515
|
+
config={
|
|
516
|
+
"initializer": types.HuggingFaceModelInitializer(
|
|
517
|
+
storage_uri="hf://username/public-model",
|
|
518
|
+
),
|
|
519
|
+
},
|
|
520
|
+
expected_output={
|
|
521
|
+
"storage_uri": "hf://username/public-model",
|
|
522
|
+
"env": {
|
|
523
|
+
"IGNORE_PATTERNS": ",".join(constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS),
|
|
524
|
+
},
|
|
525
|
+
},
|
|
526
|
+
),
|
|
527
|
+
TestCase(
|
|
528
|
+
name="S3ModelInitializer with all optional fields",
|
|
529
|
+
expected_status=SUCCESS,
|
|
530
|
+
config={
|
|
531
|
+
"initializer": types.S3ModelInitializer(
|
|
532
|
+
storage_uri="s3://my-bucket/models/trained-model",
|
|
533
|
+
endpoint="https://s3.custom.com",
|
|
534
|
+
access_key_id="test-access-key",
|
|
535
|
+
secret_access_key="test-secret-key",
|
|
536
|
+
region="us-east-1",
|
|
537
|
+
role_arn="arn:aws:iam::123456789012:role/test-role",
|
|
538
|
+
ignore_patterns=["*.txt", "*.log"],
|
|
539
|
+
),
|
|
540
|
+
},
|
|
541
|
+
expected_output={
|
|
542
|
+
"storage_uri": "s3://my-bucket/models/trained-model",
|
|
543
|
+
"env": {
|
|
544
|
+
"ENDPOINT": "https://s3.custom.com",
|
|
545
|
+
"ACCESS_KEY_ID": "test-access-key",
|
|
546
|
+
"SECRET_ACCESS_KEY": "test-secret-key",
|
|
547
|
+
"REGION": "us-east-1",
|
|
548
|
+
"ROLE_ARN": "arn:aws:iam::123456789012:role/test-role",
|
|
549
|
+
"IGNORE_PATTERNS": "*.txt,*.log",
|
|
550
|
+
},
|
|
551
|
+
},
|
|
552
|
+
),
|
|
553
|
+
TestCase(
|
|
554
|
+
name="Invalid model type",
|
|
555
|
+
expected_status=FAILED,
|
|
556
|
+
config={
|
|
557
|
+
"initializer": "invalid_type",
|
|
558
|
+
},
|
|
559
|
+
expected_error=ValueError,
|
|
560
|
+
),
|
|
561
|
+
],
|
|
562
|
+
)
|
|
563
|
+
def test_get_model_initializer(test_case):
|
|
564
|
+
"""Test get_model_initializer with various model initializer types."""
|
|
565
|
+
print("Executing test:", test_case.name)
|
|
566
|
+
try:
|
|
567
|
+
model_initializer = utils.get_model_initializer(test_case.config["initializer"])
|
|
568
|
+
|
|
569
|
+
assert test_case.expected_status == SUCCESS
|
|
570
|
+
assert model_initializer is not None
|
|
571
|
+
assert model_initializer.storage_uri == test_case.expected_output["storage_uri"]
|
|
572
|
+
|
|
573
|
+
# Check env vars if expected
|
|
574
|
+
expected_env = test_case.expected_output.get("env", {})
|
|
575
|
+
env_dict = {
|
|
576
|
+
env_var.name: env_var.value for env_var in getattr(model_initializer, "env", [])
|
|
577
|
+
}
|
|
578
|
+
assert env_dict == expected_env, f"Expected env {expected_env}, got {env_dict}"
|
|
579
|
+
|
|
580
|
+
except Exception as e:
|
|
581
|
+
assert type(e) is test_case.expected_error
|
|
582
|
+
print("test execution complete")
|
|
File without changes
|