torchx-nightly 2025.8.5__py3-none-any.whl → 2025.11.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  2. torchx/cli/cmd_list.py +1 -2
  3. torchx/cli/cmd_run.py +202 -28
  4. torchx/cli/cmd_tracker.py +1 -1
  5. torchx/components/__init__.py +1 -8
  6. torchx/components/dist.py +9 -3
  7. torchx/components/integration_tests/component_provider.py +2 -2
  8. torchx/components/utils.py +1 -1
  9. torchx/distributed/__init__.py +1 -1
  10. torchx/runner/api.py +92 -81
  11. torchx/runner/config.py +3 -1
  12. torchx/runner/events/__init__.py +20 -10
  13. torchx/runner/events/api.py +1 -1
  14. torchx/schedulers/__init__.py +7 -10
  15. torchx/schedulers/api.py +20 -15
  16. torchx/schedulers/aws_batch_scheduler.py +45 -2
  17. torchx/schedulers/docker_scheduler.py +3 -0
  18. torchx/schedulers/kubernetes_scheduler.py +200 -17
  19. torchx/schedulers/local_scheduler.py +1 -0
  20. torchx/schedulers/slurm_scheduler.py +93 -24
  21. torchx/specs/__init__.py +23 -6
  22. torchx/specs/api.py +219 -11
  23. torchx/specs/builders.py +109 -28
  24. torchx/specs/file_linter.py +117 -53
  25. torchx/specs/finder.py +25 -37
  26. torchx/specs/named_resources_aws.py +13 -2
  27. torchx/tracker/__init__.py +2 -2
  28. torchx/tracker/api.py +1 -1
  29. torchx/util/entrypoints.py +1 -6
  30. torchx/util/strings.py +1 -1
  31. torchx/util/types.py +12 -1
  32. torchx/version.py +2 -2
  33. torchx/workspace/api.py +102 -5
  34. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
  35. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
  36. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
  37. torchx/examples/pipelines/__init__.py +0 -0
  38. torchx/examples/pipelines/kfp/__init__.py +0 -0
  39. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
  40. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
  41. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
  42. torchx/pipelines/kfp/__init__.py +0 -30
  43. torchx/pipelines/kfp/adapter.py +0 -274
  44. torchx/pipelines/kfp/version.py +0 -19
  45. torchx/schedulers/gcp_batch_scheduler.py +0 -497
  46. torchx/schedulers/ray/ray_common.py +0 -22
  47. torchx/schedulers/ray/ray_driver.py +0 -307
  48. torchx/schedulers/ray_scheduler.py +0 -454
  49. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
  50. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
  51. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,81 @@ Install Volcano:
27
27
  See the
28
28
  `Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
29
29
  for more information.
30
+
31
+ Pod Overlay
32
+ ===========
33
+
34
+ You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
35
+ the ``kubernetes`` metadata on your role. The value can be:
36
+
37
+ - A dict with the overlay structure
38
+ - A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
39
+
40
+ Merge semantics:
41
+ - **dict**: recursive merge (upsert)
42
+ - **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
43
+ - **primitives**: replace
44
+
45
+ .. code:: python
46
+
47
+ from torchx.specs import Role
48
+
49
+ # Dict overlay - lists append, tuples replace
50
+ role = Role(
51
+ name="trainer",
52
+ image="my-image:latest",
53
+ entrypoint="train.py",
54
+ metadata={
55
+ "kubernetes": {
56
+ "spec": {
57
+ "nodeSelector": {"gpu": "true"},
58
+ "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
59
+ "volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
60
+ }
61
+ }
62
+ }
63
+ )
64
+
65
+ # File URI overlay
66
+ role = Role(
67
+ name="trainer",
68
+ image="my-image:latest",
69
+ entrypoint="train.py",
70
+ metadata={
71
+ "kubernetes": "file:///path/to/pod_overlay.yaml"
72
+ }
73
+ )
74
+
75
+ CLI usage with builtin components:
76
+
77
+ .. code:: bash
78
+
79
+ $ torchx run --scheduler kubernetes dist.ddp \\
80
+ --metadata kubernetes=file:///path/to/pod_overlay.yaml \\
81
+ --script train.py
82
+
83
+ Example ``pod_overlay.yaml``:
84
+
85
+ .. code:: yaml
86
+
87
+ spec:
88
+ nodeSelector:
89
+ node.kubernetes.io/instance-type: p4d.24xlarge
90
+ tolerations:
91
+ - key: nvidia.com/gpu
92
+ operator: Exists
93
+ effect: NoSchedule
94
+ volumes: !!python/tuple
95
+ - name: my-volume
96
+ emptyDir: {}
97
+
98
+ The overlay is deep-merged with the generated pod, preserving existing fields
99
+ and adding or overriding specified ones.
30
100
  """
31
101
 
32
102
  import json
33
103
  import logging
104
+ import re
34
105
  import warnings
35
106
  from dataclasses import dataclass
36
107
  from datetime import datetime
@@ -45,6 +116,7 @@ from typing import (
45
116
  Tuple,
46
117
  TYPE_CHECKING,
47
118
  TypedDict,
119
+ Union,
48
120
  )
49
121
 
50
122
  import torchx
@@ -97,6 +169,40 @@ logger: logging.Logger = logging.getLogger(__name__)
97
169
  RESERVED_MILLICPU = 100
98
170
  RESERVED_MEMMB = 1024
99
171
 
172
+
173
+ def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
174
+ """Apply overlay dict to V1Pod object, merging nested fields.
175
+
176
+ Merge semantics:
177
+ - dict: upsert (recursive merge)
178
+ - list: append by default, replace if tuple
179
+ - primitives: replace
180
+ """
181
+ from kubernetes import client
182
+
183
+ api = client.ApiClient()
184
+ pod_dict = api.sanitize_for_serialization(pod)
185
+
186
+ def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
187
+ for key, value in overlay.items():
188
+ if isinstance(value, dict) and key in base and isinstance(base[key], dict):
189
+ deep_merge(base[key], value)
190
+ elif isinstance(value, tuple):
191
+ base[key] = list(value)
192
+ elif (
193
+ isinstance(value, list) and key in base and isinstance(base[key], list)
194
+ ):
195
+ base[key].extend(value)
196
+ else:
197
+ base[key] = value
198
+
199
+ deep_merge(pod_dict, overlay)
200
+
201
+ merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
202
+ pod.spec = merged_pod.spec
203
+ pod.metadata = merged_pod.metadata
204
+
205
+
100
206
  RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101
207
  RetryPolicy.REPLICA: [],
102
208
  RetryPolicy.APPLICATION: [
@@ -369,7 +475,7 @@ def app_to_resource(
369
475
  queue: str,
370
476
  service_account: Optional[str],
371
477
  priority_class: Optional[str] = None,
372
- ) -> Dict[str, object]:
478
+ ) -> Dict[str, Any]:
373
479
  """
374
480
  app_to_resource creates a volcano job kubernetes resource definition from
375
481
  the provided AppDef. The resource definition can be used to launch the
@@ -399,8 +505,20 @@ def app_to_resource(
399
505
  replica_role = values.apply(role)
400
506
  if role_idx == 0 and replica_id == 0:
401
507
  replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
508
+ replica_role.env["TORCHX_IMAGE"] = replica_role.image
402
509
 
403
510
  pod = role_to_pod(name, replica_role, service_account)
511
+ if k8s_metadata := role.metadata.get("kubernetes"):
512
+ if isinstance(k8s_metadata, str):
513
+ import fsspec
514
+
515
+ with fsspec.open(k8s_metadata, "r") as f:
516
+ k8s_metadata = yaml.unsafe_load(f)
517
+ elif not isinstance(k8s_metadata, dict):
518
+ raise ValueError(
519
+ f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
520
+ )
521
+ _apply_pod_overlay(pod, k8s_metadata)
404
522
  pod.metadata.labels.update(
405
523
  pod_labels(
406
524
  app=app,
@@ -443,7 +561,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
443
561
  if priority_class is not None:
444
562
  job_spec["priorityClassName"] = priority_class
445
563
 
446
- resource: Dict[str, object] = {
564
+ resource: Dict[str, Any] = {
447
565
  "apiVersion": "batch.volcano.sh/v1alpha1",
448
566
  "kind": "Job",
449
567
  "metadata": {"name": f"{unique_app_id}"},
@@ -455,7 +573,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
455
573
  @dataclass
456
574
  class KubernetesJob:
457
575
  images_to_push: Dict[str, Tuple[str, str]]
458
- resource: Dict[str, object]
576
+ resource: Dict[str, Any]
459
577
 
460
578
  def __str__(self) -> str:
461
579
  return yaml.dump(sanitize_for_serialization(self.resource))
@@ -470,6 +588,7 @@ class KubernetesOpts(TypedDict, total=False):
470
588
  image_repo: Optional[str]
471
589
  service_account: Optional[str]
472
590
  priority_class: Optional[str]
591
+ validate_spec: Optional[bool]
473
592
 
474
593
 
475
594
  class KubernetesScheduler(
@@ -485,7 +604,7 @@ class KubernetesScheduler(
485
604
  For installation instructions see: https://github.com/volcano-sh/volcano
486
605
 
487
606
  This has been confirmed to work with Volcano v1.3.0 and Kubernetes versions
488
- v1.18-1.21. See https://github.com/pytorch/torchx/issues/120 which is
607
+ v1.18-1.21. See https://github.com/meta-pytorch/torchx/issues/120 which is
489
608
  tracking Volcano support for Kubernetes v1.22.
490
609
 
491
610
  .. note::
@@ -635,7 +754,7 @@ class KubernetesScheduler(
635
754
  else:
636
755
  raise
637
756
 
638
- return f'{namespace}:{resp["metadata"]["name"]}'
757
+ return f"{namespace}:{resp['metadata']['name']}"
639
758
 
640
759
  def _submit_dryrun(
641
760
  self, app: AppDef, cfg: KubernetesOpts
@@ -658,6 +777,36 @@ class KubernetesScheduler(
658
777
  ), "priority_class must be a str"
659
778
 
660
779
  resource = app_to_resource(app, queue, service_account, priority_class)
780
+
781
+ if cfg.get("validate_spec"):
782
+ try:
783
+ self._custom_objects_api().create_namespaced_custom_object(
784
+ group="batch.volcano.sh",
785
+ version="v1alpha1",
786
+ namespace=cfg.get("namespace") or "default",
787
+ plural="jobs",
788
+ body=resource,
789
+ dry_run="All",
790
+ )
791
+ except Exception as e:
792
+ from kubernetes.client.rest import ApiException
793
+
794
+ if isinstance(e, ApiException):
795
+ raise ValueError(f"Invalid job spec: {e.reason}") from e
796
+ raise
797
+
798
+ job_name = resource["metadata"]["name"]
799
+ for task in resource["spec"]["tasks"]:
800
+ task_name = task["name"]
801
+ replicas = task.get("replicas", 1)
802
+ max_index = replicas - 1
803
+ pod_name = f"{job_name}-{task_name}-{max_index}"
804
+ if len(pod_name) > 63:
805
+ raise ValueError(
806
+ f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
807
+ f"Shorten app.name or role names"
808
+ )
809
+
661
810
  req = KubernetesJob(
662
811
  resource=resource,
663
812
  images_to_push=images_to_push,
@@ -702,19 +851,32 @@ class KubernetesScheduler(
702
851
  type_=str,
703
852
  help="The name of the PriorityClass to set on the job specs",
704
853
  )
854
+ opts.add(
855
+ "validate_spec",
856
+ type_=bool,
857
+ help="Validate job spec using Kubernetes API dry-run before submission",
858
+ default=True,
859
+ )
705
860
  return opts
706
861
 
707
862
  def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
863
+ from kubernetes.client.rest import ApiException
864
+
708
865
  namespace, name = app_id.split(":")
709
866
  roles = {}
710
867
  roles_statuses = {}
711
- resp = self._custom_objects_api().get_namespaced_custom_object_status(
712
- group="batch.volcano.sh",
713
- version="v1alpha1",
714
- namespace=namespace,
715
- plural="jobs",
716
- name=name,
717
- )
868
+ try:
869
+ resp = self._custom_objects_api().get_namespaced_custom_object_status(
870
+ group="batch.volcano.sh",
871
+ version="v1alpha1",
872
+ namespace=namespace,
873
+ plural="jobs",
874
+ name=name,
875
+ )
876
+ except ApiException as e:
877
+ if e.status == 404:
878
+ return None
879
+ raise
718
880
  status = resp.get("status")
719
881
  if status:
720
882
  state_str = status["state"]["phase"]
@@ -823,13 +985,34 @@ def create_scheduler(
823
985
  def pod_labels(
824
986
  app: AppDef, role_idx: int, role: Role, replica_id: int, app_id: str
825
987
  ) -> Dict[str, str]:
988
+
989
+ def clean(label_value: str) -> str:
990
+ # cleans the provided `label_value` to make it compliant
991
+ # to pod label specs as described in
992
+ # https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
993
+ #
994
+ # Valid label value:
995
+ # must be 63 characters or less (can be empty),
996
+ # unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]),
997
+ # could contain dashes (-), underscores (_), dots (.), and alphanumerics between.
998
+
999
+ # Replace invalid characters (allow: alphanum, -, _, .) with "."
1000
+ label_value = re.sub(r"[^A-Za-z0-9\-_.]", ".", label_value)
1001
+ # Replace leading non-alphanumeric with "."
1002
+ label_value = re.sub(r"^[^A-Za-z0-9]+", ".", label_value)
1003
+ # Replace trailing non-alphanumeric with "."
1004
+ label_value = re.sub(r"[^A-Za-z0-9]+$", ".", label_value)
1005
+
1006
+ # Trim to 63 characters
1007
+ return label_value[:63]
1008
+
826
1009
  return {
827
- LABEL_VERSION: torchx.__version__,
828
- LABEL_APP_NAME: app.name,
1010
+ LABEL_VERSION: clean(torchx.__version__),
1011
+ LABEL_APP_NAME: clean(app.name),
829
1012
  LABEL_ROLE_INDEX: str(role_idx),
830
- LABEL_ROLE_NAME: role.name,
1013
+ LABEL_ROLE_NAME: clean(role.name),
831
1014
  LABEL_REPLICA_ID: str(replica_id),
832
- LABEL_KUBE_APP_NAME: app.name,
1015
+ LABEL_KUBE_APP_NAME: clean(app.name),
833
1016
  LABEL_ORGANIZATION: "torchx.pytorch.org",
834
- LABEL_UNIQUE_NAME: app_id,
1017
+ LABEL_UNIQUE_NAME: clean(app_id),
835
1018
  }
@@ -1159,6 +1159,7 @@ class LogIterator:
1159
1159
  self._check_finished() # check to see if app has finished running
1160
1160
 
1161
1161
  if os.path.isfile(self._log_file):
1162
+ time.sleep(0.1) # fix timing issue
1162
1163
  self._log_fp = open(
1163
1164
  self._log_file,
1164
1165
  mode="rt",
@@ -73,6 +73,15 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
73
73
  return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74
74
 
75
75
 
76
+ def get_appstate_from_job(job: dict[str, object]) -> AppState:
77
+ # Prior to slurm-23.11, job_state was a string and not a list
78
+ job_state = job.get("job_state", None)
79
+ if isinstance(job_state, list):
80
+ return appstate_from_slurm_state(job_state[0])
81
+ else:
82
+ return appstate_from_slurm_state(str(job_state))
83
+
84
+
76
85
  def version() -> Tuple[int, int]:
77
86
  """
78
87
  Uses ``sinfo --version`` to get the slurm version. If the command fails, it
@@ -210,6 +219,7 @@ class SlurmReplicaRequest:
210
219
  sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
211
220
  else:
212
221
  sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
222
+ sbatch_opts.setdefault("ntasks", "1")
213
223
 
214
224
  srun_opts = {
215
225
  "output": f"slurm-{macros.app_id}-{name}.out",
@@ -569,6 +579,8 @@ class SlurmScheduler(
569
579
  return self._describe_sacct(app_id)
570
580
 
571
581
  def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
582
+ # NOTE: Handles multiple job ID formats due to SLURM version differences.
583
+ # Different clusters use heterogeneous (+) vs regular (.) job ID formats.
572
584
  try:
573
585
  output = subprocess.check_output(
574
586
  ["sacct", "--parsable2", "-j", app_id],
@@ -593,15 +605,27 @@ class SlurmScheduler(
593
605
  msg = ""
594
606
  app_state = AppState.UNKNOWN
595
607
  for row in reader:
596
- job_id, *parts = row["JobID"].split("+")
608
+ # Handle both "+" (heterogeneous) and "." (regular) job ID formats
609
+ job_id_full = row["JobID"]
610
+
611
+ # Split on both "+" and "." to handle different SLURM configurations
612
+ if "+" in job_id_full:
613
+ job_id, *parts = job_id_full.split("+")
614
+ is_subjob = len(parts) > 0 and "." in parts[0]
615
+ else:
616
+ job_id, *parts = job_id_full.split(".")
617
+ is_subjob = len(parts) > 0
618
+
597
619
  if job_id != app_id:
598
620
  continue
599
- if len(parts) > 0 and "." in parts[0]:
600
- # we only care about the worker not the child jobs
621
+
622
+ if is_subjob:
623
+ # we only care about the main job not the child jobs (.batch, .0, etc.)
601
624
  continue
602
625
 
603
- state = row["State"]
604
- msg = state
626
+ msg = row["State"]
627
+ # Remove truncation indicator (CANCELLED+) and extract base state from verbose formats
628
+ state = msg.split()[0].rstrip("+")
605
629
  app_state = appstate_from_slurm_state(state)
606
630
 
607
631
  role, _, replica_id = row["JobName"].rpartition("-")
@@ -628,6 +652,9 @@ class SlurmScheduler(
628
652
  )
629
653
 
630
654
  def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
655
+ # NOTE: This method contains multiple compatibility checks for different SLURM versions
656
+ # due to API format changes across versions (20.02, 23.02, 24.05, 24.11+).
657
+
631
658
  # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
632
659
  # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
633
660
  output = subprocess.check_output(
@@ -648,7 +675,7 @@ class SlurmScheduler(
648
675
 
649
676
  entrypoint = job["command"]
650
677
  image = job["current_working_directory"]
651
- state = appstate_from_slurm_state(job["job_state"][0])
678
+ state = get_appstate_from_job(job)
652
679
 
653
680
  job_resources = job["job_resources"]
654
681
 
@@ -669,7 +696,18 @@ class SlurmScheduler(
669
696
  if state == AppState.PENDING:
670
697
  # NOTE: torchx launched jobs points to exactly one host
671
698
  # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
672
- hostname = job_resources.get("scheduled_nodes", "")
699
+
700
+ # SLURM 24.11.5+ returns job_resources=None for pending jobs (issue #1101)
701
+ if job_resources is not None:
702
+ hostname = job_resources.get("scheduled_nodes", "")
703
+ # If scheduled_nodes not found in job_resources, try nodes.list
704
+ if not hostname and "nodes" in job_resources:
705
+ nodes_info = job_resources.get("nodes", {})
706
+ if isinstance(nodes_info, dict):
707
+ hostname = nodes_info.get("list", "")
708
+ else:
709
+ # For pending jobs where job_resources is None, check top-level fields
710
+ hostname = job.get("nodes", "") or job.get("scheduled_nodes", "")
673
711
 
674
712
  role.num_replicas += 1
675
713
  role_status.replicas.append(
@@ -685,24 +723,35 @@ class SlurmScheduler(
685
723
  # where each replica is a "sub-job" so `allocated_nodes` will always be 1
686
724
  # but we deal with jobs that have not been launched with torchx
687
725
  # which can have multiple hosts per sub-job (count them as replicas)
688
- node_infos = job_resources.get("allocated_nodes", [])
726
+ nodes_data = job_resources.get("nodes", {})
727
+
728
+ # SLURM 24.11+ changed from allocated_nodes to nodes.allocation structure
729
+ if "allocation" in nodes_data and isinstance(
730
+ nodes_data["allocation"], list
731
+ ):
732
+ # SLURM 24.11+ format: nodes.allocation is a list
733
+ for node_info in nodes_data["allocation"]:
734
+ hostname = node_info["name"]
735
+ cpu = int(node_info["cpus"]["used"])
736
+ memMB = (
737
+ int(node_info["memory"]["allocated"]) // 1024
738
+ ) # Convert to MB
689
739
 
690
- if not isinstance(node_infos, list):
691
- # NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
692
- # is not a list of individual nodes, but a map of the nodelist specs
693
- # in this case just use jobs[].job_resources.nodes
694
- hostname = job_resources.get("nodes")
695
- role.num_replicas += 1
696
- role_status.replicas.append(
697
- ReplicaStatus(
698
- id=int(replica_id),
699
- role=role_name,
700
- state=state,
701
- hostname=hostname,
740
+ role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
741
+ role.num_replicas += 1
742
+ role_status.replicas.append(
743
+ ReplicaStatus(
744
+ id=int(replica_id),
745
+ role=role_name,
746
+ state=state,
747
+ hostname=hostname,
748
+ )
702
749
  )
703
- )
704
- else:
705
- for node_info in node_infos:
750
+ elif "allocated_nodes" in job_resources and isinstance(
751
+ job_resources["allocated_nodes"], list
752
+ ):
753
+ # Legacy format: allocated_nodes is a list
754
+ for node_info in job_resources["allocated_nodes"]:
706
755
  # NOTE: we expect resource specs for all the nodes to be the same
707
756
  # NOTE: use allocated (not used/requested) memory since
708
757
  # users may only specify --cpu, in which case slurm
@@ -725,6 +774,26 @@ class SlurmScheduler(
725
774
  hostname=hostname,
726
775
  )
727
776
  )
777
+ else:
778
+ # Fallback: use hostname from nodes.list
779
+ if isinstance(nodes_data, str):
780
+ hostname = nodes_data
781
+ else:
782
+ hostname = (
783
+ nodes_data.get("list", "")
784
+ if isinstance(nodes_data, dict)
785
+ else ""
786
+ )
787
+
788
+ role.num_replicas += 1
789
+ role_status.replicas.append(
790
+ ReplicaStatus(
791
+ id=int(replica_id),
792
+ role=role_name,
793
+ state=state,
794
+ hostname=hostname,
795
+ )
796
+ )
728
797
 
729
798
  return DescribeAppResponse(
730
799
  app_id=app_id,
@@ -821,7 +890,7 @@ class SlurmScheduler(
821
890
  out.append(
822
891
  ListAppResponse(
823
892
  app_id=str(job["job_id"]),
824
- state=SLURM_STATES[job["job_state"][0]],
893
+ state=get_appstate_from_job(job),
825
894
  name=job["name"],
826
895
  )
827
896
  )
torchx/specs/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
@@ -13,6 +12,8 @@ used by components to define the apps which can then be launched via a TorchX
13
12
  scheduler or pipeline adapter.
14
13
  """
15
14
  import difflib
15
+
16
+ import os
16
17
  from typing import Callable, Dict, Mapping, Optional
17
18
 
18
19
  from torchx.specs.api import (
@@ -42,9 +43,11 @@ from torchx.specs.api import (
42
43
  RoleStatus,
43
44
  runopt,
44
45
  runopts,
46
+ TORCHX_HOME,
45
47
  UnknownAppException,
46
48
  UnknownSchedulerException,
47
49
  VolumeMount,
50
+ Workspace,
48
51
  )
49
52
  from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts
50
53
 
@@ -52,14 +55,22 @@ from torchx.util.entrypoints import load_group
52
55
 
53
56
  from torchx.util.modules import import_attr
54
57
 
55
- AWS_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr(
58
+ GiB: int = 1024
59
+
60
+
61
+ ResourceFactory = Callable[[], Resource]
62
+
63
+ AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
56
64
  "torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={}
57
65
  )
58
- GENERIC_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr(
66
+ GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
59
67
  "torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
60
68
  )
61
-
62
- GiB: int = 1024
69
+ CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
70
+ os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"),
71
+ "NAMED_RESOURCES",
72
+ default={},
73
+ )
63
74
 
64
75
 
65
76
  def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
@@ -69,6 +80,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
69
80
  for name, resource in {
70
81
  **GENERIC_NAMED_RESOURCES,
71
82
  **AWS_NAMED_RESOURCES,
83
+ **CUSTOM_NAMED_RESOURCES,
72
84
  **resource_methods,
73
85
  }.items():
74
86
  materialized_resources[name] = resource
@@ -122,7 +134,7 @@ def resource(
122
134
 
123
135
  If ``h`` is specified then it is used to look up the
124
136
  resource specs from the list of registered named resources.
125
- See `registering named resource <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
137
+ See `registering named resource <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
126
138
 
127
139
  Otherwise a ``Resource`` object is created from the raw resource specs.
128
140
 
@@ -225,5 +237,10 @@ __all__ = [
225
237
  "make_app_handle",
226
238
  "materialize_appdef",
227
239
  "parse_mounts",
240
+ "torchx_run_args_from_argparse",
241
+ "torchx_run_args_from_json",
242
+ "TorchXRunArgs",
228
243
  "ALL",
244
+ "TORCHX_HOME",
245
+ "Workspace",
229
246
  ]