torchx-nightly 2025.7.9__py3-none-any.whl → 2025.11.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/cli/cmd_list.py +1 -2
- torchx/cli/cmd_run.py +202 -28
- torchx/cli/cmd_tracker.py +1 -1
- torchx/components/__init__.py +1 -8
- torchx/components/dist.py +9 -3
- torchx/components/integration_tests/component_provider.py +2 -2
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +92 -81
- torchx/runner/config.py +11 -9
- torchx/runner/events/__init__.py +20 -10
- torchx/runner/events/api.py +1 -1
- torchx/schedulers/__init__.py +7 -10
- torchx/schedulers/api.py +20 -15
- torchx/schedulers/aws_batch_scheduler.py +45 -2
- torchx/schedulers/docker_scheduler.py +3 -0
- torchx/schedulers/kubernetes_scheduler.py +200 -17
- torchx/schedulers/local_scheduler.py +1 -0
- torchx/schedulers/slurm_scheduler.py +160 -26
- torchx/specs/__init__.py +23 -6
- torchx/specs/api.py +279 -33
- torchx/specs/builders.py +109 -28
- torchx/specs/file_linter.py +117 -53
- torchx/specs/finder.py +25 -37
- torchx/specs/named_resources_aws.py +13 -2
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/util/strings.py +1 -1
- torchx/util/types.py +12 -1
- torchx/version.py +2 -2
- torchx/workspace/api.py +102 -5
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
- torchx/pipelines/kfp/__init__.py +0 -30
- torchx/pipelines/kfp/adapter.py +0 -274
- torchx/pipelines/kfp/version.py +0 -19
- torchx/schedulers/gcp_batch_scheduler.py +0 -497
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -454
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
|
@@ -27,10 +27,81 @@ Install Volcano:
|
|
|
27
27
|
See the
|
|
28
28
|
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
|
|
29
29
|
for more information.
|
|
30
|
+
|
|
31
|
+
Pod Overlay
|
|
32
|
+
===========
|
|
33
|
+
|
|
34
|
+
You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
|
|
35
|
+
the ``kubernetes`` metadata on your role. The value can be:
|
|
36
|
+
|
|
37
|
+
- A dict with the overlay structure
|
|
38
|
+
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
|
|
39
|
+
|
|
40
|
+
Merge semantics:
|
|
41
|
+
- **dict**: recursive merge (upsert)
|
|
42
|
+
- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
|
|
43
|
+
- **primitives**: replace
|
|
44
|
+
|
|
45
|
+
.. code:: python
|
|
46
|
+
|
|
47
|
+
from torchx.specs import Role
|
|
48
|
+
|
|
49
|
+
# Dict overlay - lists append, tuples replace
|
|
50
|
+
role = Role(
|
|
51
|
+
name="trainer",
|
|
52
|
+
image="my-image:latest",
|
|
53
|
+
entrypoint="train.py",
|
|
54
|
+
metadata={
|
|
55
|
+
"kubernetes": {
|
|
56
|
+
"spec": {
|
|
57
|
+
"nodeSelector": {"gpu": "true"},
|
|
58
|
+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
|
|
59
|
+
"volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# File URI overlay
|
|
66
|
+
role = Role(
|
|
67
|
+
name="trainer",
|
|
68
|
+
image="my-image:latest",
|
|
69
|
+
entrypoint="train.py",
|
|
70
|
+
metadata={
|
|
71
|
+
"kubernetes": "file:///path/to/pod_overlay.yaml"
|
|
72
|
+
}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
CLI usage with builtin components:
|
|
76
|
+
|
|
77
|
+
.. code:: bash
|
|
78
|
+
|
|
79
|
+
$ torchx run --scheduler kubernetes dist.ddp \\
|
|
80
|
+
--metadata kubernetes=file:///path/to/pod_overlay.yaml \\
|
|
81
|
+
--script train.py
|
|
82
|
+
|
|
83
|
+
Example ``pod_overlay.yaml``:
|
|
84
|
+
|
|
85
|
+
.. code:: yaml
|
|
86
|
+
|
|
87
|
+
spec:
|
|
88
|
+
nodeSelector:
|
|
89
|
+
node.kubernetes.io/instance-type: p4d.24xlarge
|
|
90
|
+
tolerations:
|
|
91
|
+
- key: nvidia.com/gpu
|
|
92
|
+
operator: Exists
|
|
93
|
+
effect: NoSchedule
|
|
94
|
+
volumes: !!python/tuple
|
|
95
|
+
- name: my-volume
|
|
96
|
+
emptyDir: {}
|
|
97
|
+
|
|
98
|
+
The overlay is deep-merged with the generated pod, preserving existing fields
|
|
99
|
+
and adding or overriding specified ones.
|
|
30
100
|
"""
|
|
31
101
|
|
|
32
102
|
import json
|
|
33
103
|
import logging
|
|
104
|
+
import re
|
|
34
105
|
import warnings
|
|
35
106
|
from dataclasses import dataclass
|
|
36
107
|
from datetime import datetime
|
|
@@ -45,6 +116,7 @@ from typing import (
|
|
|
45
116
|
Tuple,
|
|
46
117
|
TYPE_CHECKING,
|
|
47
118
|
TypedDict,
|
|
119
|
+
Union,
|
|
48
120
|
)
|
|
49
121
|
|
|
50
122
|
import torchx
|
|
@@ -97,6 +169,40 @@ logger: logging.Logger = logging.getLogger(__name__)
|
|
|
97
169
|
RESERVED_MILLICPU = 100
|
|
98
170
|
RESERVED_MEMMB = 1024
|
|
99
171
|
|
|
172
|
+
|
|
173
|
+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
|
|
174
|
+
"""Apply overlay dict to V1Pod object, merging nested fields.
|
|
175
|
+
|
|
176
|
+
Merge semantics:
|
|
177
|
+
- dict: upsert (recursive merge)
|
|
178
|
+
- list: append by default, replace if tuple
|
|
179
|
+
- primitives: replace
|
|
180
|
+
"""
|
|
181
|
+
from kubernetes import client
|
|
182
|
+
|
|
183
|
+
api = client.ApiClient()
|
|
184
|
+
pod_dict = api.sanitize_for_serialization(pod)
|
|
185
|
+
|
|
186
|
+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
|
|
187
|
+
for key, value in overlay.items():
|
|
188
|
+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
|
|
189
|
+
deep_merge(base[key], value)
|
|
190
|
+
elif isinstance(value, tuple):
|
|
191
|
+
base[key] = list(value)
|
|
192
|
+
elif (
|
|
193
|
+
isinstance(value, list) and key in base and isinstance(base[key], list)
|
|
194
|
+
):
|
|
195
|
+
base[key].extend(value)
|
|
196
|
+
else:
|
|
197
|
+
base[key] = value
|
|
198
|
+
|
|
199
|
+
deep_merge(pod_dict, overlay)
|
|
200
|
+
|
|
201
|
+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
|
|
202
|
+
pod.spec = merged_pod.spec
|
|
203
|
+
pod.metadata = merged_pod.metadata
|
|
204
|
+
|
|
205
|
+
|
|
100
206
|
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
|
|
101
207
|
RetryPolicy.REPLICA: [],
|
|
102
208
|
RetryPolicy.APPLICATION: [
|
|
@@ -369,7 +475,7 @@ def app_to_resource(
|
|
|
369
475
|
queue: str,
|
|
370
476
|
service_account: Optional[str],
|
|
371
477
|
priority_class: Optional[str] = None,
|
|
372
|
-
) -> Dict[str,
|
|
478
|
+
) -> Dict[str, Any]:
|
|
373
479
|
"""
|
|
374
480
|
app_to_resource creates a volcano job kubernetes resource definition from
|
|
375
481
|
the provided AppDef. The resource definition can be used to launch the
|
|
@@ -399,8 +505,20 @@ def app_to_resource(
|
|
|
399
505
|
replica_role = values.apply(role)
|
|
400
506
|
if role_idx == 0 and replica_id == 0:
|
|
401
507
|
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
|
|
508
|
+
replica_role.env["TORCHX_IMAGE"] = replica_role.image
|
|
402
509
|
|
|
403
510
|
pod = role_to_pod(name, replica_role, service_account)
|
|
511
|
+
if k8s_metadata := role.metadata.get("kubernetes"):
|
|
512
|
+
if isinstance(k8s_metadata, str):
|
|
513
|
+
import fsspec
|
|
514
|
+
|
|
515
|
+
with fsspec.open(k8s_metadata, "r") as f:
|
|
516
|
+
k8s_metadata = yaml.unsafe_load(f)
|
|
517
|
+
elif not isinstance(k8s_metadata, dict):
|
|
518
|
+
raise ValueError(
|
|
519
|
+
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
|
|
520
|
+
)
|
|
521
|
+
_apply_pod_overlay(pod, k8s_metadata)
|
|
404
522
|
pod.metadata.labels.update(
|
|
405
523
|
pod_labels(
|
|
406
524
|
app=app,
|
|
@@ -443,7 +561,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
|
|
|
443
561
|
if priority_class is not None:
|
|
444
562
|
job_spec["priorityClassName"] = priority_class
|
|
445
563
|
|
|
446
|
-
resource: Dict[str,
|
|
564
|
+
resource: Dict[str, Any] = {
|
|
447
565
|
"apiVersion": "batch.volcano.sh/v1alpha1",
|
|
448
566
|
"kind": "Job",
|
|
449
567
|
"metadata": {"name": f"{unique_app_id}"},
|
|
@@ -455,7 +573,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
|
|
|
455
573
|
@dataclass
|
|
456
574
|
class KubernetesJob:
|
|
457
575
|
images_to_push: Dict[str, Tuple[str, str]]
|
|
458
|
-
resource: Dict[str,
|
|
576
|
+
resource: Dict[str, Any]
|
|
459
577
|
|
|
460
578
|
def __str__(self) -> str:
|
|
461
579
|
return yaml.dump(sanitize_for_serialization(self.resource))
|
|
@@ -470,6 +588,7 @@ class KubernetesOpts(TypedDict, total=False):
|
|
|
470
588
|
image_repo: Optional[str]
|
|
471
589
|
service_account: Optional[str]
|
|
472
590
|
priority_class: Optional[str]
|
|
591
|
+
validate_spec: Optional[bool]
|
|
473
592
|
|
|
474
593
|
|
|
475
594
|
class KubernetesScheduler(
|
|
@@ -485,7 +604,7 @@ class KubernetesScheduler(
|
|
|
485
604
|
For installation instructions see: https://github.com/volcano-sh/volcano
|
|
486
605
|
|
|
487
606
|
This has been confirmed to work with Volcano v1.3.0 and Kubernetes versions
|
|
488
|
-
v1.18-1.21. See https://github.com/pytorch/torchx/issues/120 which is
|
|
607
|
+
v1.18-1.21. See https://github.com/meta-pytorch/torchx/issues/120 which is
|
|
489
608
|
tracking Volcano support for Kubernetes v1.22.
|
|
490
609
|
|
|
491
610
|
.. note::
|
|
@@ -635,7 +754,7 @@ class KubernetesScheduler(
|
|
|
635
754
|
else:
|
|
636
755
|
raise
|
|
637
756
|
|
|
638
|
-
return f
|
|
757
|
+
return f"{namespace}:{resp['metadata']['name']}"
|
|
639
758
|
|
|
640
759
|
def _submit_dryrun(
|
|
641
760
|
self, app: AppDef, cfg: KubernetesOpts
|
|
@@ -658,6 +777,36 @@ class KubernetesScheduler(
|
|
|
658
777
|
), "priority_class must be a str"
|
|
659
778
|
|
|
660
779
|
resource = app_to_resource(app, queue, service_account, priority_class)
|
|
780
|
+
|
|
781
|
+
if cfg.get("validate_spec"):
|
|
782
|
+
try:
|
|
783
|
+
self._custom_objects_api().create_namespaced_custom_object(
|
|
784
|
+
group="batch.volcano.sh",
|
|
785
|
+
version="v1alpha1",
|
|
786
|
+
namespace=cfg.get("namespace") or "default",
|
|
787
|
+
plural="jobs",
|
|
788
|
+
body=resource,
|
|
789
|
+
dry_run="All",
|
|
790
|
+
)
|
|
791
|
+
except Exception as e:
|
|
792
|
+
from kubernetes.client.rest import ApiException
|
|
793
|
+
|
|
794
|
+
if isinstance(e, ApiException):
|
|
795
|
+
raise ValueError(f"Invalid job spec: {e.reason}") from e
|
|
796
|
+
raise
|
|
797
|
+
|
|
798
|
+
job_name = resource["metadata"]["name"]
|
|
799
|
+
for task in resource["spec"]["tasks"]:
|
|
800
|
+
task_name = task["name"]
|
|
801
|
+
replicas = task.get("replicas", 1)
|
|
802
|
+
max_index = replicas - 1
|
|
803
|
+
pod_name = f"{job_name}-{task_name}-{max_index}"
|
|
804
|
+
if len(pod_name) > 63:
|
|
805
|
+
raise ValueError(
|
|
806
|
+
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
|
|
807
|
+
f"Shorten app.name or role names"
|
|
808
|
+
)
|
|
809
|
+
|
|
661
810
|
req = KubernetesJob(
|
|
662
811
|
resource=resource,
|
|
663
812
|
images_to_push=images_to_push,
|
|
@@ -702,19 +851,32 @@ class KubernetesScheduler(
|
|
|
702
851
|
type_=str,
|
|
703
852
|
help="The name of the PriorityClass to set on the job specs",
|
|
704
853
|
)
|
|
854
|
+
opts.add(
|
|
855
|
+
"validate_spec",
|
|
856
|
+
type_=bool,
|
|
857
|
+
help="Validate job spec using Kubernetes API dry-run before submission",
|
|
858
|
+
default=True,
|
|
859
|
+
)
|
|
705
860
|
return opts
|
|
706
861
|
|
|
707
862
|
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
863
|
+
from kubernetes.client.rest import ApiException
|
|
864
|
+
|
|
708
865
|
namespace, name = app_id.split(":")
|
|
709
866
|
roles = {}
|
|
710
867
|
roles_statuses = {}
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
868
|
+
try:
|
|
869
|
+
resp = self._custom_objects_api().get_namespaced_custom_object_status(
|
|
870
|
+
group="batch.volcano.sh",
|
|
871
|
+
version="v1alpha1",
|
|
872
|
+
namespace=namespace,
|
|
873
|
+
plural="jobs",
|
|
874
|
+
name=name,
|
|
875
|
+
)
|
|
876
|
+
except ApiException as e:
|
|
877
|
+
if e.status == 404:
|
|
878
|
+
return None
|
|
879
|
+
raise
|
|
718
880
|
status = resp.get("status")
|
|
719
881
|
if status:
|
|
720
882
|
state_str = status["state"]["phase"]
|
|
@@ -823,13 +985,34 @@ def create_scheduler(
|
|
|
823
985
|
def pod_labels(
|
|
824
986
|
app: AppDef, role_idx: int, role: Role, replica_id: int, app_id: str
|
|
825
987
|
) -> Dict[str, str]:
|
|
988
|
+
|
|
989
|
+
def clean(label_value: str) -> str:
|
|
990
|
+
# cleans the provided `label_value` to make it compliant
|
|
991
|
+
# to pod label specs as described in
|
|
992
|
+
# https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
|
|
993
|
+
#
|
|
994
|
+
# Valid label value:
|
|
995
|
+
# must be 63 characters or less (can be empty),
|
|
996
|
+
# unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]),
|
|
997
|
+
# could contain dashes (-), underscores (_), dots (.), and alphanumerics between.
|
|
998
|
+
|
|
999
|
+
# Replace invalid characters (allow: alphanum, -, _, .) with "."
|
|
1000
|
+
label_value = re.sub(r"[^A-Za-z0-9\-_.]", ".", label_value)
|
|
1001
|
+
# Replace leading non-alphanumeric with "."
|
|
1002
|
+
label_value = re.sub(r"^[^A-Za-z0-9]+", ".", label_value)
|
|
1003
|
+
# Replace trailing non-alphanumeric with "."
|
|
1004
|
+
label_value = re.sub(r"[^A-Za-z0-9]+$", ".", label_value)
|
|
1005
|
+
|
|
1006
|
+
# Trim to 63 characters
|
|
1007
|
+
return label_value[:63]
|
|
1008
|
+
|
|
826
1009
|
return {
|
|
827
|
-
LABEL_VERSION: torchx.__version__,
|
|
828
|
-
LABEL_APP_NAME: app.name,
|
|
1010
|
+
LABEL_VERSION: clean(torchx.__version__),
|
|
1011
|
+
LABEL_APP_NAME: clean(app.name),
|
|
829
1012
|
LABEL_ROLE_INDEX: str(role_idx),
|
|
830
|
-
LABEL_ROLE_NAME: role.name,
|
|
1013
|
+
LABEL_ROLE_NAME: clean(role.name),
|
|
831
1014
|
LABEL_REPLICA_ID: str(replica_id),
|
|
832
|
-
LABEL_KUBE_APP_NAME: app.name,
|
|
1015
|
+
LABEL_KUBE_APP_NAME: clean(app.name),
|
|
833
1016
|
LABEL_ORGANIZATION: "torchx.pytorch.org",
|
|
834
|
-
LABEL_UNIQUE_NAME: app_id,
|
|
1017
|
+
LABEL_UNIQUE_NAME: clean(app_id),
|
|
835
1018
|
}
|
|
@@ -18,6 +18,7 @@ import os.path
|
|
|
18
18
|
import shlex
|
|
19
19
|
import subprocess
|
|
20
20
|
import tempfile
|
|
21
|
+
import warnings
|
|
21
22
|
from dataclasses import dataclass
|
|
22
23
|
from datetime import datetime
|
|
23
24
|
from subprocess import CalledProcessError, PIPE
|
|
@@ -72,6 +73,64 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
|
|
|
72
73
|
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
|
|
73
74
|
|
|
74
75
|
|
|
76
|
+
def get_appstate_from_job(job: dict[str, object]) -> AppState:
|
|
77
|
+
# Prior to slurm-23.11, job_state was a string and not a list
|
|
78
|
+
job_state = job.get("job_state", None)
|
|
79
|
+
if isinstance(job_state, list):
|
|
80
|
+
return appstate_from_slurm_state(job_state[0])
|
|
81
|
+
else:
|
|
82
|
+
return appstate_from_slurm_state(str(job_state))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def version() -> Tuple[int, int]:
|
|
86
|
+
"""
|
|
87
|
+
Uses ``sinfo --version`` to get the slurm version. If the command fails, it
|
|
88
|
+
assumes the version is ``slurm 24.05.8``.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
-------
|
|
92
|
+
Tuple[int, int] slurm version as a tuple of ints (major, minor).
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
cmd = ["sinfo", "--version"]
|
|
96
|
+
try:
|
|
97
|
+
out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8")
|
|
98
|
+
except (CalledProcessError, FileNotFoundError):
|
|
99
|
+
out = "slurm 24.05.8"
|
|
100
|
+
warnings.warn(
|
|
101
|
+
"Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
|
|
102
|
+
"cluster's login or head node? This typically happens when running in `--dryrun`"
|
|
103
|
+
" mode. Assuming version is `slurm 24.05.8`.",
|
|
104
|
+
RuntimeWarning,
|
|
105
|
+
stacklevel=2,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# sinfo --version returns in the form "slurm 24.1.0"
|
|
109
|
+
_, version_literal = out.split(" ", maxsplit=2)
|
|
110
|
+
major, minor = [int(v) for v in version_literal.split(".")][:2]
|
|
111
|
+
|
|
112
|
+
return (major, minor)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _should_use_gpus_per_node_from_version() -> bool:
|
|
116
|
+
"""
|
|
117
|
+
Determine whether to use gpus-per-node based on automatically detected slurm version.
|
|
118
|
+
|
|
119
|
+
Change Reference: https://fburl.com/sqwqzxn6
|
|
120
|
+
> select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
slurm_24_11_0 = (24, 11)
|
|
127
|
+
slurm_version = version()
|
|
128
|
+
|
|
129
|
+
return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater
|
|
130
|
+
slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1]
|
|
131
|
+
) # Major version is equal and minor version is greater or equal
|
|
132
|
+
|
|
133
|
+
|
|
75
134
|
SBATCH_JOB_OPTIONS = {
|
|
76
135
|
"comment",
|
|
77
136
|
"mail-user",
|
|
@@ -81,6 +140,7 @@ SBATCH_GROUP_OPTIONS = {
|
|
|
81
140
|
"partition",
|
|
82
141
|
"time",
|
|
83
142
|
"constraint",
|
|
143
|
+
"qos",
|
|
84
144
|
}
|
|
85
145
|
|
|
86
146
|
log: logging.Logger = logging.getLogger(__name__)
|
|
@@ -106,6 +166,7 @@ SlurmOpts = TypedDict(
|
|
|
106
166
|
"mail-user": Optional[str],
|
|
107
167
|
"mail-type": Optional[str],
|
|
108
168
|
"job_dir": Optional[str],
|
|
169
|
+
"qos": Optional[str],
|
|
109
170
|
},
|
|
110
171
|
total=False,
|
|
111
172
|
)
|
|
@@ -126,7 +187,11 @@ class SlurmReplicaRequest:
|
|
|
126
187
|
|
|
127
188
|
@classmethod
|
|
128
189
|
def from_role(
|
|
129
|
-
cls,
|
|
190
|
+
cls,
|
|
191
|
+
name: str,
|
|
192
|
+
role: Role,
|
|
193
|
+
cfg: SlurmOpts,
|
|
194
|
+
nomem: bool,
|
|
130
195
|
) -> "SlurmReplicaRequest":
|
|
131
196
|
"""
|
|
132
197
|
``from_role`` creates a SlurmReplicaRequest for the specific role and
|
|
@@ -149,7 +214,12 @@ class SlurmReplicaRequest:
|
|
|
149
214
|
if not nomem and resource.memMB > 0:
|
|
150
215
|
sbatch_opts.setdefault("mem", str(resource.memMB))
|
|
151
216
|
if resource.gpu > 0:
|
|
152
|
-
|
|
217
|
+
# Use smart GPU allocation based on automatically detected Slurm version
|
|
218
|
+
if _should_use_gpus_per_node_from_version():
|
|
219
|
+
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
|
|
220
|
+
else:
|
|
221
|
+
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
|
|
222
|
+
sbatch_opts.setdefault("ntasks", "1")
|
|
153
223
|
|
|
154
224
|
srun_opts = {
|
|
155
225
|
"output": f"slurm-{macros.app_id}-{name}.out",
|
|
@@ -378,6 +448,11 @@ class SlurmScheduler(
|
|
|
378
448
|
iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
|
|
379
449
|
""",
|
|
380
450
|
)
|
|
451
|
+
opts.add(
|
|
452
|
+
"qos",
|
|
453
|
+
type_=str,
|
|
454
|
+
help="Quality of Service (QoS) to assign to the job.",
|
|
455
|
+
)
|
|
381
456
|
return opts
|
|
382
457
|
|
|
383
458
|
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
|
|
@@ -504,6 +579,8 @@ class SlurmScheduler(
|
|
|
504
579
|
return self._describe_sacct(app_id)
|
|
505
580
|
|
|
506
581
|
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
582
|
+
# NOTE: Handles multiple job ID formats due to SLURM version differences.
|
|
583
|
+
# Different clusters use heterogeneous (+) vs regular (.) job ID formats.
|
|
507
584
|
try:
|
|
508
585
|
output = subprocess.check_output(
|
|
509
586
|
["sacct", "--parsable2", "-j", app_id],
|
|
@@ -528,15 +605,27 @@ class SlurmScheduler(
|
|
|
528
605
|
msg = ""
|
|
529
606
|
app_state = AppState.UNKNOWN
|
|
530
607
|
for row in reader:
|
|
531
|
-
|
|
608
|
+
# Handle both "+" (heterogeneous) and "." (regular) job ID formats
|
|
609
|
+
job_id_full = row["JobID"]
|
|
610
|
+
|
|
611
|
+
# Split on both "+" and "." to handle different SLURM configurations
|
|
612
|
+
if "+" in job_id_full:
|
|
613
|
+
job_id, *parts = job_id_full.split("+")
|
|
614
|
+
is_subjob = len(parts) > 0 and "." in parts[0]
|
|
615
|
+
else:
|
|
616
|
+
job_id, *parts = job_id_full.split(".")
|
|
617
|
+
is_subjob = len(parts) > 0
|
|
618
|
+
|
|
532
619
|
if job_id != app_id:
|
|
533
620
|
continue
|
|
534
|
-
|
|
535
|
-
|
|
621
|
+
|
|
622
|
+
if is_subjob:
|
|
623
|
+
# we only care about the main job not the child jobs (.batch, .0, etc.)
|
|
536
624
|
continue
|
|
537
625
|
|
|
538
|
-
|
|
539
|
-
|
|
626
|
+
msg = row["State"]
|
|
627
|
+
# Remove truncation indicator (CANCELLED+) and extract base state from verbose formats
|
|
628
|
+
state = msg.split()[0].rstrip("+")
|
|
540
629
|
app_state = appstate_from_slurm_state(state)
|
|
541
630
|
|
|
542
631
|
role, _, replica_id = row["JobName"].rpartition("-")
|
|
@@ -563,6 +652,9 @@ class SlurmScheduler(
|
|
|
563
652
|
)
|
|
564
653
|
|
|
565
654
|
def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
655
|
+
# NOTE: This method contains multiple compatibility checks for different SLURM versions
|
|
656
|
+
# due to API format changes across versions (20.02, 23.02, 24.05, 24.11+).
|
|
657
|
+
|
|
566
658
|
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
|
|
567
659
|
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
|
|
568
660
|
output = subprocess.check_output(
|
|
@@ -583,7 +675,7 @@ class SlurmScheduler(
|
|
|
583
675
|
|
|
584
676
|
entrypoint = job["command"]
|
|
585
677
|
image = job["current_working_directory"]
|
|
586
|
-
state =
|
|
678
|
+
state = get_appstate_from_job(job)
|
|
587
679
|
|
|
588
680
|
job_resources = job["job_resources"]
|
|
589
681
|
|
|
@@ -604,7 +696,18 @@ class SlurmScheduler(
|
|
|
604
696
|
if state == AppState.PENDING:
|
|
605
697
|
# NOTE: torchx launched jobs points to exactly one host
|
|
606
698
|
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
|
|
607
|
-
|
|
699
|
+
|
|
700
|
+
# SLURM 24.11.5+ returns job_resources=None for pending jobs (issue #1101)
|
|
701
|
+
if job_resources is not None:
|
|
702
|
+
hostname = job_resources.get("scheduled_nodes", "")
|
|
703
|
+
# If scheduled_nodes not found in job_resources, try nodes.list
|
|
704
|
+
if not hostname and "nodes" in job_resources:
|
|
705
|
+
nodes_info = job_resources.get("nodes", {})
|
|
706
|
+
if isinstance(nodes_info, dict):
|
|
707
|
+
hostname = nodes_info.get("list", "")
|
|
708
|
+
else:
|
|
709
|
+
# For pending jobs where job_resources is None, check top-level fields
|
|
710
|
+
hostname = job.get("nodes", "") or job.get("scheduled_nodes", "")
|
|
608
711
|
|
|
609
712
|
role.num_replicas += 1
|
|
610
713
|
role_status.replicas.append(
|
|
@@ -620,24 +723,35 @@ class SlurmScheduler(
|
|
|
620
723
|
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
|
|
621
724
|
# but we deal with jobs that have not been launched with torchx
|
|
622
725
|
# which can have multiple hosts per sub-job (count them as replicas)
|
|
623
|
-
|
|
726
|
+
nodes_data = job_resources.get("nodes", {})
|
|
727
|
+
|
|
728
|
+
# SLURM 24.11+ changed from allocated_nodes to nodes.allocation structure
|
|
729
|
+
if "allocation" in nodes_data and isinstance(
|
|
730
|
+
nodes_data["allocation"], list
|
|
731
|
+
):
|
|
732
|
+
# SLURM 24.11+ format: nodes.allocation is a list
|
|
733
|
+
for node_info in nodes_data["allocation"]:
|
|
734
|
+
hostname = node_info["name"]
|
|
735
|
+
cpu = int(node_info["cpus"]["used"])
|
|
736
|
+
memMB = (
|
|
737
|
+
int(node_info["memory"]["allocated"]) // 1024
|
|
738
|
+
) # Convert to MB
|
|
624
739
|
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
role=role_name,
|
|
635
|
-
state=state,
|
|
636
|
-
hostname=hostname,
|
|
740
|
+
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
|
|
741
|
+
role.num_replicas += 1
|
|
742
|
+
role_status.replicas.append(
|
|
743
|
+
ReplicaStatus(
|
|
744
|
+
id=int(replica_id),
|
|
745
|
+
role=role_name,
|
|
746
|
+
state=state,
|
|
747
|
+
hostname=hostname,
|
|
748
|
+
)
|
|
637
749
|
)
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
750
|
+
elif "allocated_nodes" in job_resources and isinstance(
|
|
751
|
+
job_resources["allocated_nodes"], list
|
|
752
|
+
):
|
|
753
|
+
# Legacy format: allocated_nodes is a list
|
|
754
|
+
for node_info in job_resources["allocated_nodes"]:
|
|
641
755
|
# NOTE: we expect resource specs for all the nodes to be the same
|
|
642
756
|
# NOTE: use allocated (not used/requested) memory since
|
|
643
757
|
# users may only specify --cpu, in which case slurm
|
|
@@ -660,6 +774,26 @@ class SlurmScheduler(
|
|
|
660
774
|
hostname=hostname,
|
|
661
775
|
)
|
|
662
776
|
)
|
|
777
|
+
else:
|
|
778
|
+
# Fallback: use hostname from nodes.list
|
|
779
|
+
if isinstance(nodes_data, str):
|
|
780
|
+
hostname = nodes_data
|
|
781
|
+
else:
|
|
782
|
+
hostname = (
|
|
783
|
+
nodes_data.get("list", "")
|
|
784
|
+
if isinstance(nodes_data, dict)
|
|
785
|
+
else ""
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
role.num_replicas += 1
|
|
789
|
+
role_status.replicas.append(
|
|
790
|
+
ReplicaStatus(
|
|
791
|
+
id=int(replica_id),
|
|
792
|
+
role=role_name,
|
|
793
|
+
state=state,
|
|
794
|
+
hostname=hostname,
|
|
795
|
+
)
|
|
796
|
+
)
|
|
663
797
|
|
|
664
798
|
return DescribeAppResponse(
|
|
665
799
|
app_id=app_id,
|
|
@@ -756,7 +890,7 @@ class SlurmScheduler(
|
|
|
756
890
|
out.append(
|
|
757
891
|
ListAppResponse(
|
|
758
892
|
app_id=str(job["job_id"]),
|
|
759
|
-
state=
|
|
893
|
+
state=get_appstate_from_job(job),
|
|
760
894
|
name=job["name"],
|
|
761
895
|
)
|
|
762
896
|
)
|