viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,501 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Unit tests for the LocalProcessBackend class in the Kubeflow Trainer SDK.
17
+ """
18
+
19
+ from unittest.mock import Mock, patch
20
+
21
+ import pytest
22
+
23
+ from viettelcloud.aiplatform.trainer.backends.localprocess.backend import LocalProcessBackend
24
+ from viettelcloud.aiplatform.trainer.backends.localprocess.constants import LOCAL_RUNTIME_IMAGE
25
+ from viettelcloud.aiplatform.trainer.backends.localprocess.types import (
26
+ LocalProcessBackendConfig,
27
+ LocalRuntimeTrainer,
28
+ )
29
+ from viettelcloud.aiplatform.trainer.constants import constants
30
+ from viettelcloud.aiplatform.trainer.options import (
31
+ Annotations,
32
+ Labels,
33
+ Name,
34
+ PodTemplateOverride,
35
+ PodTemplateOverrides,
36
+ )
37
+ from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
38
+ from viettelcloud.aiplatform.trainer.types import types
39
+
40
+ # Test constants
41
+ TORCH_RUNTIME = constants.DEFAULT_TRAINING_RUNTIME
42
+ BASIC_TRAIN_JOB_NAME = "test-job"
43
+
44
+
45
+ def dummy_training_function():
46
+ """Dummy training function for testing."""
47
+ print("Training started")
48
+ return {"loss": 0.5, "accuracy": 0.95}
49
+
50
+
51
+ @pytest.fixture
52
+ def local_backend():
53
+ """Create LocalProcessBackend for testing."""
54
+ cfg = LocalProcessBackendConfig()
55
+ backend = LocalProcessBackend(cfg)
56
+ yield backend
57
+ # Cleanup: Clear jobs to prevent test pollution
58
+ backend._LocalProcessBackend__local_jobs.clear()
59
+
60
+
61
+ @pytest.fixture
62
+ def mock_train_environment():
63
+ """Mock the training environment to avoid actual subprocess execution."""
64
+ with (
65
+ patch(
66
+ "viettelcloud.aiplatform.trainer.backends.localprocess.job.LocalJob.start"
67
+ ) as mock_start,
68
+ patch(
69
+ "viettelcloud.aiplatform.trainer.backends.localprocess.utils.get_local_runtime_trainer"
70
+ ) as mock_get_trainer,
71
+ patch(
72
+ "viettelcloud.aiplatform.trainer.backends.localprocess.utils.get_local_train_job_script"
73
+ ) as mock_get_script,
74
+ patch("tempfile.mkdtemp") as mock_mkdtemp,
75
+ ):
76
+ # Setup mock return values
77
+ mock_mkdtemp.return_value = "/tmp/test-venv"
78
+ mock_get_script.return_value = ["/bin/bash", "-c", "echo 'training'"]
79
+
80
+ mock_trainer = LocalRuntimeTrainer(
81
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
82
+ framework="torch",
83
+ num_nodes=1,
84
+ device_count="1",
85
+ device="cpu",
86
+ packages=["torch"],
87
+ image=LOCAL_RUNTIME_IMAGE,
88
+ )
89
+ mock_trainer.set_command = Mock()
90
+ mock_get_trainer.return_value = mock_trainer
91
+
92
+ yield {
93
+ "start": mock_start,
94
+ "get_trainer": mock_get_trainer,
95
+ "get_script": mock_get_script,
96
+ "mkdtemp": mock_mkdtemp,
97
+ }
98
+
99
+
100
+ @pytest.mark.parametrize(
101
+ "test_case",
102
+ [
103
+ TestCase(
104
+ name="list_all_local_runtimes",
105
+ expected_status=SUCCESS,
106
+ config={},
107
+ ),
108
+ ],
109
+ )
110
+ def test_list_runtimes(local_backend, test_case):
111
+ """Test LocalProcessBackend.list_runtimes()."""
112
+ runtimes = local_backend.list_runtimes()
113
+ assert len(runtimes) > 0
114
+ assert all(isinstance(rt, types.Runtime) for rt in runtimes)
115
+
116
+
117
+ @pytest.mark.parametrize(
118
+ "test_case",
119
+ [
120
+ TestCase(
121
+ name="get_existing_runtime",
122
+ expected_status=SUCCESS,
123
+ config={"runtime_name": TORCH_RUNTIME},
124
+ ),
125
+ TestCase(
126
+ name="get_nonexistent_runtime",
127
+ expected_status=FAILED,
128
+ config={"runtime_name": "nonexistent-runtime"},
129
+ expected_error=ValueError,
130
+ ),
131
+ ],
132
+ )
133
+ def test_get_runtime(local_backend, test_case):
134
+ """Test LocalProcessBackend.get_runtime()."""
135
+ runtime_name = test_case.config.get("runtime_name")
136
+
137
+ if test_case.expected_status == FAILED:
138
+ with pytest.raises(test_case.expected_error):
139
+ local_backend.get_runtime(runtime_name)
140
+ else:
141
+ runtime = local_backend.get_runtime(runtime_name)
142
+ assert runtime is not None
143
+ assert runtime.name == runtime_name
144
+
145
+
146
+ @pytest.mark.parametrize(
147
+ "test_case",
148
+ [
149
+ TestCase(
150
+ name="get_packages_for_existing_runtime",
151
+ expected_status=SUCCESS,
152
+ config={
153
+ "runtime": types.Runtime(
154
+ name=TORCH_RUNTIME,
155
+ trainer=types.RuntimeTrainer(
156
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
157
+ framework="torch",
158
+ num_nodes=1,
159
+ image=LOCAL_RUNTIME_IMAGE,
160
+ ),
161
+ ),
162
+ },
163
+ ),
164
+ TestCase(
165
+ name="get_packages_for_nonexistent_runtime",
166
+ expected_status=FAILED,
167
+ config={
168
+ "runtime": types.Runtime(
169
+ name="nonexistent-runtime",
170
+ trainer=types.RuntimeTrainer(
171
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
172
+ framework="torch",
173
+ num_nodes=1,
174
+ image=LOCAL_RUNTIME_IMAGE,
175
+ ),
176
+ ),
177
+ },
178
+ expected_error=ValueError,
179
+ ),
180
+ ],
181
+ )
182
+ def test_get_runtime_packages(local_backend, test_case):
183
+ """Test LocalProcessBackend.get_runtime_packages()."""
184
+ runtime = test_case.config.get("runtime")
185
+
186
+ if test_case.expected_status == FAILED:
187
+ with pytest.raises(test_case.expected_error):
188
+ local_backend.get_runtime_packages(runtime)
189
+ else:
190
+ packages = local_backend.get_runtime_packages(runtime)
191
+ assert packages is not None
192
+ assert isinstance(packages, list)
193
+
194
+
195
+ @pytest.mark.parametrize(
196
+ "test_case",
197
+ [
198
+ TestCase(
199
+ name="train with basic custom trainer - no options",
200
+ expected_status=SUCCESS,
201
+ config={
202
+ "runtime": types.Runtime(
203
+ name=TORCH_RUNTIME,
204
+ trainer=types.RuntimeTrainer(
205
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
206
+ framework="torch",
207
+ num_nodes=1,
208
+ image=LOCAL_RUNTIME_IMAGE,
209
+ ),
210
+ ),
211
+ "trainer": types.CustomTrainer(
212
+ func=dummy_training_function,
213
+ packages_to_install=["numpy", "torch"],
214
+ ),
215
+ "options": [],
216
+ },
217
+ ),
218
+ TestCase(
219
+ name="train with custom trainer and environment variables",
220
+ expected_status=SUCCESS,
221
+ config={
222
+ "runtime": types.Runtime(
223
+ name=TORCH_RUNTIME,
224
+ trainer=types.RuntimeTrainer(
225
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
226
+ framework="torch",
227
+ num_nodes=1,
228
+ image=LOCAL_RUNTIME_IMAGE,
229
+ ),
230
+ ),
231
+ "trainer": types.CustomTrainer(
232
+ func=dummy_training_function,
233
+ packages_to_install=["torch"],
234
+ env={"CUDA_VISIBLE_DEVICES": "0", "OMP_NUM_THREADS": "4"},
235
+ ),
236
+ "options": [],
237
+ },
238
+ ),
239
+ TestCase(
240
+ name="train rejects kubernetes labels option",
241
+ expected_status=FAILED,
242
+ config={
243
+ "runtime": types.Runtime(
244
+ name=TORCH_RUNTIME,
245
+ trainer=types.RuntimeTrainer(
246
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
247
+ framework="torch",
248
+ num_nodes=1,
249
+ image=LOCAL_RUNTIME_IMAGE,
250
+ ),
251
+ ),
252
+ "trainer": types.CustomTrainer(
253
+ func=dummy_training_function,
254
+ ),
255
+ "options": [Labels({"app": "test"})],
256
+ },
257
+ expected_error=ValueError,
258
+ ),
259
+ TestCase(
260
+ name="train rejects kubernetes annotations option",
261
+ expected_status=FAILED,
262
+ config={
263
+ "runtime": types.Runtime(
264
+ name=TORCH_RUNTIME,
265
+ trainer=types.RuntimeTrainer(
266
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
267
+ framework="torch",
268
+ num_nodes=1,
269
+ image=LOCAL_RUNTIME_IMAGE,
270
+ ),
271
+ ),
272
+ "trainer": types.CustomTrainer(
273
+ func=dummy_training_function,
274
+ ),
275
+ "options": [Annotations({"description": "test"})],
276
+ },
277
+ expected_error=ValueError,
278
+ ),
279
+ TestCase(
280
+ name="train rejects pod template overrides option",
281
+ expected_status=FAILED,
282
+ config={
283
+ "runtime": types.Runtime(
284
+ name=TORCH_RUNTIME,
285
+ trainer=types.RuntimeTrainer(
286
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
287
+ framework="torch",
288
+ num_nodes=1,
289
+ image=LOCAL_RUNTIME_IMAGE,
290
+ ),
291
+ ),
292
+ "trainer": types.CustomTrainer(
293
+ func=dummy_training_function,
294
+ ),
295
+ "options": [
296
+ PodTemplateOverrides(
297
+ PodTemplateOverride(
298
+ target_jobs=["node"],
299
+ )
300
+ )
301
+ ],
302
+ },
303
+ expected_error=ValueError,
304
+ ),
305
+ TestCase(
306
+ name="train fails without runtime",
307
+ expected_status=FAILED,
308
+ config={
309
+ "runtime": None,
310
+ "trainer": types.CustomTrainer(
311
+ func=dummy_training_function,
312
+ ),
313
+ "options": [],
314
+ },
315
+ expected_error=ValueError,
316
+ ),
317
+ TestCase(
318
+ name="train fails without custom trainer",
319
+ expected_status=FAILED,
320
+ config={
321
+ "runtime": types.Runtime(
322
+ name=TORCH_RUNTIME,
323
+ trainer=types.RuntimeTrainer(
324
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
325
+ framework="torch",
326
+ num_nodes=1,
327
+ image=LOCAL_RUNTIME_IMAGE,
328
+ ),
329
+ ),
330
+ "trainer": None,
331
+ },
332
+ expected_error=ValueError,
333
+ ),
334
+ ],
335
+ )
336
+ def test_train(local_backend, mock_train_environment, test_case):
337
+ """Test LocalProcessBackend.train() with success and failure cases."""
338
+ runtime = test_case.config.get("runtime")
339
+ trainer = test_case.config.get("trainer")
340
+ options = test_case.config.get("options", [])
341
+
342
+ mocks = mock_train_environment
343
+
344
+ if test_case.expected_status == FAILED:
345
+ with pytest.raises(test_case.expected_error) as exc_info:
346
+ local_backend.train(
347
+ runtime=runtime,
348
+ trainer=trainer,
349
+ options=options,
350
+ )
351
+
352
+ # Verify specific error messages
353
+ error_msg = str(exc_info.value)
354
+ if "rejects kubernetes" in test_case.name:
355
+ assert "not compatible with" in error_msg
356
+ elif "without runtime" in test_case.name:
357
+ assert "Runtime must be provided" in error_msg
358
+ elif "without custom trainer" in test_case.name:
359
+ assert "CustomTrainer must be set" in error_msg
360
+ else:
361
+ train_job_name = local_backend.train(
362
+ runtime=runtime,
363
+ trainer=trainer,
364
+ options=options,
365
+ )
366
+
367
+ assert train_job_name is not None
368
+ assert len(train_job_name) > 0
369
+ mocks["start"].assert_called_once()
370
+ mocks["get_trainer"].assert_called_once()
371
+ mocks["get_script"].assert_called_once()
372
+
373
+ # Verify job is tracked
374
+ jobs = local_backend.list_jobs(runtime=runtime)
375
+ assert any(job.name == train_job_name for job in jobs)
376
+
377
+
378
+ @pytest.mark.parametrize(
379
+ "test_case",
380
+ [
381
+ TestCase(
382
+ name="get_nonexistent_job",
383
+ expected_status=FAILED,
384
+ config={"job_name": "nonexistent-job"},
385
+ expected_error=ValueError,
386
+ ),
387
+ ],
388
+ )
389
+ def test_get_job(local_backend, test_case):
390
+ """Test LocalProcessBackend.get_job()."""
391
+ job_name = test_case.config.get("job_name")
392
+
393
+ if test_case.expected_status == FAILED:
394
+ with pytest.raises(test_case.expected_error):
395
+ local_backend.get_job(job_name)
396
+
397
+
398
+ @pytest.mark.parametrize(
399
+ "test_case",
400
+ [
401
+ TestCase(
402
+ name="list_jobs_empty",
403
+ expected_status=SUCCESS,
404
+ config={"runtime": None},
405
+ ),
406
+ ],
407
+ )
408
+ def test_list_jobs(local_backend, test_case):
409
+ """Test LocalProcessBackend.list_jobs()."""
410
+ runtime = test_case.config.get("runtime")
411
+ jobs = local_backend.list_jobs(runtime=runtime)
412
+ assert isinstance(jobs, list)
413
+
414
+
415
+ @pytest.mark.parametrize(
416
+ "test_case",
417
+ [
418
+ TestCase(
419
+ name="get_logs_nonexistent_job",
420
+ expected_status=FAILED,
421
+ config={"job_name": "nonexistent-job", "step": "train"},
422
+ expected_error=ValueError,
423
+ ),
424
+ ],
425
+ )
426
+ def test_get_job_logs(local_backend, test_case):
427
+ """Test LocalProcessBackend.get_job_logs()."""
428
+ job_name = test_case.config.get("job_name")
429
+ step = test_case.config.get("step", "train")
430
+
431
+ if test_case.expected_status == FAILED:
432
+ with pytest.raises(test_case.expected_error):
433
+ list(local_backend.get_job_logs(job_name, step=step))
434
+
435
+
436
+ @pytest.mark.parametrize(
437
+ "test_case",
438
+ [
439
+ TestCase(
440
+ name="wait_for_nonexistent_job",
441
+ expected_status=FAILED,
442
+ config={"job_name": "nonexistent-job"},
443
+ expected_error=ValueError,
444
+ ),
445
+ ],
446
+ )
447
+ def test_wait_for_job_status(local_backend, test_case):
448
+ """Test LocalProcessBackend.wait_for_job_status()."""
449
+ job_name = test_case.config.get("job_name")
450
+
451
+ if test_case.expected_status == FAILED:
452
+ with pytest.raises(test_case.expected_error):
453
+ local_backend.wait_for_job_status(job_name)
454
+
455
+
456
+ @pytest.mark.parametrize(
457
+ "test_case",
458
+ [
459
+ TestCase(
460
+ name="delete_nonexistent_job",
461
+ expected_status=FAILED,
462
+ config={"job_name": "nonexistent-job"},
463
+ expected_error=ValueError,
464
+ ),
465
+ ],
466
+ )
467
+ def test_delete_job(local_backend, test_case):
468
+ """Test LocalProcessBackend.delete_job()."""
469
+ job_name = test_case.config.get("job_name")
470
+
471
+ if test_case.expected_status == FAILED:
472
+ with pytest.raises(test_case.expected_error):
473
+ local_backend.delete_job(job_name)
474
+
475
+
476
+ def test_name_option_sets_job_name(local_backend, mock_train_environment):
477
+ """Test that Name option sets the custom job name."""
478
+ custom_name = "my-custom-job-name"
479
+
480
+ def dummy_func():
481
+ pass
482
+
483
+ runtime = types.Runtime(
484
+ name=TORCH_RUNTIME,
485
+ trainer=types.RuntimeTrainer(
486
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
487
+ framework="torch",
488
+ image=LOCAL_RUNTIME_IMAGE,
489
+ ),
490
+ )
491
+
492
+ trainer = types.CustomTrainer(func=dummy_func)
493
+ options = [Name(name=custom_name)]
494
+
495
+ job_name = local_backend.train(
496
+ runtime=runtime,
497
+ trainer=trainer,
498
+ options=options,
499
+ )
500
+
501
+ assert job_name == custom_name
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ import textwrap
17
+
18
+ import viettelcloud.aiplatform.common.constants as common_constants
19
+ from viettelcloud.aiplatform.trainer.backends.localprocess import types
20
+ from viettelcloud.aiplatform.trainer.constants import constants
21
+ from viettelcloud.aiplatform.trainer.types import types as base_types
22
+
23
+ TORCH_FRAMEWORK_TYPE = "torch"
24
+
25
+ # Image name for the local runtime.
26
+ LOCAL_RUNTIME_IMAGE = "local"
27
+
28
+ local_runtimes = [
29
+ base_types.Runtime(
30
+ name=constants.DEFAULT_TRAINING_RUNTIME,
31
+ trainer=types.LocalRuntimeTrainer(
32
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
33
+ framework=TORCH_FRAMEWORK_TYPE,
34
+ num_nodes=1,
35
+ device_count=common_constants.UNKNOWN,
36
+ device=common_constants.UNKNOWN,
37
+ packages=["torch"],
38
+ image=LOCAL_RUNTIME_IMAGE,
39
+ ),
40
+ )
41
+ ]
42
+
43
+
44
+ # Create venv script
45
+
46
+
47
+ # The exec script to embed training function into container command.
48
+ DEPENDENCIES_SCRIPT = textwrap.dedent(
49
+ """
50
+ PIP_DISABLE_PIP_VERSION_CHECK=1 pip install $QUIET \
51
+ --no-warn-script-location $PIP_INDEX $PACKAGE_STR
52
+ """
53
+ )
54
+
55
+ # activate virtualenv, then run the entrypoint from the virtualenv bin
56
+ LOCAL_EXEC_ENTRYPOINT = textwrap.dedent(
57
+ """
58
+ $ENTRYPOINT "$FUNC_FILE" "$PARAMETERS"
59
+ """
60
+ )
61
+
62
+ TORCH_COMMAND = "torchrun"
63
+
64
+ # default command, will run from within the virtualenv
65
+ DEFAULT_COMMAND = "python"
66
+
67
+ # remove virtualenv after training is completed.
68
+ LOCAL_EXEC_JOB_CLEANUP_SCRIPT = textwrap.dedent(
69
+ """
70
+ rm -rf $PYENV_LOCATION
71
+ """
72
+ )
73
+
74
+
75
+ LOCAL_EXEC_JOB_TEMPLATE = textwrap.dedent(
76
+ """
77
+ set -e
78
+ $OS_PYTHON_BIN -m venv --without-pip $PYENV_LOCATION
79
+ echo "Operating inside $PYENV_LOCATION"
80
+ source $PYENV_LOCATION/bin/activate
81
+ $PYENV_LOCATION/bin/python -m ensurepip --upgrade --default-pip
82
+ $DEPENDENCIES_SCRIPT
83
+ $ENTRYPOINT
84
+ $CLEANUP_SCRIPT
85
+ """
86
+ )
87
+
88
+ LOCAL_EXEC_FILENAME = "train_{}.py"
89
+
90
+ PYTHON_PACKAGE_NAME_RE = re.compile(r"^\s*([A-Za-z0-9][A-Za-z0-9._-]*)")