viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,710 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import Callable, Iterator
16
+ import copy
17
+ import logging
18
+ import multiprocessing
19
+ import random
20
+ import re
21
+ import string
22
+ import time
23
+ from typing import Any, Optional, Union
24
+ import uuid
25
+
26
+ from kubeflow_trainer_api import models
27
+ from kubernetes import client, config, watch
28
+
29
+ import viettelcloud.aiplatform.common.constants as common_constants
30
+ from viettelcloud.aiplatform.common.types import KubernetesBackendConfig
31
+ import viettelcloud.aiplatform.common.utils as common_utils
32
+ from viettelcloud.aiplatform.trainer.backends.base import RuntimeBackend
33
+ import viettelcloud.aiplatform.trainer.backends.kubernetes.utils as utils
34
+ from viettelcloud.aiplatform.trainer.constants import constants
35
+ from viettelcloud.aiplatform.trainer.types import types
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class KubernetesBackend(RuntimeBackend):
41
+ def __init__(self, cfg: KubernetesBackendConfig):
42
+ if cfg.namespace is None:
43
+ cfg.namespace = common_utils.get_default_target_namespace(cfg.context)
44
+
45
+ # If client configuration is not set, use kube-config to access Kubernetes APIs.
46
+ if cfg.client_configuration is None:
47
+ # Load kube-config or in-cluster config.
48
+ if cfg.config_file or not common_utils.is_running_in_k8s():
49
+ config.load_kube_config(config_file=cfg.config_file, context=cfg.context)
50
+ else:
51
+ config.load_incluster_config()
52
+
53
+ k8s_client = client.ApiClient(cfg.client_configuration)
54
+ self.custom_api = client.CustomObjectsApi(k8s_client)
55
+ self.core_api = client.CoreV1Api(k8s_client)
56
+
57
+ self.namespace = cfg.namespace
58
+
59
+ def list_runtimes(self) -> list[types.Runtime]:
60
+ result = []
61
+ try:
62
+ thread = self.custom_api.list_cluster_custom_object(
63
+ constants.GROUP,
64
+ constants.VERSION,
65
+ constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
66
+ async_req=True,
67
+ )
68
+
69
+ runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict(
70
+ thread.get(common_constants.DEFAULT_TIMEOUT)
71
+ )
72
+
73
+ if not runtime_list:
74
+ return result
75
+
76
+ for runtime in runtime_list.items:
77
+ if not (
78
+ runtime.metadata
79
+ and runtime.metadata.labels
80
+ and constants.RUNTIME_FRAMEWORK_LABEL in runtime.metadata.labels
81
+ ):
82
+ logger.warning(
83
+ f"Runtime {runtime.metadata.name} must have " # type: ignore
84
+ f"{constants.RUNTIME_FRAMEWORK_LABEL} label."
85
+ )
86
+ continue
87
+ result.append(self.__get_runtime_from_cr(runtime))
88
+
89
+ except multiprocessing.TimeoutError as e:
90
+ raise TimeoutError(f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s") from e
91
+ except Exception as e:
92
+ raise RuntimeError(f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s") from e
93
+
94
+ return result
95
+
96
+ def get_runtime(self, name: str) -> types.Runtime:
97
+ """Get the Runtime object"""
98
+
99
+ try:
100
+ thread = self.custom_api.get_cluster_custom_object(
101
+ constants.GROUP,
102
+ constants.VERSION,
103
+ constants.CLUSTER_TRAINING_RUNTIME_PLURAL,
104
+ name,
105
+ async_req=True,
106
+ )
107
+
108
+ runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(
109
+ thread.get(common_constants.DEFAULT_TIMEOUT) # type: ignore
110
+ )
111
+
112
+ except multiprocessing.TimeoutError as e:
113
+ raise TimeoutError(
114
+ f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: {name}"
115
+ ) from e
116
+ except Exception as e:
117
+ raise RuntimeError(
118
+ f"Failed to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: {name}"
119
+ ) from e
120
+
121
+ return self.__get_runtime_from_cr(runtime) # type: ignore
122
+
123
+ def get_runtime_packages(self, runtime: types.Runtime):
124
+ if runtime.trainer.trainer_type == types.TrainerType.BUILTIN_TRAINER:
125
+ raise ValueError("Cannot get Runtime packages for BuiltinTrainer")
126
+
127
+ # Create a deepcopy of the runtime to avoid modifying the original command.
128
+ runtime_copy = copy.deepcopy(runtime)
129
+
130
+ # Run mpirun only within the single process.
131
+ if runtime_copy.trainer.command[0] == "mpirun":
132
+ mpi_command = list(constants.MPI_COMMAND)
133
+ mpi_command[1:3] = ["-np", "1"]
134
+ runtime_copy.trainer.set_command(tuple(mpi_command))
135
+
136
+ def print_packages():
137
+ import shutil
138
+ import subprocess
139
+ import sys
140
+
141
+ # Print Python version.
142
+ print(f"Python: {sys.version}")
143
+
144
+ # Print Python packages.
145
+ if shutil.which("pip"):
146
+ pip_list = subprocess.run(["pip", "list"], capture_output=True, text=True)
147
+ print(pip_list.stdout)
148
+ else:
149
+ print("Unable to get installed packages: pip command not found")
150
+
151
+ # Print nvidia-smi if GPUs are available.
152
+ if shutil.which("nvidia-smi"):
153
+ print("Available GPUs on the single training node")
154
+ nvidia_smi = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
155
+ print(nvidia_smi.stdout)
156
+
157
+ # Create the TrainJob and wait until it completes.
158
+ # If Runtime trainer has GPU resources use them, otherwise run TrainJob with 1 CPU.
159
+ job_name = self.train(
160
+ runtime=runtime_copy,
161
+ trainer=types.CustomTrainer(
162
+ func=print_packages,
163
+ num_nodes=1,
164
+ resources_per_node=({"cpu": 1} if runtime_copy.trainer.device != "gpu" else None),
165
+ ),
166
+ )
167
+
168
+ self.wait_for_job_status(job_name)
169
+ print("\n".join(self.get_job_logs(name=job_name)))
170
+ self.delete_job(job_name)
171
+
172
+ def train(
173
+ self,
174
+ runtime: Optional[Union[str, types.Runtime]] = None,
175
+ initializer: Optional[types.Initializer] = None,
176
+ trainer: Optional[
177
+ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
178
+ ] = None,
179
+ options: Optional[list] = None,
180
+ ) -> str:
181
+ # Process options to extract configuration
182
+ job_spec = {}
183
+ labels = None
184
+ annotations = None
185
+ name = None
186
+ spec_labels = None
187
+ spec_annotations = None
188
+ trainer_overrides = {}
189
+ pod_template_overrides = None
190
+
191
+ if options:
192
+ for option in options:
193
+ option(job_spec, trainer, self)
194
+
195
+ metadata_section = job_spec.get("metadata", {})
196
+ labels = metadata_section.get("labels")
197
+ annotations = metadata_section.get("annotations")
198
+ name = metadata_section.get("name")
199
+
200
+ # Extract spec-level labels/annotations and other spec configurations
201
+ spec_section = job_spec.get("spec", {})
202
+ spec_labels = spec_section.get("labels")
203
+ spec_annotations = spec_section.get("annotations")
204
+ trainer_overrides = spec_section.get("trainer", {})
205
+ pod_template_overrides = spec_section.get("podTemplateOverrides")
206
+
207
+ # Generate unique name for the TrainJob if not provided
208
+ train_job_name = name or (
209
+ random.choice(string.ascii_lowercase)
210
+ + uuid.uuid4().hex[: constants.JOB_NAME_UUID_LENGTH]
211
+ )
212
+
213
+ # Build the TrainJob spec using the common _get_trainjob_spec method
214
+ trainjob_spec = self._get_trainjob_spec(
215
+ runtime=runtime,
216
+ initializer=initializer,
217
+ trainer=trainer,
218
+ trainer_overrides=trainer_overrides,
219
+ spec_labels=spec_labels,
220
+ spec_annotations=spec_annotations,
221
+ pod_template_overrides=pod_template_overrides,
222
+ )
223
+
224
+ # Build the TrainJob.
225
+ train_job = models.TrainerV1alpha1TrainJob(
226
+ apiVersion=constants.API_VERSION,
227
+ kind=constants.TRAINJOB_KIND,
228
+ metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
229
+ name=train_job_name, labels=labels, annotations=annotations
230
+ ),
231
+ spec=trainjob_spec,
232
+ )
233
+
234
+ # Create the TrainJob.
235
+ try:
236
+ self.custom_api.create_namespaced_custom_object(
237
+ constants.GROUP,
238
+ constants.VERSION,
239
+ self.namespace,
240
+ constants.TRAINJOB_PLURAL,
241
+ train_job.to_dict(),
242
+ )
243
+ except multiprocessing.TimeoutError as e:
244
+ raise TimeoutError(
245
+ f"Timeout to create {constants.TRAINJOB_KIND}: {self.namespace}/{train_job_name}"
246
+ ) from e
247
+ except Exception as e:
248
+ raise RuntimeError(
249
+ f"Failed to create {constants.TRAINJOB_KIND}: {self.namespace}/{train_job_name}"
250
+ ) from e
251
+
252
+ logger.debug(
253
+ f"{constants.TRAINJOB_KIND} {self.namespace}/{train_job_name} has been created"
254
+ )
255
+
256
+ return train_job_name
257
+
258
+ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
259
+ result = []
260
+ try:
261
+ thread = self.custom_api.list_namespaced_custom_object(
262
+ constants.GROUP,
263
+ constants.VERSION,
264
+ self.namespace,
265
+ constants.TRAINJOB_PLURAL,
266
+ async_req=True,
267
+ )
268
+
269
+ trainjob_list = models.TrainerV1alpha1TrainJobList.from_dict(
270
+ thread.get(common_constants.DEFAULT_TIMEOUT)
271
+ )
272
+
273
+ if not trainjob_list:
274
+ return result
275
+
276
+ for trainjob in trainjob_list.items:
277
+ # If runtime object is set, we check the TrainJob's runtime reference.
278
+ if (
279
+ runtime is not None
280
+ and trainjob.spec
281
+ and trainjob.spec.runtime_ref
282
+ and trainjob.spec.runtime_ref.name != runtime.name
283
+ ):
284
+ continue
285
+
286
+ result.append(self.__get_trainjob_from_cr(trainjob))
287
+
288
+ except multiprocessing.TimeoutError as e:
289
+ raise TimeoutError(
290
+ f"Timeout to list {constants.TRAINJOB_KIND}s in namespace: {self.namespace}"
291
+ ) from e
292
+ except Exception as e:
293
+ raise RuntimeError(
294
+ f"Failed to list {constants.TRAINJOB_KIND}s in namespace: {self.namespace}"
295
+ ) from e
296
+
297
+ return result
298
+
299
+ def get_job(self, name: str) -> types.TrainJob:
300
+ """Get the TrainJob object"""
301
+
302
+ try:
303
+ thread = self.custom_api.get_namespaced_custom_object(
304
+ constants.GROUP,
305
+ constants.VERSION,
306
+ self.namespace,
307
+ constants.TRAINJOB_PLURAL,
308
+ name,
309
+ async_req=True,
310
+ )
311
+
312
+ trainjob = models.TrainerV1alpha1TrainJob.from_dict(
313
+ thread.get(common_constants.DEFAULT_TIMEOUT) # type: ignore
314
+ )
315
+
316
+ except multiprocessing.TimeoutError as e:
317
+ raise TimeoutError(
318
+ f"Timeout to get {constants.TRAINJOB_KIND}: {self.namespace}/{name}"
319
+ ) from e
320
+ except Exception as e:
321
+ raise RuntimeError(
322
+ f"Failed to get {constants.TRAINJOB_KIND}: {self.namespace}/{name}"
323
+ ) from e
324
+
325
+ return self.__get_trainjob_from_cr(trainjob) # type: ignore
326
+
327
+ def get_job_logs(
328
+ self,
329
+ name: str,
330
+ follow: bool = False,
331
+ step: str = constants.NODE + "-0",
332
+ ) -> Iterator[str]:
333
+ """Get the TrainJob logs"""
334
+ # Get the TrainJob Pod name.
335
+ pod_name = None
336
+ for c in self.get_job(name).steps:
337
+ if c.status != constants.POD_PENDING and c.name == step:
338
+ pod_name = c.pod_name
339
+ break
340
+ if pod_name is None:
341
+ return
342
+
343
+ # Remove the number for the node step.
344
+ container_name = re.sub(r"-\d+$", "", step)
345
+ yield from self._read_pod_logs(
346
+ pod_name=pod_name, container_name=container_name, follow=follow
347
+ )
348
+
349
+ def wait_for_job_status(
350
+ self,
351
+ name: str,
352
+ status: set[str] = {constants.TRAINJOB_COMPLETE},
353
+ timeout: int = 600,
354
+ polling_interval: int = 2,
355
+ callbacks: Optional[list[Callable[[types.TrainJob], None]]] = None,
356
+ ) -> types.TrainJob:
357
+ job_statuses = {
358
+ constants.TRAINJOB_CREATED,
359
+ constants.TRAINJOB_RUNNING,
360
+ constants.TRAINJOB_COMPLETE,
361
+ constants.TRAINJOB_FAILED,
362
+ }
363
+ if not status.issubset(job_statuses):
364
+ raise ValueError(f"Expected status {status} must be a subset of {job_statuses}")
365
+
366
+ if polling_interval > timeout:
367
+ raise ValueError(
368
+ f"Polling interval {polling_interval} must be less than timeout: {timeout}"
369
+ )
370
+
371
+ for _ in range(round(timeout / polling_interval)):
372
+ # Check the status after event is generated for the TrainJob's Pods.
373
+ trainjob = self.get_job(name)
374
+ logger.debug(f"TrainJob {name}, status {trainjob.status}")
375
+
376
+ # Invoke callbacks if provided
377
+ if callbacks:
378
+ for callback in callbacks:
379
+ callback(trainjob)
380
+
381
+ # Raise an error if TrainJob is Failed and it is not the expected status.
382
+ if (
383
+ constants.TRAINJOB_FAILED not in status
384
+ and trainjob.status == constants.TRAINJOB_FAILED
385
+ ):
386
+ raise RuntimeError(f"TrainJob {name} is Failed")
387
+
388
+ # Return the TrainJob if it reaches the expected status.
389
+ if trainjob.status in status:
390
+ return trainjob
391
+
392
+ time.sleep(polling_interval)
393
+
394
+ raise TimeoutError(f"Timeout waiting for TrainJob {name} to reach status: {status} status")
395
+
396
+ def delete_job(self, name: str):
397
+ try:
398
+ self.custom_api.delete_namespaced_custom_object(
399
+ constants.GROUP,
400
+ constants.VERSION,
401
+ self.namespace,
402
+ constants.TRAINJOB_PLURAL,
403
+ name=name,
404
+ )
405
+ except multiprocessing.TimeoutError as e:
406
+ raise TimeoutError(
407
+ f"Timeout to delete {constants.TRAINJOB_KIND}: {self.namespace}/{name}"
408
+ ) from e
409
+ except Exception as e:
410
+ raise RuntimeError(
411
+ f"Failed to delete {constants.TRAINJOB_KIND}: {self.namespace}/{name}"
412
+ ) from e
413
+
414
+ logger.debug(f"{constants.TRAINJOB_KIND} {self.namespace}/{name} has been deleted")
415
+
416
+ def get_job_events(self, name: str) -> list[types.Event]:
417
+ # Get all pod names related to this TrainJob
418
+ trainjob = self.get_job(name)
419
+
420
+ # Create set of all TrainJob-related resource names
421
+ trainjob_resources = {name}
422
+ for step in trainjob.steps:
423
+ trainjob_resources.add(step.name)
424
+ if step.pod_name:
425
+ trainjob_resources.add(step.pod_name)
426
+
427
+ events = []
428
+ try:
429
+ # Retrieve events from the namespace
430
+ event_response: models.IoK8sApiCoreV1EventList = self.core_api.list_namespaced_event(
431
+ namespace=self.namespace,
432
+ async_req=True,
433
+ ).get(common_constants.DEFAULT_TIMEOUT)
434
+
435
+ # Filter events related to this TrainJob or its pods
436
+ for event in event_response.items:
437
+ if not (event.metadata and event.involved_object and event.first_timestamp):
438
+ continue
439
+
440
+ involved_object = event.involved_object
441
+
442
+ # Check if event is related to TrainJob resources
443
+ if (
444
+ involved_object.kind in {constants.TRAINJOB_KIND, "JobSet", "Job", "Pod"}
445
+ and involved_object.name in trainjob_resources
446
+ ):
447
+ events.append(
448
+ types.Event(
449
+ involved_object_kind=involved_object.kind,
450
+ involved_object_name=involved_object.name,
451
+ message=event.message or "",
452
+ reason=event.reason or "",
453
+ event_time=event.first_timestamp,
454
+ )
455
+ )
456
+
457
+ # Sort events by first occurrence time
458
+ events.sort(key=lambda e: e.event_time)
459
+
460
+ return events
461
+ except multiprocessing.TimeoutError as e:
462
+ raise TimeoutError(
463
+ f"Timeout getting {constants.TRAINJOB_KIND} events: {self.namespace}/{name}"
464
+ ) from e
465
+
466
+ def __get_runtime_from_cr(
467
+ self,
468
+ runtime_cr: models.TrainerV1alpha1ClusterTrainingRuntime,
469
+ ) -> types.Runtime:
470
+ if not (
471
+ runtime_cr.metadata
472
+ and runtime_cr.metadata.name
473
+ and runtime_cr.spec
474
+ and runtime_cr.spec.ml_policy
475
+ and runtime_cr.spec.template.spec
476
+ and runtime_cr.spec.template.spec.replicated_jobs
477
+ ):
478
+ raise Exception(f"ClusterTrainingRuntime CR is invalid: {runtime_cr}")
479
+
480
+ if not (
481
+ runtime_cr.metadata.labels
482
+ and constants.RUNTIME_FRAMEWORK_LABEL in runtime_cr.metadata.labels
483
+ ):
484
+ raise Exception(
485
+ f"Runtime {runtime_cr.metadata.name} must have "
486
+ f"{constants.RUNTIME_FRAMEWORK_LABEL} label"
487
+ )
488
+
489
+ return types.Runtime(
490
+ name=runtime_cr.metadata.name,
491
+ trainer=utils.get_runtime_trainer(
492
+ runtime_cr.metadata.labels[constants.RUNTIME_FRAMEWORK_LABEL],
493
+ runtime_cr.spec.template.spec.replicated_jobs,
494
+ runtime_cr.spec.ml_policy,
495
+ ),
496
+ )
497
+
498
+ def _read_pod_logs(self, pod_name: str, container_name: str, follow: bool) -> Iterator[str]:
499
+ """Read logs from a pod container."""
500
+ try:
501
+ if follow:
502
+ log_stream = watch.Watch().stream(
503
+ self.core_api.read_namespaced_pod_log,
504
+ name=pod_name,
505
+ namespace=self.namespace,
506
+ container=container_name,
507
+ follow=True,
508
+ )
509
+
510
+ # Stream logs incrementally.
511
+ yield from log_stream # type: ignore
512
+ else:
513
+ logs = self.core_api.read_namespaced_pod_log(
514
+ name=pod_name,
515
+ namespace=self.namespace,
516
+ container=container_name,
517
+ )
518
+
519
+ yield from logs.splitlines()
520
+
521
+ except Exception as e:
522
+ raise RuntimeError(
523
+ f"Failed to read logs for the pod {self.namespace}/{pod_name}"
524
+ ) from e
525
+
526
+ def __get_trainjob_from_cr(
527
+ self,
528
+ trainjob_cr: models.TrainerV1alpha1TrainJob,
529
+ ) -> types.TrainJob:
530
+ if not (
531
+ trainjob_cr.metadata
532
+ and trainjob_cr.metadata.name
533
+ and trainjob_cr.metadata.namespace
534
+ and trainjob_cr.spec
535
+ and trainjob_cr.metadata.creation_timestamp
536
+ ):
537
+ raise Exception(f"TrainJob CR is invalid: {trainjob_cr}")
538
+
539
+ name = trainjob_cr.metadata.name
540
+ namespace = trainjob_cr.metadata.namespace
541
+
542
+ runtime = self.get_runtime(trainjob_cr.spec.runtime_ref.name)
543
+
544
+ # Construct the TrainJob from the CR.
545
+ trainjob = types.TrainJob(
546
+ name=name,
547
+ creation_timestamp=trainjob_cr.metadata.creation_timestamp,
548
+ runtime=runtime,
549
+ steps=[],
550
+ # Number of nodes is taken from TrainJob or TrainingRuntime
551
+ num_nodes=(
552
+ trainjob_cr.spec.trainer.num_nodes
553
+ if trainjob_cr.spec.trainer and trainjob_cr.spec.trainer.num_nodes
554
+ else runtime.trainer.num_nodes
555
+ ),
556
+ status=constants.TRAINJOB_CREATED, # The default TrainJob status.
557
+ )
558
+
559
+ # Add the TrainJob components, e.g. trainer nodes and initializer.
560
+ try:
561
+ response = self.core_api.list_namespaced_pod(
562
+ namespace,
563
+ label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name),
564
+ async_req=True,
565
+ ).get(common_constants.DEFAULT_TIMEOUT)
566
+
567
+ # Convert Pod to the correct format.
568
+ # This is required to convert Pod's container resources into API object from str
569
+ pod_list = models.IoK8sApiCoreV1PodList.from_dict(response.to_dict())
570
+ if not pod_list:
571
+ return trainjob
572
+
573
+ for pod in pod_list.items:
574
+ # Pod must have labels to detect the TrainJob step.
575
+ # Every Pod always has a single TrainJob step.
576
+ if not (pod.metadata and pod.metadata.name and pod.metadata.labels and pod.spec):
577
+ raise Exception(f"TrainJob Pod is invalid: {pod}")
578
+
579
+ # Get the Initializer step.
580
+ if pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL] in {
581
+ constants.DATASET_INITIALIZER,
582
+ constants.MODEL_INITIALIZER,
583
+ }:
584
+ trainjob.steps.append(
585
+ utils.get_trainjob_initializer_step(
586
+ pod.metadata.name,
587
+ pod.spec,
588
+ pod.status,
589
+ )
590
+ )
591
+ # Get the Node step.
592
+ elif pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL] in {
593
+ constants.LAUNCHER,
594
+ constants.NODE,
595
+ }:
596
+ trainjob.steps.append(
597
+ utils.get_trainjob_node_step(
598
+ pod.metadata.name,
599
+ pod.spec,
600
+ pod.status,
601
+ trainjob.runtime,
602
+ pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL],
603
+ int(pod.metadata.labels[constants.JOB_INDEX_LABEL]),
604
+ )
605
+ )
606
+ except multiprocessing.TimeoutError as e:
607
+ raise TimeoutError(
608
+ f"Timeout to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}"
609
+ ) from e
610
+ except Exception as e:
611
+ raise RuntimeError(
612
+ f"Failed to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}"
613
+ ) from e
614
+
615
+ # Update the TrainJob status from its conditions.
616
+ if trainjob_cr.status and trainjob_cr.status.conditions:
617
+ for c in trainjob_cr.status.conditions:
618
+ if (
619
+ c.type == constants.TRAINJOB_COMPLETE
620
+ and c.status == "True"
621
+ or c.type == constants.TRAINJOB_FAILED
622
+ and c.status == "True"
623
+ ):
624
+ trainjob.status = c.type
625
+ else:
626
+ # The TrainJob running status is defined when all training node (e.g. Pods) are
627
+ # running or succeeded.
628
+ num_running_nodes = sum(
629
+ 1
630
+ for step in trainjob.steps
631
+ if step.name.startswith(constants.NODE)
632
+ and (
633
+ step.status == constants.TRAINJOB_RUNNING
634
+ or step.status == constants.POD_SUCCEEDED
635
+ )
636
+ )
637
+
638
+ if trainjob.num_nodes == num_running_nodes:
639
+ trainjob.status = constants.TRAINJOB_RUNNING
640
+
641
+ return trainjob
642
+
643
+ def _get_trainjob_spec(
644
+ self,
645
+ runtime: Optional[Union[str, types.Runtime]] = None,
646
+ initializer: Optional[types.Initializer] = None,
647
+ trainer: Optional[
648
+ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
649
+ ] = None,
650
+ trainer_overrides: Optional[dict[str, Any]] = None,
651
+ spec_labels: Optional[dict[str, str]] = None,
652
+ spec_annotations: Optional[dict[str, str]] = None,
653
+ pod_template_overrides: Optional[models.IoK8sApiCoreV1PodTemplateSpec] = None,
654
+ ) -> models.TrainerV1alpha1TrainJobSpec:
655
+ """Get TrainJob spec from the given parameters"""
656
+
657
+ if runtime is None:
658
+ runtime = self.get_runtime(constants.DEFAULT_TRAINING_RUNTIME)
659
+ elif isinstance(runtime, str):
660
+ runtime = self.get_runtime(runtime)
661
+
662
+ # Build the Trainer.
663
+ trainer_cr = models.TrainerV1alpha1Trainer()
664
+
665
+ if trainer:
666
+ # If users choose to use a custom training script.
667
+ if isinstance(trainer, (types.CustomTrainer, types.CustomTrainerContainer)):
668
+ if runtime.trainer.trainer_type != types.TrainerType.CUSTOM_TRAINER:
669
+ raise ValueError(f"CustomTrainer can't be used with {runtime} runtime")
670
+ trainer_cr = utils.get_trainer_cr_from_custom_trainer(runtime, trainer)
671
+
672
+ # If users choose to use a builtin trainer for post-training.
673
+ elif isinstance(trainer, types.BuiltinTrainer):
674
+ if runtime.trainer.trainer_type != types.TrainerType.BUILTIN_TRAINER:
675
+ raise ValueError(f"BuiltinTrainer can't be used with {runtime} runtime")
676
+ trainer_cr = utils.get_trainer_cr_from_builtin_trainer(
677
+ runtime, trainer, initializer
678
+ )
679
+
680
+ else:
681
+ raise ValueError(
682
+ f"The trainer type {type(trainer)} is not supported. "
683
+ "Please use CustomTrainer, CustomTrainerContainer, or BuiltinTrainer."
684
+ )
685
+
686
+ # Apply trainer overrides if trainer was not provided but overrides exist
687
+ if trainer_overrides:
688
+ if "command" in trainer_overrides:
689
+ trainer_cr.command = trainer_overrides["command"]
690
+ if "args" in trainer_overrides:
691
+ trainer_cr.args = trainer_overrides["args"]
692
+
693
+ trainjob_spec = models.TrainerV1alpha1TrainJobSpec(
694
+ runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name),
695
+ trainer=trainer_cr if trainer_cr != models.TrainerV1alpha1Trainer() else None,
696
+ labels=spec_labels,
697
+ annotations=spec_annotations,
698
+ pod_template_overrides=pod_template_overrides,
699
+ )
700
+
701
+ # Add initializer if users define it.
702
+ if initializer and (initializer.dataset or initializer.model):
703
+ trainjob_spec.initializer = models.TrainerV1alpha1Initializer(
704
+ dataset=utils.get_dataset_initializer(initializer.dataset)
705
+ if initializer.dataset
706
+ else None,
707
+ model=utils.get_model_initializer(initializer.model) if initializer.model else None,
708
+ )
709
+
710
+ return trainjob_spec