viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Kubernetes backend-specific constants."""
@@ -0,0 +1,636 @@
1
+ # Copyright 2024 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import fields
16
+ import inspect
17
+ import os
18
+ import textwrap
19
+ from typing import Any, Callable, Optional, Union
20
+ from urllib.parse import urlparse
21
+
22
+ from kubeflow_trainer_api import models
23
+
24
+ from viettelcloud.aiplatform.trainer.constants import constants
25
+ from viettelcloud.aiplatform.trainer.types import types
26
+
27
+
28
+ def get_container_devices(
29
+ resources: Optional[models.IoK8sApiCoreV1ResourceRequirements],
30
+ ) -> Optional[tuple[str, str]]:
31
+ """
32
+ Get the device type and device count for the given container.
33
+ """
34
+
35
+ # If containers resource limits are empty, return Unknown.
36
+ if resources is None or resources.limits is None:
37
+ return None
38
+
39
+ # TODO (andreyvelich): We should discuss how to get container device type.
40
+ # Potentially, we can use the trainer.kubeflow.org/device label from the runtime or
41
+ # node types.
42
+ # TODO (andreyvelich): Support other resource labels (e.g. NPUs).
43
+ if constants.GPU_LABEL in resources.limits:
44
+ device = constants.GPU_LABEL.split("/")[1]
45
+ device_count = resources.limits[constants.GPU_LABEL].actual_instance
46
+ elif constants.TPU_LABEL in resources.limits:
47
+ device = constants.TPU_LABEL.split("/")[1]
48
+ device_count = resources.limits[constants.TPU_LABEL].actual_instance
49
+ elif any(k.startswith(constants.GPU_MIG_PREFIX) for k in resources.limits):
50
+ mig_keys = [k for k in resources.limits if k.startswith(constants.GPU_MIG_PREFIX)]
51
+ if len(mig_keys) > 1:
52
+ raise ValueError(f"Multiple MIG resource types are not supported yet: {mig_keys}")
53
+ mig_key = mig_keys[0]
54
+ device = mig_key.split("/")[1]
55
+ device_count = resources.limits[mig_key].actual_instance
56
+ elif constants.CPU_LABEL in resources.limits:
57
+ device = constants.CPU_LABEL
58
+ device_count = resources.limits[constants.CPU_LABEL].actual_instance
59
+ else:
60
+ raise Exception(f"Unknown device type in the container resources: {resources.limits}")
61
+ if device_count is None:
62
+ raise Exception(f"Failed to get device count for resources: {resources.limits}")
63
+
64
+ return device, str(device_count)
65
+
66
+
67
+ def get_runtime_trainer_container(
68
+ replicated_jobs: list[models.JobsetV1alpha2ReplicatedJob],
69
+ ) -> Optional[models.IoK8sApiCoreV1Container]:
70
+ """
71
+ Get the runtime node container from the given replicated jobs.
72
+ """
73
+
74
+ for rjob in replicated_jobs:
75
+ if not (rjob.template.spec and rjob.template.spec.template.spec):
76
+ raise Exception(f"Invalid ReplicatedJob template: {rjob}")
77
+ # The ancestor label defines Trainer container in the ReplicatedJobs.
78
+ if not (
79
+ rjob.template.metadata
80
+ and rjob.template.metadata.labels
81
+ and constants.TRAINJOB_ANCESTOR_LABEL in rjob.template.metadata.labels
82
+ ):
83
+ continue
84
+
85
+ for container in rjob.template.spec.template.spec.containers:
86
+ if container.name == constants.NODE:
87
+ return container
88
+
89
+ return None
90
+
91
+
92
+ def get_runtime_trainer(
93
+ framework: str,
94
+ replicated_jobs: list[models.JobsetV1alpha2ReplicatedJob],
95
+ ml_policy: models.TrainerV1alpha1MLPolicy,
96
+ ) -> types.RuntimeTrainer:
97
+ """
98
+ Get the RuntimeTrainer object.
99
+ """
100
+
101
+ trainer_container = get_runtime_trainer_container(replicated_jobs)
102
+
103
+ if not (trainer_container and trainer_container.image):
104
+ raise Exception(f"Runtime doesn't have trainer container {replicated_jobs}")
105
+
106
+ trainer = types.RuntimeTrainer(
107
+ trainer_type=(
108
+ types.TrainerType.BUILTIN_TRAINER
109
+ if framework == types.TORCH_TUNE
110
+ else types.TrainerType.CUSTOM_TRAINER
111
+ ),
112
+ framework=framework,
113
+ image=trainer_container.image,
114
+ )
115
+
116
+ # Get the container devices.
117
+ if devices := get_container_devices(trainer_container.resources):
118
+ trainer.device, trainer.device_count = devices
119
+
120
+ # Torch and MPI plugins override accelerator count.
121
+ if ml_policy.torch and ml_policy.torch.num_proc_per_node:
122
+ num_proc = ml_policy.torch.num_proc_per_node.actual_instance
123
+ if isinstance(num_proc, int):
124
+ trainer.device_count = str(num_proc)
125
+ elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
126
+ trainer.device_count = str(ml_policy.mpi.num_proc_per_node)
127
+
128
+ # Multiply accelerator_count by the number of nodes.
129
+ if trainer.device_count.isdigit() and ml_policy.num_nodes:
130
+ trainer.device_count = str(int(trainer.device_count) * ml_policy.num_nodes)
131
+
132
+ # Add number of training nodes.
133
+ if ml_policy.num_nodes:
134
+ trainer.num_nodes = ml_policy.num_nodes
135
+
136
+ # Set the Trainer entrypoint.
137
+ if framework == types.TORCH_TUNE:
138
+ trainer.set_command(constants.TORCH_TUNE_COMMAND)
139
+ elif ml_policy.torch:
140
+ trainer.set_command(constants.TORCH_COMMAND)
141
+ elif ml_policy.mpi:
142
+ trainer.set_command(constants.MPI_COMMAND)
143
+ else:
144
+ trainer.set_command(constants.DEFAULT_COMMAND)
145
+
146
+ return trainer
147
+
148
+
149
+ def get_trainjob_initializer_step(
150
+ pod_name: str,
151
+ pod_spec: models.IoK8sApiCoreV1PodSpec,
152
+ pod_status: Optional[models.IoK8sApiCoreV1PodStatus],
153
+ ) -> types.Step:
154
+ """
155
+ Get the TrainJob initializer step from the given Pod name, spec, and status.
156
+ """
157
+
158
+ container = next(
159
+ c
160
+ for c in pod_spec.containers
161
+ if c.name in {constants.DATASET_INITIALIZER, constants.MODEL_INITIALIZER}
162
+ )
163
+
164
+ step = types.Step(
165
+ name=container.name,
166
+ status=pod_status.phase if pod_status else None,
167
+ pod_name=pod_name,
168
+ )
169
+
170
+ if devices := get_container_devices(container.resources):
171
+ step.device, step.device_count = devices
172
+
173
+ return step
174
+
175
+
176
+ def get_trainjob_node_step(
177
+ pod_name: str,
178
+ pod_spec: models.IoK8sApiCoreV1PodSpec,
179
+ pod_status: Optional[models.IoK8sApiCoreV1PodStatus],
180
+ trainjob_runtime: types.Runtime,
181
+ replicated_job_name: str,
182
+ job_index: int,
183
+ ) -> types.Step:
184
+ """
185
+ Get the TrainJob trainer node step from the given Pod name, spec, and status.
186
+ """
187
+
188
+ container = next(c for c in pod_spec.containers if c.name == constants.NODE)
189
+
190
+ step = types.Step(
191
+ name=f"{constants.NODE}-{job_index}",
192
+ status=pod_status.phase if pod_status else None,
193
+ pod_name=pod_name,
194
+ )
195
+
196
+ if devices := get_container_devices(container.resources):
197
+ step.device, step.device_count = devices
198
+
199
+ # For the MPI use-cases, the launcher container is always node-0
200
+ # Thus, we should increase the index for other nodes.
201
+ if (
202
+ trainjob_runtime.trainer.command[0] == "mpirun"
203
+ and replicated_job_name != constants.LAUNCHER
204
+ ):
205
+ # TODO (andreyvelich): We should also override the device_count
206
+ # based on OMPI_MCA_orte_set_default_slots value. Right now, it is hard to do
207
+ # since we inject this env only to the Launcher Pod.
208
+ step.name = f"{constants.NODE}-{job_index + 1}"
209
+
210
+ if container.env:
211
+ for env in container.env:
212
+ if (
213
+ env.value
214
+ and env.value.isdigit()
215
+ and env.name == constants.TORCH_ENV_NUM_PROC_PER_NODE
216
+ ):
217
+ step.device_count = env.value
218
+
219
+ return step
220
+
221
+
222
+ # TODO (andreyvelich): Discuss if we want to support V1ResourceRequirements resources as input.
223
+ def get_resources_per_node(
224
+ resources_per_node: dict,
225
+ ) -> models.IoK8sApiCoreV1ResourceRequirements:
226
+ """
227
+ Get the Trainer resources for the training node from the given dict.
228
+ """
229
+
230
+ # Convert all keys in resources to lowercase.
231
+ resources = {
232
+ k.lower(): models.IoK8sApimachineryPkgApiResourceQuantity(v)
233
+ for k, v in resources_per_node.items()
234
+ }
235
+ if "gpu" in resources:
236
+ resources["nvidia.com/gpu"] = resources.pop("gpu")
237
+
238
+ # Optional alias for MIG: "mig-<profile>" -> "nvidia.com/mig-<profile>"
239
+ # Example: "mig-1g.5gb" -> "nvidia.com/mig-1g.5gb"
240
+ mig_alias_keys = [k for k in resources if k.startswith("mig-")]
241
+ for k in mig_alias_keys:
242
+ resources[f"{constants.GPU_MIG_PREFIX}{k[len('mig-') :]}"] = resources.pop(k)
243
+
244
+ mig_keys = [k for k in resources if k.startswith(constants.GPU_MIG_PREFIX)]
245
+ if len(mig_keys) > 1:
246
+ raise ValueError(f"Multiple MIG resource types are not supported: {mig_keys}")
247
+ if mig_keys and "nvidia.com/gpu" in resources:
248
+ raise ValueError(
249
+ f"GPU (nvidia.com/gpu) and MIG ({mig_keys[0]}) cannot be requested together"
250
+ )
251
+
252
+ resources = models.IoK8sApiCoreV1ResourceRequirements(
253
+ requests=resources,
254
+ limits=resources,
255
+ )
256
+ return resources
257
+
258
+
259
+ def get_script_for_python_packages(
260
+ packages_to_install: list[str],
261
+ pip_index_urls: list[str],
262
+ ) -> str:
263
+ """
264
+ Get init script to install Python packages from the given pip index URLs.
265
+ """
266
+ packages_str = " ".join(packages_to_install)
267
+
268
+ # first url will be the index-url.
269
+ options = [f"--index-url {pip_index_urls[0]}"]
270
+ options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:])
271
+
272
+ header_script = textwrap.dedent(
273
+ """
274
+ if ! [ -x "$(command -v pip)" ]; then
275
+ python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
276
+ fi
277
+
278
+ """
279
+ )
280
+
281
+ script_for_python_packages = (
282
+ header_script
283
+ + "PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
284
+ + "--no-warn-script-location {} --user {}".format(
285
+ " ".join(options),
286
+ packages_str,
287
+ )
288
+ + " ||\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
289
+ + "--no-warn-script-location {} {}\n".format(
290
+ " ".join(options),
291
+ packages_str,
292
+ )
293
+ )
294
+
295
+ return script_for_python_packages
296
+
297
+
298
+ def get_command_using_train_func(
299
+ runtime: types.Runtime,
300
+ train_func: Callable,
301
+ train_func_parameters: Optional[dict[str, Any]],
302
+ pip_index_urls: list[str],
303
+ packages_to_install: Optional[list[str]],
304
+ ) -> list[str]:
305
+ """
306
+ Get the Trainer container command from the given training function and parameters.
307
+ """
308
+ # Check if the runtime has a Trainer.
309
+ if not runtime.trainer:
310
+ raise ValueError(f"Runtime must have a trainer: {runtime}")
311
+
312
+ # Check if training function is callable.
313
+ if not callable(train_func):
314
+ raise ValueError(
315
+ f"Training function must be callable, got function type: {type(train_func)}"
316
+ )
317
+
318
+ # Extract the function implementation.
319
+ func_code = inspect.getsource(train_func)
320
+
321
+ # Extract the file name where the function is defined.
322
+ func_file = os.path.basename(inspect.getfile(train_func))
323
+
324
+ # Function might be defined in some indented scope (e.g. in another function).
325
+ # We need to dedent the function code.
326
+ func_code = textwrap.dedent(func_code)
327
+
328
+ # Wrap function code to execute it from the file. For example:
329
+ # TODO (andreyvelich): Find a better way to run users' scripts.
330
+ # def train(parameters):
331
+ # print('Start Training...')
332
+ # train({'lr': 0.01})
333
+ if train_func_parameters is None:
334
+ func_call = f"{train_func.__name__}()"
335
+ else:
336
+ # Always unpack kwargs for training function calls.
337
+ func_call = f"{train_func.__name__}(**{train_func_parameters})"
338
+
339
+ # Combine everything into the final code string.
340
+ func_code = f"{func_code}\n{func_call}\n"
341
+
342
+ is_mpi = runtime.trainer.command[0] == "mpirun"
343
+ # The default file location for OpenMPI is: /home/mpiuser/<FILE_NAME>.py
344
+ if is_mpi:
345
+ func_file = os.path.join(constants.DEFAULT_MPI_USER_HOME, func_file)
346
+
347
+ # Install Python packages if that is required.
348
+ install_packages = ""
349
+ if packages_to_install:
350
+ install_packages = get_script_for_python_packages(
351
+ packages_to_install,
352
+ pip_index_urls,
353
+ )
354
+
355
+ # Add function code to the Trainer command.
356
+ command = []
357
+ for c in runtime.trainer.command:
358
+ if "{func_file}" in c:
359
+ exec_script = c.format(func_code=func_code, func_file=func_file)
360
+ if install_packages:
361
+ exec_script = install_packages + exec_script
362
+ command.append(exec_script)
363
+ else:
364
+ command.append(c)
365
+
366
+ return command
367
+
368
+
369
+ def get_trainer_cr_from_custom_trainer(
370
+ runtime: types.Runtime,
371
+ trainer: Union[types.CustomTrainer, types.CustomTrainerContainer],
372
+ ) -> models.TrainerV1alpha1Trainer:
373
+ """
374
+ Get the Trainer CR from the custom trainer.
375
+
376
+ Args:
377
+ runtime: The runtime configuration.
378
+ trainer: The custom trainer or container configuration.
379
+ """
380
+ trainer_cr = models.TrainerV1alpha1Trainer()
381
+
382
+ # Add number of nodes to the Trainer.
383
+ if trainer.num_nodes:
384
+ trainer_cr.num_nodes = trainer.num_nodes
385
+
386
+ # Add resources per node to the Trainer.
387
+ if trainer.resources_per_node:
388
+ trainer_cr.resources_per_node = get_resources_per_node(trainer.resources_per_node)
389
+
390
+ if isinstance(trainer, types.CustomTrainer):
391
+ # If CustomTrainer is used, generate command from function.
392
+ trainer_cr.command = get_command_using_train_func(
393
+ runtime,
394
+ trainer.func,
395
+ trainer.func_args,
396
+ trainer.pip_index_urls,
397
+ trainer.packages_to_install,
398
+ )
399
+
400
+ # Set the TrainJob trainer image if that is set.
401
+ if trainer.image:
402
+ trainer_cr.image = trainer.image
403
+
404
+ # Add environment variables to the Trainer.
405
+ if trainer.env:
406
+ trainer_cr.env = [
407
+ models.IoK8sApiCoreV1EnvVar(name=key, value=value) for key, value in trainer.env.items()
408
+ ]
409
+
410
+ return trainer_cr
411
+
412
+
413
+ def get_trainer_cr_from_builtin_trainer(
414
+ runtime: types.Runtime,
415
+ trainer: types.BuiltinTrainer,
416
+ initializer: Optional[types.Initializer] = None,
417
+ ) -> models.TrainerV1alpha1Trainer:
418
+ """
419
+ Get the Trainer CR from the builtin trainer.
420
+ """
421
+ if not isinstance(trainer.config, types.TorchTuneConfig):
422
+ raise ValueError(f"The BuiltinTrainer config is invalid: {trainer.config}")
423
+
424
+ trainer_cr = models.TrainerV1alpha1Trainer()
425
+
426
+ # Add number of nodes to the Trainer.
427
+ if trainer.config.num_nodes:
428
+ trainer_cr.num_nodes = trainer.config.num_nodes
429
+
430
+ # Add resources per node to the Trainer.
431
+ if trainer.config.resources_per_node:
432
+ trainer_cr.resources_per_node = get_resources_per_node(trainer.config.resources_per_node)
433
+
434
+ trainer_cr.command = list(runtime.trainer.command)
435
+ # Parse args in the TorchTuneConfig to the Trainer, preparing for the mutation of
436
+ # the torchtune config in the runtime plugin.
437
+ # Ref:https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2
438
+ trainer_cr.args = get_args_using_torchtune_config(trainer.config, initializer)
439
+
440
+ return trainer_cr
441
+
442
+
443
+ def get_args_using_torchtune_config(
444
+ fine_tuning_config: types.TorchTuneConfig,
445
+ initializer: Optional[types.Initializer] = None,
446
+ ) -> list[str]:
447
+ """
448
+ Get the Trainer args from the TorchTuneConfig.
449
+ """
450
+ args = []
451
+
452
+ # Override the dtype if it is provided.
453
+ if fine_tuning_config.dtype:
454
+ if not isinstance(fine_tuning_config.dtype, types.DataType):
455
+ raise ValueError(f"Invalid dtype: {fine_tuning_config.dtype}.")
456
+
457
+ args.append(f"dtype={fine_tuning_config.dtype}")
458
+
459
+ # Override the batch size if it is provided.
460
+ if fine_tuning_config.batch_size:
461
+ args.append(f"batch_size={fine_tuning_config.batch_size}")
462
+
463
+ # Override the epochs if it is provided.
464
+ if fine_tuning_config.epochs:
465
+ args.append(f"epochs={fine_tuning_config.epochs}")
466
+
467
+ # Override the loss if it is provided.
468
+ if fine_tuning_config.loss:
469
+ args.append(f"loss={fine_tuning_config.loss}")
470
+
471
+ # Override the data dir or data files if it is provided.
472
+ if isinstance(initializer, types.Initializer) and isinstance(
473
+ initializer.dataset, types.HuggingFaceDatasetInitializer
474
+ ):
475
+ storage_uri = (
476
+ "hf://" + initializer.dataset.storage_uri
477
+ if not initializer.dataset.storage_uri.startswith("hf://")
478
+ else initializer.dataset.storage_uri
479
+ )
480
+ storage_uri_parsed = urlparse(storage_uri)
481
+ parts = storage_uri_parsed.path.strip("/").split("/")
482
+ relative_path = "/".join(parts[1:]) if len(parts) > 1 else "."
483
+
484
+ if relative_path != "." and "." in relative_path:
485
+ args.append(f"dataset.data_files={os.path.join(constants.DATASET_PATH, relative_path)}")
486
+ else:
487
+ args.append(f"dataset.data_dir={os.path.join(constants.DATASET_PATH, relative_path)}")
488
+
489
+ if fine_tuning_config.peft_config:
490
+ args += get_args_from_peft_config(fine_tuning_config.peft_config)
491
+
492
+ if fine_tuning_config.dataset_preprocess_config:
493
+ args += get_args_from_dataset_preprocess_config(
494
+ fine_tuning_config.dataset_preprocess_config
495
+ )
496
+
497
+ return args
498
+
499
+
500
+ def get_args_from_peft_config(peft_config: types.LoraConfig) -> list[str]:
501
+ """
502
+ Get the args from the given PEFT config.
503
+ """
504
+ args = []
505
+
506
+ if not isinstance(peft_config, types.LoraConfig):
507
+ raise ValueError(f"Invalid PEFT config type: {type(peft_config)}.")
508
+
509
+ field_map = {
510
+ "apply_lora_to_mlp": "model.apply_lora_to_mlp",
511
+ "apply_lora_to_output": "model.apply_lora_to_output",
512
+ "lora_rank": "model.lora_rank",
513
+ "lora_alpha": "model.lora_alpha",
514
+ "lora_dropout": "model.lora_dropout",
515
+ "quantize_base": "model.quantize_base",
516
+ "use_dora": "model.use_dora",
517
+ }
518
+
519
+ # Override the PEFT fields if they are provided.
520
+ for field, arg_name in field_map.items():
521
+ value = getattr(peft_config, field, None)
522
+ if value:
523
+ args.append(f"{arg_name}={value}")
524
+
525
+ # Override the LoRA attention modules if they are provided.
526
+ if peft_config.lora_attn_modules:
527
+ args.append(f"model.lora_attn_modules=[{','.join(peft_config.lora_attn_modules)}]")
528
+
529
+ return args
530
+
531
+
532
+ def get_args_from_dataset_preprocess_config(
533
+ dataset_preprocess_config: types.TorchTuneInstructDataset,
534
+ ) -> list[str]:
535
+ """
536
+ Get the args from the given dataset preprocess config.
537
+ """
538
+ args = []
539
+
540
+ if not isinstance(dataset_preprocess_config, types.TorchTuneInstructDataset):
541
+ raise ValueError(
542
+ f"Invalid dataset preprocess config type: {type(dataset_preprocess_config)}."
543
+ )
544
+
545
+ # Override the dataset type field in the torchtune config.
546
+ args.append(f"dataset={constants.TORCH_TUNE_INSTRUCT_DATASET}")
547
+
548
+ # Override the dataset source field if it is provided.
549
+ if dataset_preprocess_config.source:
550
+ if not isinstance(dataset_preprocess_config.source, types.DataFormat):
551
+ raise ValueError(f"Invalid data format: {dataset_preprocess_config.source.value}.")
552
+
553
+ args.append(f"dataset.source={dataset_preprocess_config.source.value}")
554
+
555
+ # Override the split field if it is provided.
556
+ if dataset_preprocess_config.split:
557
+ args.append(f"dataset.split={dataset_preprocess_config.split}")
558
+
559
+ # Override the train_on_input field if it is provided.
560
+ if dataset_preprocess_config.train_on_input:
561
+ args.append(f"dataset.train_on_input={dataset_preprocess_config.train_on_input}")
562
+
563
+ # Override the new_system_prompt field if it is provided.
564
+ if dataset_preprocess_config.new_system_prompt:
565
+ args.append(f"dataset.new_system_prompt={dataset_preprocess_config.new_system_prompt}")
566
+
567
+ # Override the column_map field if it is provided.
568
+ if dataset_preprocess_config.column_map:
569
+ args.append(f"dataset.column_map={dataset_preprocess_config.column_map}")
570
+
571
+ return args
572
+
573
+
574
+ def get_optional_initializer_envs(
575
+ initializer: types.BaseInitializer, required_fields: set
576
+ ) -> list[models.IoK8sApiCoreV1EnvVar]:
577
+ """Get the optional envs from the initializer config"""
578
+ envs = []
579
+ for f in fields(initializer):
580
+ if f.name not in required_fields:
581
+ value = getattr(initializer, f.name)
582
+ if value is not None:
583
+ # Convert list values (like ignore_patterns) to comma-separated strings
584
+ if isinstance(value, list):
585
+ value = ",".join(str(item) for item in value)
586
+ envs.append(models.IoK8sApiCoreV1EnvVar(name=f.name.upper(), value=value))
587
+ return envs
588
+
589
+
590
+ def get_dataset_initializer(
591
+ dataset: Union[
592
+ types.HuggingFaceDatasetInitializer,
593
+ types.S3DatasetInitializer,
594
+ types.DataCacheInitializer,
595
+ ],
596
+ ) -> models.TrainerV1alpha1DatasetInitializer:
597
+ """
598
+ Get the TrainJob dataset initializer from the given config.
599
+ """
600
+ if isinstance(dataset, (types.HuggingFaceDatasetInitializer, types.S3DatasetInitializer)):
601
+ return models.TrainerV1alpha1DatasetInitializer(
602
+ storageUri=dataset.storage_uri,
603
+ env=get_optional_initializer_envs(dataset, required_fields={"storage_uri"}),
604
+ )
605
+
606
+ elif isinstance(dataset, types.DataCacheInitializer):
607
+ envs = [
608
+ models.IoK8sApiCoreV1EnvVar(name="CLUSTER_SIZE", value=str(dataset.num_data_nodes + 1)),
609
+ models.IoK8sApiCoreV1EnvVar(name="METADATA_LOC", value=dataset.metadata_loc),
610
+ ]
611
+
612
+ # Add env vars from optional fields (skip required fields)
613
+ envs += get_optional_initializer_envs(
614
+ dataset, {"storage_uri", "metadata_loc", "num_data_nodes"}
615
+ )
616
+
617
+ return models.TrainerV1alpha1DatasetInitializer(
618
+ storageUri=dataset.storage_uri, env=envs if envs else None
619
+ )
620
+
621
+ raise ValueError(f"Dataset initializer type is invalid: {type(dataset)}")
622
+
623
+
624
+ def get_model_initializer(
625
+ model: Union[types.HuggingFaceModelInitializer, types.S3ModelInitializer],
626
+ ) -> models.TrainerV1alpha1ModelInitializer:
627
+ """
628
+ Get the TrainJob model initializer from the given config.
629
+ """
630
+ if isinstance(model, (types.HuggingFaceModelInitializer, types.S3ModelInitializer)):
631
+ return models.TrainerV1alpha1ModelInitializer(
632
+ storageUri=model.storage_uri,
633
+ env=get_optional_initializer_envs(model, required_fields={"storage_uri"}),
634
+ )
635
+
636
+ raise ValueError(f"Model initializer type is invalid: {type(model)}")