torchx-nightly 2025.8.5__py3-none-any.whl → 2025.11.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/cli/cmd_list.py +1 -2
- torchx/cli/cmd_run.py +202 -28
- torchx/cli/cmd_tracker.py +1 -1
- torchx/components/__init__.py +1 -8
- torchx/components/dist.py +9 -3
- torchx/components/integration_tests/component_provider.py +2 -2
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +92 -81
- torchx/runner/config.py +3 -1
- torchx/runner/events/__init__.py +20 -10
- torchx/runner/events/api.py +1 -1
- torchx/schedulers/__init__.py +7 -10
- torchx/schedulers/api.py +20 -15
- torchx/schedulers/aws_batch_scheduler.py +45 -2
- torchx/schedulers/docker_scheduler.py +3 -0
- torchx/schedulers/kubernetes_scheduler.py +200 -17
- torchx/schedulers/local_scheduler.py +1 -0
- torchx/schedulers/slurm_scheduler.py +93 -24
- torchx/specs/__init__.py +23 -6
- torchx/specs/api.py +219 -11
- torchx/specs/builders.py +109 -28
- torchx/specs/file_linter.py +117 -53
- torchx/specs/finder.py +25 -37
- torchx/specs/named_resources_aws.py +13 -2
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/util/strings.py +1 -1
- torchx/util/types.py +12 -1
- torchx/version.py +2 -2
- torchx/workspace/api.py +102 -5
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
- torchx/pipelines/kfp/__init__.py +0 -30
- torchx/pipelines/kfp/adapter.py +0 -274
- torchx/pipelines/kfp/version.py +0 -19
- torchx/schedulers/gcp_batch_scheduler.py +0 -497
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -454
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
|
@@ -27,10 +27,81 @@ Install Volcano:
|
|
|
27
27
|
See the
|
|
28
28
|
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
|
|
29
29
|
for more information.
|
|
30
|
+
|
|
31
|
+
Pod Overlay
|
|
32
|
+
===========
|
|
33
|
+
|
|
34
|
+
You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
|
|
35
|
+
the ``kubernetes`` metadata on your role. The value can be:
|
|
36
|
+
|
|
37
|
+
- A dict with the overlay structure
|
|
38
|
+
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
|
|
39
|
+
|
|
40
|
+
Merge semantics:
|
|
41
|
+
- **dict**: recursive merge (upsert)
|
|
42
|
+
- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
|
|
43
|
+
- **primitives**: replace
|
|
44
|
+
|
|
45
|
+
.. code:: python
|
|
46
|
+
|
|
47
|
+
from torchx.specs import Role
|
|
48
|
+
|
|
49
|
+
# Dict overlay - lists append, tuples replace
|
|
50
|
+
role = Role(
|
|
51
|
+
name="trainer",
|
|
52
|
+
image="my-image:latest",
|
|
53
|
+
entrypoint="train.py",
|
|
54
|
+
metadata={
|
|
55
|
+
"kubernetes": {
|
|
56
|
+
"spec": {
|
|
57
|
+
"nodeSelector": {"gpu": "true"},
|
|
58
|
+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
|
|
59
|
+
"volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# File URI overlay
|
|
66
|
+
role = Role(
|
|
67
|
+
name="trainer",
|
|
68
|
+
image="my-image:latest",
|
|
69
|
+
entrypoint="train.py",
|
|
70
|
+
metadata={
|
|
71
|
+
"kubernetes": "file:///path/to/pod_overlay.yaml"
|
|
72
|
+
}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
CLI usage with builtin components:
|
|
76
|
+
|
|
77
|
+
.. code:: bash
|
|
78
|
+
|
|
79
|
+
$ torchx run --scheduler kubernetes dist.ddp \\
|
|
80
|
+
--metadata kubernetes=file:///path/to/pod_overlay.yaml \\
|
|
81
|
+
--script train.py
|
|
82
|
+
|
|
83
|
+
Example ``pod_overlay.yaml``:
|
|
84
|
+
|
|
85
|
+
.. code:: yaml
|
|
86
|
+
|
|
87
|
+
spec:
|
|
88
|
+
nodeSelector:
|
|
89
|
+
node.kubernetes.io/instance-type: p4d.24xlarge
|
|
90
|
+
tolerations:
|
|
91
|
+
- key: nvidia.com/gpu
|
|
92
|
+
operator: Exists
|
|
93
|
+
effect: NoSchedule
|
|
94
|
+
volumes: !!python/tuple
|
|
95
|
+
- name: my-volume
|
|
96
|
+
emptyDir: {}
|
|
97
|
+
|
|
98
|
+
The overlay is deep-merged with the generated pod, preserving existing fields
|
|
99
|
+
and adding or overriding specified ones.
|
|
30
100
|
"""
|
|
31
101
|
|
|
32
102
|
import json
|
|
33
103
|
import logging
|
|
104
|
+
import re
|
|
34
105
|
import warnings
|
|
35
106
|
from dataclasses import dataclass
|
|
36
107
|
from datetime import datetime
|
|
@@ -45,6 +116,7 @@ from typing import (
|
|
|
45
116
|
Tuple,
|
|
46
117
|
TYPE_CHECKING,
|
|
47
118
|
TypedDict,
|
|
119
|
+
Union,
|
|
48
120
|
)
|
|
49
121
|
|
|
50
122
|
import torchx
|
|
@@ -97,6 +169,40 @@ logger: logging.Logger = logging.getLogger(__name__)
|
|
|
97
169
|
RESERVED_MILLICPU = 100
|
|
98
170
|
RESERVED_MEMMB = 1024
|
|
99
171
|
|
|
172
|
+
|
|
173
|
+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
|
|
174
|
+
"""Apply overlay dict to V1Pod object, merging nested fields.
|
|
175
|
+
|
|
176
|
+
Merge semantics:
|
|
177
|
+
- dict: upsert (recursive merge)
|
|
178
|
+
- list: append by default, replace if tuple
|
|
179
|
+
- primitives: replace
|
|
180
|
+
"""
|
|
181
|
+
from kubernetes import client
|
|
182
|
+
|
|
183
|
+
api = client.ApiClient()
|
|
184
|
+
pod_dict = api.sanitize_for_serialization(pod)
|
|
185
|
+
|
|
186
|
+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
|
|
187
|
+
for key, value in overlay.items():
|
|
188
|
+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
|
|
189
|
+
deep_merge(base[key], value)
|
|
190
|
+
elif isinstance(value, tuple):
|
|
191
|
+
base[key] = list(value)
|
|
192
|
+
elif (
|
|
193
|
+
isinstance(value, list) and key in base and isinstance(base[key], list)
|
|
194
|
+
):
|
|
195
|
+
base[key].extend(value)
|
|
196
|
+
else:
|
|
197
|
+
base[key] = value
|
|
198
|
+
|
|
199
|
+
deep_merge(pod_dict, overlay)
|
|
200
|
+
|
|
201
|
+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
|
|
202
|
+
pod.spec = merged_pod.spec
|
|
203
|
+
pod.metadata = merged_pod.metadata
|
|
204
|
+
|
|
205
|
+
|
|
100
206
|
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
|
|
101
207
|
RetryPolicy.REPLICA: [],
|
|
102
208
|
RetryPolicy.APPLICATION: [
|
|
@@ -369,7 +475,7 @@ def app_to_resource(
|
|
|
369
475
|
queue: str,
|
|
370
476
|
service_account: Optional[str],
|
|
371
477
|
priority_class: Optional[str] = None,
|
|
372
|
-
) -> Dict[str,
|
|
478
|
+
) -> Dict[str, Any]:
|
|
373
479
|
"""
|
|
374
480
|
app_to_resource creates a volcano job kubernetes resource definition from
|
|
375
481
|
the provided AppDef. The resource definition can be used to launch the
|
|
@@ -399,8 +505,20 @@ def app_to_resource(
|
|
|
399
505
|
replica_role = values.apply(role)
|
|
400
506
|
if role_idx == 0 and replica_id == 0:
|
|
401
507
|
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
|
|
508
|
+
replica_role.env["TORCHX_IMAGE"] = replica_role.image
|
|
402
509
|
|
|
403
510
|
pod = role_to_pod(name, replica_role, service_account)
|
|
511
|
+
if k8s_metadata := role.metadata.get("kubernetes"):
|
|
512
|
+
if isinstance(k8s_metadata, str):
|
|
513
|
+
import fsspec
|
|
514
|
+
|
|
515
|
+
with fsspec.open(k8s_metadata, "r") as f:
|
|
516
|
+
k8s_metadata = yaml.unsafe_load(f)
|
|
517
|
+
elif not isinstance(k8s_metadata, dict):
|
|
518
|
+
raise ValueError(
|
|
519
|
+
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
|
|
520
|
+
)
|
|
521
|
+
_apply_pod_overlay(pod, k8s_metadata)
|
|
404
522
|
pod.metadata.labels.update(
|
|
405
523
|
pod_labels(
|
|
406
524
|
app=app,
|
|
@@ -443,7 +561,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
|
|
|
443
561
|
if priority_class is not None:
|
|
444
562
|
job_spec["priorityClassName"] = priority_class
|
|
445
563
|
|
|
446
|
-
resource: Dict[str,
|
|
564
|
+
resource: Dict[str, Any] = {
|
|
447
565
|
"apiVersion": "batch.volcano.sh/v1alpha1",
|
|
448
566
|
"kind": "Job",
|
|
449
567
|
"metadata": {"name": f"{unique_app_id}"},
|
|
@@ -455,7 +573,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
|
|
|
455
573
|
@dataclass
|
|
456
574
|
class KubernetesJob:
|
|
457
575
|
images_to_push: Dict[str, Tuple[str, str]]
|
|
458
|
-
resource: Dict[str,
|
|
576
|
+
resource: Dict[str, Any]
|
|
459
577
|
|
|
460
578
|
def __str__(self) -> str:
|
|
461
579
|
return yaml.dump(sanitize_for_serialization(self.resource))
|
|
@@ -470,6 +588,7 @@ class KubernetesOpts(TypedDict, total=False):
|
|
|
470
588
|
image_repo: Optional[str]
|
|
471
589
|
service_account: Optional[str]
|
|
472
590
|
priority_class: Optional[str]
|
|
591
|
+
validate_spec: Optional[bool]
|
|
473
592
|
|
|
474
593
|
|
|
475
594
|
class KubernetesScheduler(
|
|
@@ -485,7 +604,7 @@ class KubernetesScheduler(
|
|
|
485
604
|
For installation instructions see: https://github.com/volcano-sh/volcano
|
|
486
605
|
|
|
487
606
|
This has been confirmed to work with Volcano v1.3.0 and Kubernetes versions
|
|
488
|
-
v1.18-1.21. See https://github.com/pytorch/torchx/issues/120 which is
|
|
607
|
+
v1.18-1.21. See https://github.com/meta-pytorch/torchx/issues/120 which is
|
|
489
608
|
tracking Volcano support for Kubernetes v1.22.
|
|
490
609
|
|
|
491
610
|
.. note::
|
|
@@ -635,7 +754,7 @@ class KubernetesScheduler(
|
|
|
635
754
|
else:
|
|
636
755
|
raise
|
|
637
756
|
|
|
638
|
-
return f
|
|
757
|
+
return f"{namespace}:{resp['metadata']['name']}"
|
|
639
758
|
|
|
640
759
|
def _submit_dryrun(
|
|
641
760
|
self, app: AppDef, cfg: KubernetesOpts
|
|
@@ -658,6 +777,36 @@ class KubernetesScheduler(
|
|
|
658
777
|
), "priority_class must be a str"
|
|
659
778
|
|
|
660
779
|
resource = app_to_resource(app, queue, service_account, priority_class)
|
|
780
|
+
|
|
781
|
+
if cfg.get("validate_spec"):
|
|
782
|
+
try:
|
|
783
|
+
self._custom_objects_api().create_namespaced_custom_object(
|
|
784
|
+
group="batch.volcano.sh",
|
|
785
|
+
version="v1alpha1",
|
|
786
|
+
namespace=cfg.get("namespace") or "default",
|
|
787
|
+
plural="jobs",
|
|
788
|
+
body=resource,
|
|
789
|
+
dry_run="All",
|
|
790
|
+
)
|
|
791
|
+
except Exception as e:
|
|
792
|
+
from kubernetes.client.rest import ApiException
|
|
793
|
+
|
|
794
|
+
if isinstance(e, ApiException):
|
|
795
|
+
raise ValueError(f"Invalid job spec: {e.reason}") from e
|
|
796
|
+
raise
|
|
797
|
+
|
|
798
|
+
job_name = resource["metadata"]["name"]
|
|
799
|
+
for task in resource["spec"]["tasks"]:
|
|
800
|
+
task_name = task["name"]
|
|
801
|
+
replicas = task.get("replicas", 1)
|
|
802
|
+
max_index = replicas - 1
|
|
803
|
+
pod_name = f"{job_name}-{task_name}-{max_index}"
|
|
804
|
+
if len(pod_name) > 63:
|
|
805
|
+
raise ValueError(
|
|
806
|
+
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
|
|
807
|
+
f"Shorten app.name or role names"
|
|
808
|
+
)
|
|
809
|
+
|
|
661
810
|
req = KubernetesJob(
|
|
662
811
|
resource=resource,
|
|
663
812
|
images_to_push=images_to_push,
|
|
@@ -702,19 +851,32 @@ class KubernetesScheduler(
|
|
|
702
851
|
type_=str,
|
|
703
852
|
help="The name of the PriorityClass to set on the job specs",
|
|
704
853
|
)
|
|
854
|
+
opts.add(
|
|
855
|
+
"validate_spec",
|
|
856
|
+
type_=bool,
|
|
857
|
+
help="Validate job spec using Kubernetes API dry-run before submission",
|
|
858
|
+
default=True,
|
|
859
|
+
)
|
|
705
860
|
return opts
|
|
706
861
|
|
|
707
862
|
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
863
|
+
from kubernetes.client.rest import ApiException
|
|
864
|
+
|
|
708
865
|
namespace, name = app_id.split(":")
|
|
709
866
|
roles = {}
|
|
710
867
|
roles_statuses = {}
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
868
|
+
try:
|
|
869
|
+
resp = self._custom_objects_api().get_namespaced_custom_object_status(
|
|
870
|
+
group="batch.volcano.sh",
|
|
871
|
+
version="v1alpha1",
|
|
872
|
+
namespace=namespace,
|
|
873
|
+
plural="jobs",
|
|
874
|
+
name=name,
|
|
875
|
+
)
|
|
876
|
+
except ApiException as e:
|
|
877
|
+
if e.status == 404:
|
|
878
|
+
return None
|
|
879
|
+
raise
|
|
718
880
|
status = resp.get("status")
|
|
719
881
|
if status:
|
|
720
882
|
state_str = status["state"]["phase"]
|
|
@@ -823,13 +985,34 @@ def create_scheduler(
|
|
|
823
985
|
def pod_labels(
|
|
824
986
|
app: AppDef, role_idx: int, role: Role, replica_id: int, app_id: str
|
|
825
987
|
) -> Dict[str, str]:
|
|
988
|
+
|
|
989
|
+
def clean(label_value: str) -> str:
|
|
990
|
+
# cleans the provided `label_value` to make it compliant
|
|
991
|
+
# to pod label specs as described in
|
|
992
|
+
# https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
|
|
993
|
+
#
|
|
994
|
+
# Valid label value:
|
|
995
|
+
# must be 63 characters or less (can be empty),
|
|
996
|
+
# unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]),
|
|
997
|
+
# could contain dashes (-), underscores (_), dots (.), and alphanumerics between.
|
|
998
|
+
|
|
999
|
+
# Replace invalid characters (allow: alphanum, -, _, .) with "."
|
|
1000
|
+
label_value = re.sub(r"[^A-Za-z0-9\-_.]", ".", label_value)
|
|
1001
|
+
# Replace leading non-alphanumeric with "."
|
|
1002
|
+
label_value = re.sub(r"^[^A-Za-z0-9]+", ".", label_value)
|
|
1003
|
+
# Replace trailing non-alphanumeric with "."
|
|
1004
|
+
label_value = re.sub(r"[^A-Za-z0-9]+$", ".", label_value)
|
|
1005
|
+
|
|
1006
|
+
# Trim to 63 characters
|
|
1007
|
+
return label_value[:63]
|
|
1008
|
+
|
|
826
1009
|
return {
|
|
827
|
-
LABEL_VERSION: torchx.__version__,
|
|
828
|
-
LABEL_APP_NAME: app.name,
|
|
1010
|
+
LABEL_VERSION: clean(torchx.__version__),
|
|
1011
|
+
LABEL_APP_NAME: clean(app.name),
|
|
829
1012
|
LABEL_ROLE_INDEX: str(role_idx),
|
|
830
|
-
LABEL_ROLE_NAME: role.name,
|
|
1013
|
+
LABEL_ROLE_NAME: clean(role.name),
|
|
831
1014
|
LABEL_REPLICA_ID: str(replica_id),
|
|
832
|
-
LABEL_KUBE_APP_NAME: app.name,
|
|
1015
|
+
LABEL_KUBE_APP_NAME: clean(app.name),
|
|
833
1016
|
LABEL_ORGANIZATION: "torchx.pytorch.org",
|
|
834
|
-
LABEL_UNIQUE_NAME: app_id,
|
|
1017
|
+
LABEL_UNIQUE_NAME: clean(app_id),
|
|
835
1018
|
}
|
|
@@ -73,6 +73,15 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
|
|
|
73
73
|
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
+
def get_appstate_from_job(job: dict[str, object]) -> AppState:
|
|
77
|
+
# Prior to slurm-23.11, job_state was a string and not a list
|
|
78
|
+
job_state = job.get("job_state", None)
|
|
79
|
+
if isinstance(job_state, list):
|
|
80
|
+
return appstate_from_slurm_state(job_state[0])
|
|
81
|
+
else:
|
|
82
|
+
return appstate_from_slurm_state(str(job_state))
|
|
83
|
+
|
|
84
|
+
|
|
76
85
|
def version() -> Tuple[int, int]:
|
|
77
86
|
"""
|
|
78
87
|
Uses ``sinfo --version`` to get the slurm version. If the command fails, it
|
|
@@ -210,6 +219,7 @@ class SlurmReplicaRequest:
|
|
|
210
219
|
sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
|
|
211
220
|
else:
|
|
212
221
|
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
|
|
222
|
+
sbatch_opts.setdefault("ntasks", "1")
|
|
213
223
|
|
|
214
224
|
srun_opts = {
|
|
215
225
|
"output": f"slurm-{macros.app_id}-{name}.out",
|
|
@@ -569,6 +579,8 @@ class SlurmScheduler(
|
|
|
569
579
|
return self._describe_sacct(app_id)
|
|
570
580
|
|
|
571
581
|
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
582
|
+
# NOTE: Handles multiple job ID formats due to SLURM version differences.
|
|
583
|
+
# Different clusters use heterogeneous (+) vs regular (.) job ID formats.
|
|
572
584
|
try:
|
|
573
585
|
output = subprocess.check_output(
|
|
574
586
|
["sacct", "--parsable2", "-j", app_id],
|
|
@@ -593,15 +605,27 @@ class SlurmScheduler(
|
|
|
593
605
|
msg = ""
|
|
594
606
|
app_state = AppState.UNKNOWN
|
|
595
607
|
for row in reader:
|
|
596
|
-
|
|
608
|
+
# Handle both "+" (heterogeneous) and "." (regular) job ID formats
|
|
609
|
+
job_id_full = row["JobID"]
|
|
610
|
+
|
|
611
|
+
# Split on both "+" and "." to handle different SLURM configurations
|
|
612
|
+
if "+" in job_id_full:
|
|
613
|
+
job_id, *parts = job_id_full.split("+")
|
|
614
|
+
is_subjob = len(parts) > 0 and "." in parts[0]
|
|
615
|
+
else:
|
|
616
|
+
job_id, *parts = job_id_full.split(".")
|
|
617
|
+
is_subjob = len(parts) > 0
|
|
618
|
+
|
|
597
619
|
if job_id != app_id:
|
|
598
620
|
continue
|
|
599
|
-
|
|
600
|
-
|
|
621
|
+
|
|
622
|
+
if is_subjob:
|
|
623
|
+
# we only care about the main job not the child jobs (.batch, .0, etc.)
|
|
601
624
|
continue
|
|
602
625
|
|
|
603
|
-
|
|
604
|
-
|
|
626
|
+
msg = row["State"]
|
|
627
|
+
# Remove truncation indicator (CANCELLED+) and extract base state from verbose formats
|
|
628
|
+
state = msg.split()[0].rstrip("+")
|
|
605
629
|
app_state = appstate_from_slurm_state(state)
|
|
606
630
|
|
|
607
631
|
role, _, replica_id = row["JobName"].rpartition("-")
|
|
@@ -628,6 +652,9 @@ class SlurmScheduler(
|
|
|
628
652
|
)
|
|
629
653
|
|
|
630
654
|
def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
655
|
+
# NOTE: This method contains multiple compatibility checks for different SLURM versions
|
|
656
|
+
# due to API format changes across versions (20.02, 23.02, 24.05, 24.11+).
|
|
657
|
+
|
|
631
658
|
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
|
|
632
659
|
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
|
|
633
660
|
output = subprocess.check_output(
|
|
@@ -648,7 +675,7 @@ class SlurmScheduler(
|
|
|
648
675
|
|
|
649
676
|
entrypoint = job["command"]
|
|
650
677
|
image = job["current_working_directory"]
|
|
651
|
-
state =
|
|
678
|
+
state = get_appstate_from_job(job)
|
|
652
679
|
|
|
653
680
|
job_resources = job["job_resources"]
|
|
654
681
|
|
|
@@ -669,7 +696,18 @@ class SlurmScheduler(
|
|
|
669
696
|
if state == AppState.PENDING:
|
|
670
697
|
# NOTE: torchx launched jobs points to exactly one host
|
|
671
698
|
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
|
|
672
|
-
|
|
699
|
+
|
|
700
|
+
# SLURM 24.11.5+ returns job_resources=None for pending jobs (issue #1101)
|
|
701
|
+
if job_resources is not None:
|
|
702
|
+
hostname = job_resources.get("scheduled_nodes", "")
|
|
703
|
+
# If scheduled_nodes not found in job_resources, try nodes.list
|
|
704
|
+
if not hostname and "nodes" in job_resources:
|
|
705
|
+
nodes_info = job_resources.get("nodes", {})
|
|
706
|
+
if isinstance(nodes_info, dict):
|
|
707
|
+
hostname = nodes_info.get("list", "")
|
|
708
|
+
else:
|
|
709
|
+
# For pending jobs where job_resources is None, check top-level fields
|
|
710
|
+
hostname = job.get("nodes", "") or job.get("scheduled_nodes", "")
|
|
673
711
|
|
|
674
712
|
role.num_replicas += 1
|
|
675
713
|
role_status.replicas.append(
|
|
@@ -685,24 +723,35 @@ class SlurmScheduler(
|
|
|
685
723
|
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
|
|
686
724
|
# but we deal with jobs that have not been launched with torchx
|
|
687
725
|
# which can have multiple hosts per sub-job (count them as replicas)
|
|
688
|
-
|
|
726
|
+
nodes_data = job_resources.get("nodes", {})
|
|
727
|
+
|
|
728
|
+
# SLURM 24.11+ changed from allocated_nodes to nodes.allocation structure
|
|
729
|
+
if "allocation" in nodes_data and isinstance(
|
|
730
|
+
nodes_data["allocation"], list
|
|
731
|
+
):
|
|
732
|
+
# SLURM 24.11+ format: nodes.allocation is a list
|
|
733
|
+
for node_info in nodes_data["allocation"]:
|
|
734
|
+
hostname = node_info["name"]
|
|
735
|
+
cpu = int(node_info["cpus"]["used"])
|
|
736
|
+
memMB = (
|
|
737
|
+
int(node_info["memory"]["allocated"]) // 1024
|
|
738
|
+
) # Convert to MB
|
|
689
739
|
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
role=role_name,
|
|
700
|
-
state=state,
|
|
701
|
-
hostname=hostname,
|
|
740
|
+
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
|
|
741
|
+
role.num_replicas += 1
|
|
742
|
+
role_status.replicas.append(
|
|
743
|
+
ReplicaStatus(
|
|
744
|
+
id=int(replica_id),
|
|
745
|
+
role=role_name,
|
|
746
|
+
state=state,
|
|
747
|
+
hostname=hostname,
|
|
748
|
+
)
|
|
702
749
|
)
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
750
|
+
elif "allocated_nodes" in job_resources and isinstance(
|
|
751
|
+
job_resources["allocated_nodes"], list
|
|
752
|
+
):
|
|
753
|
+
# Legacy format: allocated_nodes is a list
|
|
754
|
+
for node_info in job_resources["allocated_nodes"]:
|
|
706
755
|
# NOTE: we expect resource specs for all the nodes to be the same
|
|
707
756
|
# NOTE: use allocated (not used/requested) memory since
|
|
708
757
|
# users may only specify --cpu, in which case slurm
|
|
@@ -725,6 +774,26 @@ class SlurmScheduler(
|
|
|
725
774
|
hostname=hostname,
|
|
726
775
|
)
|
|
727
776
|
)
|
|
777
|
+
else:
|
|
778
|
+
# Fallback: use hostname from nodes.list
|
|
779
|
+
if isinstance(nodes_data, str):
|
|
780
|
+
hostname = nodes_data
|
|
781
|
+
else:
|
|
782
|
+
hostname = (
|
|
783
|
+
nodes_data.get("list", "")
|
|
784
|
+
if isinstance(nodes_data, dict)
|
|
785
|
+
else ""
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
role.num_replicas += 1
|
|
789
|
+
role_status.replicas.append(
|
|
790
|
+
ReplicaStatus(
|
|
791
|
+
id=int(replica_id),
|
|
792
|
+
role=role_name,
|
|
793
|
+
state=state,
|
|
794
|
+
hostname=hostname,
|
|
795
|
+
)
|
|
796
|
+
)
|
|
728
797
|
|
|
729
798
|
return DescribeAppResponse(
|
|
730
799
|
app_id=app_id,
|
|
@@ -821,7 +890,7 @@ class SlurmScheduler(
|
|
|
821
890
|
out.append(
|
|
822
891
|
ListAppResponse(
|
|
823
892
|
app_id=str(job["job_id"]),
|
|
824
|
-
state=
|
|
893
|
+
state=get_appstate_from_job(job),
|
|
825
894
|
name=job["name"],
|
|
826
895
|
)
|
|
827
896
|
)
|
torchx/specs/__init__.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
@@ -13,6 +12,8 @@ used by components to define the apps which can then be launched via a TorchX
|
|
|
13
12
|
scheduler or pipeline adapter.
|
|
14
13
|
"""
|
|
15
14
|
import difflib
|
|
15
|
+
|
|
16
|
+
import os
|
|
16
17
|
from typing import Callable, Dict, Mapping, Optional
|
|
17
18
|
|
|
18
19
|
from torchx.specs.api import (
|
|
@@ -42,9 +43,11 @@ from torchx.specs.api import (
|
|
|
42
43
|
RoleStatus,
|
|
43
44
|
runopt,
|
|
44
45
|
runopts,
|
|
46
|
+
TORCHX_HOME,
|
|
45
47
|
UnknownAppException,
|
|
46
48
|
UnknownSchedulerException,
|
|
47
49
|
VolumeMount,
|
|
50
|
+
Workspace,
|
|
48
51
|
)
|
|
49
52
|
from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts
|
|
50
53
|
|
|
@@ -52,14 +55,22 @@ from torchx.util.entrypoints import load_group
|
|
|
52
55
|
|
|
53
56
|
from torchx.util.modules import import_attr
|
|
54
57
|
|
|
55
|
-
|
|
58
|
+
GiB: int = 1024
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
ResourceFactory = Callable[[], Resource]
|
|
62
|
+
|
|
63
|
+
AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
|
|
56
64
|
"torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={}
|
|
57
65
|
)
|
|
58
|
-
GENERIC_NAMED_RESOURCES: Mapping[str,
|
|
66
|
+
GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
|
|
59
67
|
"torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
|
|
60
68
|
)
|
|
61
|
-
|
|
62
|
-
|
|
69
|
+
CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
|
|
70
|
+
os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"),
|
|
71
|
+
"NAMED_RESOURCES",
|
|
72
|
+
default={},
|
|
73
|
+
)
|
|
63
74
|
|
|
64
75
|
|
|
65
76
|
def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
|
|
@@ -69,6 +80,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
|
|
|
69
80
|
for name, resource in {
|
|
70
81
|
**GENERIC_NAMED_RESOURCES,
|
|
71
82
|
**AWS_NAMED_RESOURCES,
|
|
83
|
+
**CUSTOM_NAMED_RESOURCES,
|
|
72
84
|
**resource_methods,
|
|
73
85
|
}.items():
|
|
74
86
|
materialized_resources[name] = resource
|
|
@@ -122,7 +134,7 @@ def resource(
|
|
|
122
134
|
|
|
123
135
|
If ``h`` is specified then it is used to look up the
|
|
124
136
|
resource specs from the list of registered named resources.
|
|
125
|
-
See `registering named resource <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
137
|
+
See `registering named resource <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
126
138
|
|
|
127
139
|
Otherwise a ``Resource`` object is created from the raw resource specs.
|
|
128
140
|
|
|
@@ -225,5 +237,10 @@ __all__ = [
|
|
|
225
237
|
"make_app_handle",
|
|
226
238
|
"materialize_appdef",
|
|
227
239
|
"parse_mounts",
|
|
240
|
+
"torchx_run_args_from_argparse",
|
|
241
|
+
"torchx_run_args_from_json",
|
|
242
|
+
"TorchXRunArgs",
|
|
228
243
|
"ALL",
|
|
244
|
+
"TORCHX_HOME",
|
|
245
|
+
"Workspace",
|
|
229
246
|
]
|