torchx-nightly 2025.7.9__py3-none-any.whl → 2025.11.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  2. torchx/cli/cmd_list.py +1 -2
  3. torchx/cli/cmd_run.py +202 -28
  4. torchx/cli/cmd_tracker.py +1 -1
  5. torchx/components/__init__.py +1 -8
  6. torchx/components/dist.py +9 -3
  7. torchx/components/integration_tests/component_provider.py +2 -2
  8. torchx/components/utils.py +1 -1
  9. torchx/distributed/__init__.py +1 -1
  10. torchx/runner/api.py +92 -81
  11. torchx/runner/config.py +11 -9
  12. torchx/runner/events/__init__.py +20 -10
  13. torchx/runner/events/api.py +1 -1
  14. torchx/schedulers/__init__.py +7 -10
  15. torchx/schedulers/api.py +20 -15
  16. torchx/schedulers/aws_batch_scheduler.py +45 -2
  17. torchx/schedulers/docker_scheduler.py +3 -0
  18. torchx/schedulers/kubernetes_scheduler.py +200 -17
  19. torchx/schedulers/local_scheduler.py +1 -0
  20. torchx/schedulers/slurm_scheduler.py +160 -26
  21. torchx/specs/__init__.py +23 -6
  22. torchx/specs/api.py +279 -33
  23. torchx/specs/builders.py +109 -28
  24. torchx/specs/file_linter.py +117 -53
  25. torchx/specs/finder.py +25 -37
  26. torchx/specs/named_resources_aws.py +13 -2
  27. torchx/tracker/__init__.py +2 -2
  28. torchx/tracker/api.py +1 -1
  29. torchx/util/entrypoints.py +1 -6
  30. torchx/util/strings.py +1 -1
  31. torchx/util/types.py +12 -1
  32. torchx/version.py +2 -2
  33. torchx/workspace/api.py +102 -5
  34. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
  35. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
  36. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
  37. torchx/examples/pipelines/__init__.py +0 -0
  38. torchx/examples/pipelines/kfp/__init__.py +0 -0
  39. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
  40. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
  41. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
  42. torchx/pipelines/kfp/__init__.py +0 -30
  43. torchx/pipelines/kfp/adapter.py +0 -274
  44. torchx/pipelines/kfp/version.py +0 -19
  45. torchx/schedulers/gcp_batch_scheduler.py +0 -497
  46. torchx/schedulers/ray/ray_common.py +0 -22
  47. torchx/schedulers/ray/ray_driver.py +0 -307
  48. torchx/schedulers/ray_scheduler.py +0 -454
  49. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
  50. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
  51. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,81 @@ Install Volcano:
27
27
  See the
28
28
  `Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
29
29
  for more information.
30
+
31
+ Pod Overlay
32
+ ===========
33
+
34
+ You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
35
+ the ``kubernetes`` metadata on your role. The value can be:
36
+
37
+ - A dict with the overlay structure
38
+ - A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
39
+
40
+ Merge semantics:
41
+ - **dict**: recursive merge (upsert)
42
+ - **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
43
+ - **primitives**: replace
44
+
45
+ .. code:: python
46
+
47
+ from torchx.specs import Role
48
+
49
+ # Dict overlay - lists append, tuples replace
50
+ role = Role(
51
+ name="trainer",
52
+ image="my-image:latest",
53
+ entrypoint="train.py",
54
+ metadata={
55
+ "kubernetes": {
56
+ "spec": {
57
+ "nodeSelector": {"gpu": "true"},
58
+ "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
59
+ "volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
60
+ }
61
+ }
62
+ }
63
+ )
64
+
65
+ # File URI overlay
66
+ role = Role(
67
+ name="trainer",
68
+ image="my-image:latest",
69
+ entrypoint="train.py",
70
+ metadata={
71
+ "kubernetes": "file:///path/to/pod_overlay.yaml"
72
+ }
73
+ )
74
+
75
+ CLI usage with builtin components:
76
+
77
+ .. code:: bash
78
+
79
+ $ torchx run --scheduler kubernetes dist.ddp \\
80
+ --metadata kubernetes=file:///path/to/pod_overlay.yaml \\
81
+ --script train.py
82
+
83
+ Example ``pod_overlay.yaml``:
84
+
85
+ .. code:: yaml
86
+
87
+ spec:
88
+ nodeSelector:
89
+ node.kubernetes.io/instance-type: p4d.24xlarge
90
+ tolerations:
91
+ - key: nvidia.com/gpu
92
+ operator: Exists
93
+ effect: NoSchedule
94
+ volumes: !!python/tuple
95
+ - name: my-volume
96
+ emptyDir: {}
97
+
98
+ The overlay is deep-merged with the generated pod, preserving existing fields
99
+ and adding or overriding specified ones.
30
100
  """
31
101
 
32
102
  import json
33
103
  import logging
104
+ import re
34
105
  import warnings
35
106
  from dataclasses import dataclass
36
107
  from datetime import datetime
@@ -45,6 +116,7 @@ from typing import (
45
116
  Tuple,
46
117
  TYPE_CHECKING,
47
118
  TypedDict,
119
+ Union,
48
120
  )
49
121
 
50
122
  import torchx
@@ -97,6 +169,40 @@ logger: logging.Logger = logging.getLogger(__name__)
97
169
  RESERVED_MILLICPU = 100
98
170
  RESERVED_MEMMB = 1024
99
171
 
172
+
173
+ def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
174
+ """Apply overlay dict to V1Pod object, merging nested fields.
175
+
176
+ Merge semantics:
177
+ - dict: upsert (recursive merge)
178
+ - list: append by default, replace if tuple
179
+ - primitives: replace
180
+ """
181
+ from kubernetes import client
182
+
183
+ api = client.ApiClient()
184
+ pod_dict = api.sanitize_for_serialization(pod)
185
+
186
+ def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
187
+ for key, value in overlay.items():
188
+ if isinstance(value, dict) and key in base and isinstance(base[key], dict):
189
+ deep_merge(base[key], value)
190
+ elif isinstance(value, tuple):
191
+ base[key] = list(value)
192
+ elif (
193
+ isinstance(value, list) and key in base and isinstance(base[key], list)
194
+ ):
195
+ base[key].extend(value)
196
+ else:
197
+ base[key] = value
198
+
199
+ deep_merge(pod_dict, overlay)
200
+
201
+ merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
202
+ pod.spec = merged_pod.spec
203
+ pod.metadata = merged_pod.metadata
204
+
205
+
100
206
  RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101
207
  RetryPolicy.REPLICA: [],
102
208
  RetryPolicy.APPLICATION: [
@@ -369,7 +475,7 @@ def app_to_resource(
369
475
  queue: str,
370
476
  service_account: Optional[str],
371
477
  priority_class: Optional[str] = None,
372
- ) -> Dict[str, object]:
478
+ ) -> Dict[str, Any]:
373
479
  """
374
480
  app_to_resource creates a volcano job kubernetes resource definition from
375
481
  the provided AppDef. The resource definition can be used to launch the
@@ -399,8 +505,20 @@ def app_to_resource(
399
505
  replica_role = values.apply(role)
400
506
  if role_idx == 0 and replica_id == 0:
401
507
  replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
508
+ replica_role.env["TORCHX_IMAGE"] = replica_role.image
402
509
 
403
510
  pod = role_to_pod(name, replica_role, service_account)
511
+ if k8s_metadata := role.metadata.get("kubernetes"):
512
+ if isinstance(k8s_metadata, str):
513
+ import fsspec
514
+
515
+ with fsspec.open(k8s_metadata, "r") as f:
516
+ k8s_metadata = yaml.unsafe_load(f)
517
+ elif not isinstance(k8s_metadata, dict):
518
+ raise ValueError(
519
+ f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
520
+ )
521
+ _apply_pod_overlay(pod, k8s_metadata)
404
522
  pod.metadata.labels.update(
405
523
  pod_labels(
406
524
  app=app,
@@ -443,7 +561,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
443
561
  if priority_class is not None:
444
562
  job_spec["priorityClassName"] = priority_class
445
563
 
446
- resource: Dict[str, object] = {
564
+ resource: Dict[str, Any] = {
447
565
  "apiVersion": "batch.volcano.sh/v1alpha1",
448
566
  "kind": "Job",
449
567
  "metadata": {"name": f"{unique_app_id}"},
@@ -455,7 +573,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
455
573
  @dataclass
456
574
  class KubernetesJob:
457
575
  images_to_push: Dict[str, Tuple[str, str]]
458
- resource: Dict[str, object]
576
+ resource: Dict[str, Any]
459
577
 
460
578
  def __str__(self) -> str:
461
579
  return yaml.dump(sanitize_for_serialization(self.resource))
@@ -470,6 +588,7 @@ class KubernetesOpts(TypedDict, total=False):
470
588
  image_repo: Optional[str]
471
589
  service_account: Optional[str]
472
590
  priority_class: Optional[str]
591
+ validate_spec: Optional[bool]
473
592
 
474
593
 
475
594
  class KubernetesScheduler(
@@ -485,7 +604,7 @@ class KubernetesScheduler(
485
604
  For installation instructions see: https://github.com/volcano-sh/volcano
486
605
 
487
606
  This has been confirmed to work with Volcano v1.3.0 and Kubernetes versions
488
- v1.18-1.21. See https://github.com/pytorch/torchx/issues/120 which is
607
+ v1.18-1.21. See https://github.com/meta-pytorch/torchx/issues/120 which is
489
608
  tracking Volcano support for Kubernetes v1.22.
490
609
 
491
610
  .. note::
@@ -635,7 +754,7 @@ class KubernetesScheduler(
635
754
  else:
636
755
  raise
637
756
 
638
- return f'{namespace}:{resp["metadata"]["name"]}'
757
+ return f"{namespace}:{resp['metadata']['name']}"
639
758
 
640
759
  def _submit_dryrun(
641
760
  self, app: AppDef, cfg: KubernetesOpts
@@ -658,6 +777,36 @@ class KubernetesScheduler(
658
777
  ), "priority_class must be a str"
659
778
 
660
779
  resource = app_to_resource(app, queue, service_account, priority_class)
780
+
781
+ if cfg.get("validate_spec"):
782
+ try:
783
+ self._custom_objects_api().create_namespaced_custom_object(
784
+ group="batch.volcano.sh",
785
+ version="v1alpha1",
786
+ namespace=cfg.get("namespace") or "default",
787
+ plural="jobs",
788
+ body=resource,
789
+ dry_run="All",
790
+ )
791
+ except Exception as e:
792
+ from kubernetes.client.rest import ApiException
793
+
794
+ if isinstance(e, ApiException):
795
+ raise ValueError(f"Invalid job spec: {e.reason}") from e
796
+ raise
797
+
798
+ job_name = resource["metadata"]["name"]
799
+ for task in resource["spec"]["tasks"]:
800
+ task_name = task["name"]
801
+ replicas = task.get("replicas", 1)
802
+ max_index = replicas - 1
803
+ pod_name = f"{job_name}-{task_name}-{max_index}"
804
+ if len(pod_name) > 63:
805
+ raise ValueError(
806
+ f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
807
+ f"Shorten app.name or role names"
808
+ )
809
+
661
810
  req = KubernetesJob(
662
811
  resource=resource,
663
812
  images_to_push=images_to_push,
@@ -702,19 +851,32 @@ class KubernetesScheduler(
702
851
  type_=str,
703
852
  help="The name of the PriorityClass to set on the job specs",
704
853
  )
854
+ opts.add(
855
+ "validate_spec",
856
+ type_=bool,
857
+ help="Validate job spec using Kubernetes API dry-run before submission",
858
+ default=True,
859
+ )
705
860
  return opts
706
861
 
707
862
  def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
863
+ from kubernetes.client.rest import ApiException
864
+
708
865
  namespace, name = app_id.split(":")
709
866
  roles = {}
710
867
  roles_statuses = {}
711
- resp = self._custom_objects_api().get_namespaced_custom_object_status(
712
- group="batch.volcano.sh",
713
- version="v1alpha1",
714
- namespace=namespace,
715
- plural="jobs",
716
- name=name,
717
- )
868
+ try:
869
+ resp = self._custom_objects_api().get_namespaced_custom_object_status(
870
+ group="batch.volcano.sh",
871
+ version="v1alpha1",
872
+ namespace=namespace,
873
+ plural="jobs",
874
+ name=name,
875
+ )
876
+ except ApiException as e:
877
+ if e.status == 404:
878
+ return None
879
+ raise
718
880
  status = resp.get("status")
719
881
  if status:
720
882
  state_str = status["state"]["phase"]
@@ -823,13 +985,34 @@ def create_scheduler(
823
985
  def pod_labels(
824
986
  app: AppDef, role_idx: int, role: Role, replica_id: int, app_id: str
825
987
  ) -> Dict[str, str]:
988
+
989
+ def clean(label_value: str) -> str:
990
+ # cleans the provided `label_value` to make it compliant
991
+ # to pod label specs as described in
992
+ # https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
993
+ #
994
+ # Valid label value:
995
+ # must be 63 characters or less (can be empty),
996
+ # unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]),
997
+ # could contain dashes (-), underscores (_), dots (.), and alphanumerics between.
998
+
999
+ # Replace invalid characters (allow: alphanum, -, _, .) with "."
1000
+ label_value = re.sub(r"[^A-Za-z0-9\-_.]", ".", label_value)
1001
+ # Replace leading non-alphanumeric with "."
1002
+ label_value = re.sub(r"^[^A-Za-z0-9]+", ".", label_value)
1003
+ # Replace trailing non-alphanumeric with "."
1004
+ label_value = re.sub(r"[^A-Za-z0-9]+$", ".", label_value)
1005
+
1006
+ # Trim to 63 characters
1007
+ return label_value[:63]
1008
+
826
1009
  return {
827
- LABEL_VERSION: torchx.__version__,
828
- LABEL_APP_NAME: app.name,
1010
+ LABEL_VERSION: clean(torchx.__version__),
1011
+ LABEL_APP_NAME: clean(app.name),
829
1012
  LABEL_ROLE_INDEX: str(role_idx),
830
- LABEL_ROLE_NAME: role.name,
1013
+ LABEL_ROLE_NAME: clean(role.name),
831
1014
  LABEL_REPLICA_ID: str(replica_id),
832
- LABEL_KUBE_APP_NAME: app.name,
1015
+ LABEL_KUBE_APP_NAME: clean(app.name),
833
1016
  LABEL_ORGANIZATION: "torchx.pytorch.org",
834
- LABEL_UNIQUE_NAME: app_id,
1017
+ LABEL_UNIQUE_NAME: clean(app_id),
835
1018
  }
@@ -1159,6 +1159,7 @@ class LogIterator:
1159
1159
  self._check_finished() # check to see if app has finished running
1160
1160
 
1161
1161
  if os.path.isfile(self._log_file):
1162
+ time.sleep(0.1) # fix timing issue
1162
1163
  self._log_fp = open(
1163
1164
  self._log_file,
1164
1165
  mode="rt",
@@ -18,6 +18,7 @@ import os.path
18
18
  import shlex
19
19
  import subprocess
20
20
  import tempfile
21
+ import warnings
21
22
  from dataclasses import dataclass
22
23
  from datetime import datetime
23
24
  from subprocess import CalledProcessError, PIPE
@@ -72,6 +73,64 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
72
73
  return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
73
74
 
74
75
 
76
+ def get_appstate_from_job(job: dict[str, object]) -> AppState:
77
+ # Prior to slurm-23.11, job_state was a string and not a list
78
+ job_state = job.get("job_state", None)
79
+ if isinstance(job_state, list):
80
+ return appstate_from_slurm_state(job_state[0])
81
+ else:
82
+ return appstate_from_slurm_state(str(job_state))
83
+
84
+
85
+ def version() -> Tuple[int, int]:
86
+ """
87
+ Uses ``sinfo --version`` to get the slurm version. If the command fails, it
88
+ assumes the version is ``slurm 24.05.8``.
89
+
90
+ Returns:
91
+ -------
92
+ Tuple[int, int] slurm version as a tuple of ints (major, minor).
93
+ """
94
+
95
+ cmd = ["sinfo", "--version"]
96
+ try:
97
+ out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8")
98
+ except (CalledProcessError, FileNotFoundError):
99
+ out = "slurm 24.05.8"
100
+ warnings.warn(
101
+ "Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
102
+ "cluster's login or head node? This typically happens when running in `--dryrun`"
103
+ " mode. Assuming version is `slurm 24.05.8`.",
104
+ RuntimeWarning,
105
+ stacklevel=2,
106
+ )
107
+
108
+ # sinfo --version returns in the form "slurm 24.1.0"
109
+ _, version_literal = out.split(" ", maxsplit=2)
110
+ major, minor = [int(v) for v in version_literal.split(".")][:2]
111
+
112
+ return (major, minor)
113
+
114
+
115
+ def _should_use_gpus_per_node_from_version() -> bool:
116
+ """
117
+ Determine whether to use gpus-per-node based on automatically detected slurm version.
118
+
119
+ Change Reference: https://fburl.com/sqwqzxn6
120
+ > select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
121
+
122
+ Returns:
123
+ ``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
124
+ """
125
+
126
+ slurm_24_11_0 = (24, 11)
127
+ slurm_version = version()
128
+
129
+ return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater
130
+ slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1]
131
+ ) # Major version is equal and minor version is greater or equal
132
+
133
+
75
134
  SBATCH_JOB_OPTIONS = {
76
135
  "comment",
77
136
  "mail-user",
@@ -81,6 +140,7 @@ SBATCH_GROUP_OPTIONS = {
81
140
  "partition",
82
141
  "time",
83
142
  "constraint",
143
+ "qos",
84
144
  }
85
145
 
86
146
  log: logging.Logger = logging.getLogger(__name__)
@@ -106,6 +166,7 @@ SlurmOpts = TypedDict(
106
166
  "mail-user": Optional[str],
107
167
  "mail-type": Optional[str],
108
168
  "job_dir": Optional[str],
169
+ "qos": Optional[str],
109
170
  },
110
171
  total=False,
111
172
  )
@@ -126,7 +187,11 @@ class SlurmReplicaRequest:
126
187
 
127
188
  @classmethod
128
189
  def from_role(
129
- cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
190
+ cls,
191
+ name: str,
192
+ role: Role,
193
+ cfg: SlurmOpts,
194
+ nomem: bool,
130
195
  ) -> "SlurmReplicaRequest":
131
196
  """
132
197
  ``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -149,7 +214,12 @@ class SlurmReplicaRequest:
149
214
  if not nomem and resource.memMB > 0:
150
215
  sbatch_opts.setdefault("mem", str(resource.memMB))
151
216
  if resource.gpu > 0:
152
- sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
217
+ # Use smart GPU allocation based on automatically detected Slurm version
218
+ if _should_use_gpus_per_node_from_version():
219
+ sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
220
+ else:
221
+ sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
222
+ sbatch_opts.setdefault("ntasks", "1")
153
223
 
154
224
  srun_opts = {
155
225
  "output": f"slurm-{macros.app_id}-{name}.out",
@@ -378,6 +448,11 @@ class SlurmScheduler(
378
448
  iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
379
449
  """,
380
450
  )
451
+ opts.add(
452
+ "qos",
453
+ type_=str,
454
+ help="Quality of Service (QoS) to assign to the job.",
455
+ )
381
456
  return opts
382
457
 
383
458
  def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
@@ -504,6 +579,8 @@ class SlurmScheduler(
504
579
  return self._describe_sacct(app_id)
505
580
 
506
581
  def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
582
+ # NOTE: Handles multiple job ID formats due to SLURM version differences.
583
+ # Different clusters use heterogeneous (+) vs regular (.) job ID formats.
507
584
  try:
508
585
  output = subprocess.check_output(
509
586
  ["sacct", "--parsable2", "-j", app_id],
@@ -528,15 +605,27 @@ class SlurmScheduler(
528
605
  msg = ""
529
606
  app_state = AppState.UNKNOWN
530
607
  for row in reader:
531
- job_id, *parts = row["JobID"].split("+")
608
+ # Handle both "+" (heterogeneous) and "." (regular) job ID formats
609
+ job_id_full = row["JobID"]
610
+
611
+ # Split on both "+" and "." to handle different SLURM configurations
612
+ if "+" in job_id_full:
613
+ job_id, *parts = job_id_full.split("+")
614
+ is_subjob = len(parts) > 0 and "." in parts[0]
615
+ else:
616
+ job_id, *parts = job_id_full.split(".")
617
+ is_subjob = len(parts) > 0
618
+
532
619
  if job_id != app_id:
533
620
  continue
534
- if len(parts) > 0 and "." in parts[0]:
535
- # we only care about the worker not the child jobs
621
+
622
+ if is_subjob:
623
+ # we only care about the main job not the child jobs (.batch, .0, etc.)
536
624
  continue
537
625
 
538
- state = row["State"]
539
- msg = state
626
+ msg = row["State"]
627
+ # Remove truncation indicator (CANCELLED+) and extract base state from verbose formats
628
+ state = msg.split()[0].rstrip("+")
540
629
  app_state = appstate_from_slurm_state(state)
541
630
 
542
631
  role, _, replica_id = row["JobName"].rpartition("-")
@@ -563,6 +652,9 @@ class SlurmScheduler(
563
652
  )
564
653
 
565
654
  def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
655
+ # NOTE: This method contains multiple compatibility checks for different SLURM versions
656
+ # due to API format changes across versions (20.02, 23.02, 24.05, 24.11+).
657
+
566
658
  # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567
659
  # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568
660
  output = subprocess.check_output(
@@ -583,7 +675,7 @@ class SlurmScheduler(
583
675
 
584
676
  entrypoint = job["command"]
585
677
  image = job["current_working_directory"]
586
- state = appstate_from_slurm_state(job["job_state"][0])
678
+ state = get_appstate_from_job(job)
587
679
 
588
680
  job_resources = job["job_resources"]
589
681
 
@@ -604,7 +696,18 @@ class SlurmScheduler(
604
696
  if state == AppState.PENDING:
605
697
  # NOTE: torchx launched jobs points to exactly one host
606
698
  # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
607
- hostname = job_resources.get("scheduled_nodes", "")
699
+
700
+ # SLURM 24.11.5+ returns job_resources=None for pending jobs (issue #1101)
701
+ if job_resources is not None:
702
+ hostname = job_resources.get("scheduled_nodes", "")
703
+ # If scheduled_nodes not found in job_resources, try nodes.list
704
+ if not hostname and "nodes" in job_resources:
705
+ nodes_info = job_resources.get("nodes", {})
706
+ if isinstance(nodes_info, dict):
707
+ hostname = nodes_info.get("list", "")
708
+ else:
709
+ # For pending jobs where job_resources is None, check top-level fields
710
+ hostname = job.get("nodes", "") or job.get("scheduled_nodes", "")
608
711
 
609
712
  role.num_replicas += 1
610
713
  role_status.replicas.append(
@@ -620,24 +723,35 @@ class SlurmScheduler(
620
723
  # where each replica is a "sub-job" so `allocated_nodes` will always be 1
621
724
  # but we deal with jobs that have not been launched with torchx
622
725
  # which can have multiple hosts per sub-job (count them as replicas)
623
- node_infos = job_resources.get("allocated_nodes", [])
726
+ nodes_data = job_resources.get("nodes", {})
727
+
728
+ # SLURM 24.11+ changed from allocated_nodes to nodes.allocation structure
729
+ if "allocation" in nodes_data and isinstance(
730
+ nodes_data["allocation"], list
731
+ ):
732
+ # SLURM 24.11+ format: nodes.allocation is a list
733
+ for node_info in nodes_data["allocation"]:
734
+ hostname = node_info["name"]
735
+ cpu = int(node_info["cpus"]["used"])
736
+ memMB = (
737
+ int(node_info["memory"]["allocated"]) // 1024
738
+ ) # Convert to MB
624
739
 
625
- if not isinstance(node_infos, list):
626
- # NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
627
- # is not a list of individual nodes, but a map of the nodelist specs
628
- # in this case just use jobs[].job_resources.nodes
629
- hostname = job_resources.get("nodes")
630
- role.num_replicas += 1
631
- role_status.replicas.append(
632
- ReplicaStatus(
633
- id=int(replica_id),
634
- role=role_name,
635
- state=state,
636
- hostname=hostname,
740
+ role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
741
+ role.num_replicas += 1
742
+ role_status.replicas.append(
743
+ ReplicaStatus(
744
+ id=int(replica_id),
745
+ role=role_name,
746
+ state=state,
747
+ hostname=hostname,
748
+ )
637
749
  )
638
- )
639
- else:
640
- for node_info in node_infos:
750
+ elif "allocated_nodes" in job_resources and isinstance(
751
+ job_resources["allocated_nodes"], list
752
+ ):
753
+ # Legacy format: allocated_nodes is a list
754
+ for node_info in job_resources["allocated_nodes"]:
641
755
  # NOTE: we expect resource specs for all the nodes to be the same
642
756
  # NOTE: use allocated (not used/requested) memory since
643
757
  # users may only specify --cpu, in which case slurm
@@ -660,6 +774,26 @@ class SlurmScheduler(
660
774
  hostname=hostname,
661
775
  )
662
776
  )
777
+ else:
778
+ # Fallback: use hostname from nodes.list
779
+ if isinstance(nodes_data, str):
780
+ hostname = nodes_data
781
+ else:
782
+ hostname = (
783
+ nodes_data.get("list", "")
784
+ if isinstance(nodes_data, dict)
785
+ else ""
786
+ )
787
+
788
+ role.num_replicas += 1
789
+ role_status.replicas.append(
790
+ ReplicaStatus(
791
+ id=int(replica_id),
792
+ role=role_name,
793
+ state=state,
794
+ hostname=hostname,
795
+ )
796
+ )
663
797
 
664
798
  return DescribeAppResponse(
665
799
  app_id=app_id,
@@ -756,7 +890,7 @@ class SlurmScheduler(
756
890
  out.append(
757
891
  ListAppResponse(
758
892
  app_id=str(job["job_id"]),
759
- state=SLURM_STATES[job["job_state"][0]],
893
+ state=get_appstate_from_job(job),
760
894
  name=job["name"],
761
895
  )
762
896
  )