torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +4 -3
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +75 -6
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +418 -30
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +181 -4
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
  108. torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
  109. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  This contains the TorchX Slurm scheduler which can be used to run TorchX
10
12
  components on a Slurm cluster.
@@ -16,13 +18,14 @@ import os.path
16
18
  import shlex
17
19
  import subprocess
18
20
  import tempfile
21
+ import warnings
19
22
  from dataclasses import dataclass
20
23
  from datetime import datetime
21
- from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
24
+ from subprocess import CalledProcessError, PIPE
25
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
22
26
 
23
27
  import torchx
24
28
  from torchx.schedulers.api import (
25
- AppDryRunInfo,
26
29
  DescribeAppResponse,
27
30
  filter_regex,
28
31
  ListAppResponse,
@@ -33,16 +36,17 @@ from torchx.schedulers.api import (
33
36
  from torchx.schedulers.local_scheduler import LogIterator
34
37
  from torchx.specs import (
35
38
  AppDef,
39
+ AppDryRunInfo,
36
40
  AppState,
37
41
  macros,
38
42
  NONE,
39
43
  ReplicaStatus,
44
+ Resource,
40
45
  Role,
41
46
  RoleStatus,
42
47
  runopts,
43
48
  )
44
49
  from torchx.workspace.dir_workspace import DirWorkspaceMixin
45
- from typing_extensions import TypedDict
46
50
 
47
51
  SLURM_JOB_DIRS = ".torchxslurmjobdirs"
48
52
 
@@ -64,15 +68,80 @@ SLURM_STATES: Mapping[str, AppState] = {
64
68
  "TIMEOUT": AppState.FAILED,
65
69
  }
66
70
 
71
+
72
+ def appstate_from_slurm_state(slurm_state: str) -> AppState:
73
+ return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74
+
75
+
76
+ def get_appstate_from_job(job: dict[str, object]) -> AppState:
77
+ # Prior to slurm-23.11, job_state was a string and not a list
78
+ job_state = job.get("job_state", None)
79
+ if isinstance(job_state, list):
80
+ return appstate_from_slurm_state(job_state[0])
81
+ else:
82
+ return appstate_from_slurm_state(str(job_state))
83
+
84
+
85
+ def version() -> Tuple[int, int]:
86
+ """
87
+ Uses ``sinfo --version`` to get the slurm version. If the command fails, it
88
+ assumes the version is ``slurm 24.05.8``.
89
+
90
+ Returns:
91
+ -------
92
+ Tuple[int, int] slurm version as a tuple of ints (major, minor).
93
+ """
94
+
95
+ cmd = ["sinfo", "--version"]
96
+ try:
97
+ out = subprocess.check_output(cmd, stderr=PIPE, encoding="utf-8")
98
+ except (CalledProcessError, FileNotFoundError):
99
+ out = "slurm 24.05.8"
100
+ warnings.warn(
101
+ "Error running: `{sinfo_cmd}` to get SLURM version. Are you running outside the "
102
+ "cluster's login or head node? This typically happens when running in `--dryrun`"
103
+ " mode. Assuming version is `slurm 24.05.8`.",
104
+ RuntimeWarning,
105
+ stacklevel=2,
106
+ )
107
+
108
+ # sinfo --version returns in the form "slurm 24.1.0"
109
+ _, version_literal = out.split(" ", maxsplit=2)
110
+ major, minor = [int(v) for v in version_literal.split(".")][:2]
111
+
112
+ return (major, minor)
113
+
114
+
115
+ def _should_use_gpus_per_node_from_version() -> bool:
116
+ """
117
+ Determine whether to use gpus-per-node based on automatically detected slurm version.
118
+
119
+ Change Reference: https://fburl.com/sqwqzxn6
120
+ > select/linear - Reject jobs asking for GRES per job|socket|task or cpus|mem per GRES.
121
+
122
+ Returns:
123
+ ``True`` in slurm ``version>=24.11.0``, ``False`` otherwise.
124
+ """
125
+
126
+ slurm_24_11_0 = (24, 11)
127
+ slurm_version = version()
128
+
129
+ return slurm_version[0] > slurm_24_11_0[0] or ( # Major version is greater
130
+ slurm_version[0] == slurm_24_11_0[0] and slurm_version[1] >= slurm_24_11_0[1]
131
+ ) # Major version is equal and minor version is greater or equal
132
+
133
+
67
134
  SBATCH_JOB_OPTIONS = {
68
135
  "comment",
69
136
  "mail-user",
70
137
  "mail-type",
138
+ "account",
71
139
  }
72
140
  SBATCH_GROUP_OPTIONS = {
73
141
  "partition",
74
142
  "time",
75
143
  "constraint",
144
+ "qos",
76
145
  }
77
146
 
78
147
  log: logging.Logger = logging.getLogger(__name__)
@@ -91,6 +160,7 @@ def _apply_app_id_env(s: str) -> str:
91
160
  SlurmOpts = TypedDict(
92
161
  "SlurmOpts",
93
162
  {
163
+ "account": Optional[str],
94
164
  "partition": str,
95
165
  "time": str,
96
166
  "comment": Optional[str],
@@ -98,6 +168,7 @@ SlurmOpts = TypedDict(
98
168
  "mail-user": Optional[str],
99
169
  "mail-type": Optional[str],
100
170
  "job_dir": Optional[str],
171
+ "qos": Optional[str],
101
172
  },
102
173
  total=False,
103
174
  )
@@ -118,7 +189,11 @@ class SlurmReplicaRequest:
118
189
 
119
190
  @classmethod
120
191
  def from_role(
121
- cls, name: str, role: Role, cfg: SlurmOpts, nomem: bool
192
+ cls,
193
+ name: str,
194
+ role: Role,
195
+ cfg: SlurmOpts,
196
+ nomem: bool,
122
197
  ) -> "SlurmReplicaRequest":
123
198
  """
124
199
  ``from_role`` creates a SlurmReplicaRequest for the specific role and
@@ -141,7 +216,12 @@ class SlurmReplicaRequest:
141
216
  if not nomem and resource.memMB > 0:
142
217
  sbatch_opts.setdefault("mem", str(resource.memMB))
143
218
  if resource.gpu > 0:
144
- sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
219
+ # Use smart GPU allocation based on automatically detected Slurm version
220
+ if _should_use_gpus_per_node_from_version():
221
+ sbatch_opts.setdefault("gpus-per-node", str(resource.gpu))
222
+ else:
223
+ sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))
224
+ sbatch_opts.setdefault("ntasks", "1")
145
225
 
146
226
  srun_opts = {
147
227
  "output": f"slurm-{macros.app_id}-{name}.out",
@@ -326,6 +406,12 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
326
406
 
327
407
  def _run_opts(self) -> runopts:
328
408
  opts = runopts()
409
+ opts.add(
410
+ "account",
411
+ type_=str,
412
+ help="The account to use for the slurm job.",
413
+ default=None,
414
+ )
329
415
  opts.add(
330
416
  "partition",
331
417
  type_=str,
@@ -368,6 +454,11 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
368
454
  iteration, jobs will be tracked in ``.torchxslurmjobdirs``.
369
455
  """,
370
456
  )
457
+ opts.add(
458
+ "qos",
459
+ type_=str,
460
+ help="Quality of Service (QoS) to assign to the job.",
461
+ )
371
462
  return opts
372
463
 
373
464
  def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
@@ -470,7 +561,7 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
470
561
 
471
562
  return AppDryRunInfo(req, repr)
472
563
 
473
- def _validate(self, app: AppDef, scheduler: str) -> None:
564
+ def _validate(self, app: AppDef, scheduler: str, cfg: SlurmOpts) -> None:
474
565
  # Skip validation step for slurm
475
566
  pass
476
567
 
@@ -478,10 +569,38 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
478
569
  subprocess.run(["scancel", app_id], check=True)
479
570
 
480
571
  def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
481
- p = subprocess.run(
482
- ["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
483
- )
484
- output = p.stdout.decode("utf-8").split("\n")
572
+ # NOTE: depending on the version of slurm, querying for job info
573
+ # with `squeue` for finished (or non-existent) jobs either:
574
+ # 1. errors out with 'slurm_load_jobs error: Invalid job id specified'
575
+ # 2. -- or -- squeue returns an empty jobs list
576
+ # in either case, fall back to the less descriptive but more persistent sacct
577
+ # (slurm cluster must have accounting storage enabled for sacct to work)
578
+ try:
579
+ if desc := self._describe_squeue(app_id):
580
+ return desc
581
+ except CalledProcessError as e:
582
+ log.info(
583
+ f"unable to get job info for `{app_id}` with `squeue` ({e.stderr}), trying `sacct`"
584
+ )
585
+ return self._describe_sacct(app_id)
586
+
587
+ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
588
+ # NOTE: Handles multiple job ID formats due to SLURM version differences.
589
+ # Different clusters use heterogeneous (+) vs regular (.) job ID formats.
590
+ try:
591
+ output = subprocess.check_output(
592
+ ["sacct", "--parsable2", "-j", app_id],
593
+ stderr=PIPE,
594
+ encoding="utf-8",
595
+ ).split("\n")
596
+ except CalledProcessError as e:
597
+ log.info(
598
+ "unable to get job info for `{}` with `sacct` ({})".format(
599
+ app_id, e.stderr
600
+ )
601
+ )
602
+ return None
603
+
485
604
  if len(output) <= 1:
486
605
  return None
487
606
 
@@ -492,20 +611,28 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
492
611
  msg = ""
493
612
  app_state = AppState.UNKNOWN
494
613
  for row in reader:
495
- job_id, *parts = row["JobID"].split("+")
614
+ # Handle both "+" (heterogeneous) and "." (regular) job ID formats
615
+ job_id_full = row["JobID"]
616
+
617
+ # Split on both "+" and "." to handle different SLURM configurations
618
+ if "+" in job_id_full:
619
+ job_id, *parts = job_id_full.split("+")
620
+ is_subjob = len(parts) > 0 and "." in parts[0]
621
+ else:
622
+ job_id, *parts = job_id_full.split(".")
623
+ is_subjob = len(parts) > 0
624
+
496
625
  if job_id != app_id:
497
626
  continue
498
- if len(parts) > 0 and "." in parts[0]:
499
- # we only care about the worker not the child jobs
627
+
628
+ if is_subjob:
629
+ # we only care about the main job not the child jobs (.batch, .0, etc.)
500
630
  continue
501
631
 
502
- state = row["State"]
503
- msg = state
504
- state_enum = SLURM_STATES.get(state)
505
- assert (
506
- state_enum
507
- ), f"failed to translate slurm state {state} to torchx state"
508
- app_state = state_enum
632
+ msg = row["State"]
633
+ # Remove truncation indicator (CANCELLED+) and extract base state from verbose formats
634
+ state = msg.split()[0].rstrip("+")
635
+ app_state = appstate_from_slurm_state(state)
509
636
 
510
637
  role, _, replica_id = row["JobName"].rpartition("-")
511
638
  if not replica_id or not role:
@@ -530,6 +657,157 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
530
657
  msg=msg,
531
658
  )
532
659
 
660
+ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
661
+ # NOTE: This method contains multiple compatibility checks for different SLURM versions
662
+ # due to API format changes across versions (20.02, 23.02, 24.05, 24.11+).
663
+
664
+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
665
+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
666
+ output = subprocess.check_output(
667
+ ["squeue", "--json", "-j", app_id], stderr=PIPE, encoding="utf-8"
668
+ )
669
+ output_json = json.loads(output)
670
+ jobs = output_json["jobs"]
671
+ if not jobs:
672
+ return None
673
+
674
+ roles: dict[str, Role] = {}
675
+ roles_statuses: dict[str, RoleStatus] = {}
676
+ state = AppState.UNKNOWN
677
+
678
+ for job in jobs:
679
+ # job name is of the form "{role_name}-{replica_id}"
680
+ role_name, _, replica_id = job["name"].rpartition("-")
681
+
682
+ entrypoint = job["command"]
683
+ image = job["current_working_directory"]
684
+ state = get_appstate_from_job(job)
685
+
686
+ job_resources = job["job_resources"]
687
+
688
+ role = roles.setdefault(
689
+ role_name,
690
+ Role(
691
+ name=role_name,
692
+ image=image,
693
+ entrypoint=entrypoint,
694
+ num_replicas=0,
695
+ ),
696
+ )
697
+ role_status = roles_statuses.setdefault(
698
+ role_name,
699
+ RoleStatus(role_name, replicas=[]),
700
+ )
701
+
702
+ if state == AppState.PENDING:
703
+ # NOTE: torchx launched jobs points to exactly one host
704
+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
705
+
706
+ # SLURM 24.11.5+ returns job_resources=None for pending jobs (issue #1101)
707
+ if job_resources is not None:
708
+ hostname = job_resources.get("scheduled_nodes", "")
709
+ # If scheduled_nodes not found in job_resources, try nodes.list
710
+ if not hostname and "nodes" in job_resources:
711
+ nodes_info = job_resources.get("nodes", {})
712
+ if isinstance(nodes_info, dict):
713
+ hostname = nodes_info.get("list", "")
714
+ else:
715
+ # For pending jobs where job_resources is None, check top-level fields
716
+ hostname = job.get("nodes", "") or job.get("scheduled_nodes", "")
717
+
718
+ role.num_replicas += 1
719
+ role_status.replicas.append(
720
+ ReplicaStatus(
721
+ id=int(replica_id),
722
+ role=role_name,
723
+ state=state,
724
+ hostname=hostname,
725
+ )
726
+ )
727
+ else: # state == AppState.RUNNING
728
+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
729
+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
730
+ # but we deal with jobs that have not been launched with torchx
731
+ # which can have multiple hosts per sub-job (count them as replicas)
732
+ nodes_data = job_resources.get("nodes", {})
733
+
734
+ # SLURM 24.11+ changed from allocated_nodes to nodes.allocation structure
735
+ if "allocation" in nodes_data and isinstance(
736
+ nodes_data["allocation"], list
737
+ ):
738
+ # SLURM 24.11+ format: nodes.allocation is a list
739
+ for node_info in nodes_data["allocation"]:
740
+ hostname = node_info["name"]
741
+ cpu = int(node_info["cpus"]["used"])
742
+ memMB = (
743
+ int(node_info["memory"]["allocated"]) // 1024
744
+ ) # Convert to MB
745
+
746
+ role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
747
+ role.num_replicas += 1
748
+ role_status.replicas.append(
749
+ ReplicaStatus(
750
+ id=int(replica_id),
751
+ role=role_name,
752
+ state=state,
753
+ hostname=hostname,
754
+ )
755
+ )
756
+ elif "allocated_nodes" in job_resources and isinstance(
757
+ job_resources["allocated_nodes"], list
758
+ ):
759
+ # Legacy format: allocated_nodes is a list
760
+ for node_info in job_resources["allocated_nodes"]:
761
+ # NOTE: we expect resource specs for all the nodes to be the same
762
+ # NOTE: use allocated (not used/requested) memory since
763
+ # users may only specify --cpu, in which case slurm
764
+ # uses the (system) configured {mem-per-cpu} * {cpus}
765
+ # to allocate memory.
766
+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
767
+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
768
+ cpu = int(node_info["cpus_used"])
769
+ memMB = int(node_info["memory_allocated"])
770
+
771
+ hostname = node_info["nodename"]
772
+
773
+ role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
774
+ role.num_replicas += 1
775
+ role_status.replicas.append(
776
+ ReplicaStatus(
777
+ id=int(replica_id),
778
+ role=role_name,
779
+ state=state,
780
+ hostname=hostname,
781
+ )
782
+ )
783
+ else:
784
+ # Fallback: use hostname from nodes.list
785
+ if isinstance(nodes_data, str):
786
+ hostname = nodes_data
787
+ else:
788
+ hostname = (
789
+ nodes_data.get("list", "")
790
+ if isinstance(nodes_data, dict)
791
+ else ""
792
+ )
793
+
794
+ role.num_replicas += 1
795
+ role_status.replicas.append(
796
+ ReplicaStatus(
797
+ id=int(replica_id),
798
+ role=role_name,
799
+ state=state,
800
+ hostname=hostname,
801
+ )
802
+ )
803
+
804
+ return DescribeAppResponse(
805
+ app_id=app_id,
806
+ roles=list(roles.values()),
807
+ roles_statuses=list(roles_statuses.values()),
808
+ state=state,
809
+ )
810
+
533
811
  def log_iter(
534
812
  self,
535
813
  app_id: str,
@@ -570,6 +848,12 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
570
848
  return iterator
571
849
 
572
850
  def list(self) -> List[ListAppResponse]:
851
+ try:
852
+ return self._list_sacct()
853
+ except subprocess.CalledProcessError:
854
+ return self._list_squeue()
855
+
856
+ def _list_sacct(self) -> List[ListAppResponse]:
573
857
  # By default sacct only returns accounting information of jobs launched on the current day
574
858
  # To return all jobs launched, set starttime to one second past unix epoch time
575
859
  # Starttime will be modified when listing jobs by timeframe is supported
@@ -586,6 +870,38 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
586
870
  for job in output_json["jobs"]
587
871
  ]
588
872
 
873
+ def _list_squeue(self) -> List[ListAppResponse]:
874
+ # if sacct isn't configured on the cluster, fallback to squeue which
875
+ # only has currently running jobs
876
+ p = subprocess.run(
877
+ ["squeue", "--json"],
878
+ stdout=subprocess.PIPE,
879
+ check=True,
880
+ )
881
+ output_json = json.loads(p.stdout.decode("utf-8"))
882
+
883
+ out = []
884
+ for job in output_json["jobs"]:
885
+ job_id = job["job_id"]
886
+
887
+ het_job_id = job.get("het_job_id")
888
+ if (
889
+ het_job_id
890
+ and het_job_id["set"]
891
+ and het_job_id["number"] != job_id
892
+ and het_job_id["number"] > 0
893
+ ):
894
+ continue
895
+
896
+ out.append(
897
+ ListAppResponse(
898
+ app_id=str(job["job_id"]),
899
+ state=get_appstate_from_job(job),
900
+ name=job["name"],
901
+ )
902
+ )
903
+ return out
904
+
589
905
 
590
906
  def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler:
591
907
  return SlurmScheduler(
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import io
9
11
  import os
10
12
  import threading
torchx/specs/__init__.py CHANGED
@@ -1,25 +1,22 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
5
4
  # This source code is licensed under the BSD-style license found in the
6
5
  # LICENSE file in the root directory of this source tree.
7
6
 
7
+ # pyre-strict
8
+
8
9
  """
9
10
  This contains the TorchX AppDef and related component definitions. These are
10
11
  used by components to define the apps which can then be launched via a TorchX
11
12
  scheduler or pipeline adapter.
12
13
  """
13
14
  import difflib
14
- from typing import Callable, Dict, Optional
15
15
 
16
- from torchx.specs.named_resources_aws import NAMED_RESOURCES as AWS_NAMED_RESOURCES
17
- from torchx.specs.named_resources_generic import (
18
- NAMED_RESOURCES as GENERIC_NAMED_RESOURCES,
19
- )
20
- from torchx.util.entrypoints import load_group
16
+ import os
17
+ from typing import Callable, Dict, Iterator, Mapping, Optional
21
18
 
22
- from .api import ( # noqa: F401 F403
19
+ from torchx.specs.api import (
23
20
  ALL,
24
21
  AppDef,
25
22
  AppDryRunInfo,
@@ -46,15 +43,36 @@ from .api import ( # noqa: F401 F403
46
43
  RoleStatus,
47
44
  runopt,
48
45
  runopts,
46
+ TORCHX_HOME,
49
47
  UnknownAppException,
50
48
  UnknownSchedulerException,
51
49
  VolumeMount,
50
+ Workspace,
52
51
  )
53
- from .builders import make_app_handle, materialize_appdef, parse_mounts # noqa
52
+ from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts
53
+
54
+ from torchx.util.entrypoints import load_group
55
+
56
+ from torchx.util.modules import import_attr
54
57
 
55
58
  GiB: int = 1024
56
59
 
57
60
 
61
+ ResourceFactory = Callable[[], Resource]
62
+
63
+ AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
64
+ "torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={}
65
+ )
66
+ GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
67
+ "torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
68
+ )
69
+ CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
70
+ os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"),
71
+ "NAMED_RESOURCES",
72
+ default={},
73
+ )
74
+
75
+
58
76
  def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
59
77
  resource_methods = load_group("torchx.named_resources", default={})
60
78
  materialized_resources: Dict[str, Callable[[], Resource]] = {}
@@ -62,6 +80,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
62
80
  for name, resource in {
63
81
  **GENERIC_NAMED_RESOURCES,
64
82
  **AWS_NAMED_RESOURCES,
83
+ **CUSTOM_NAMED_RESOURCES,
65
84
  **resource_methods,
66
85
  }.items():
67
86
  materialized_resources[name] = resource
@@ -94,8 +113,22 @@ class _NamedResourcesLibrary:
94
113
  def __contains__(self, key: str) -> bool:
95
114
  return key in _named_resource_factories
96
115
 
97
- def __iter__(self) -> None:
98
- raise NotImplementedError("named resources doesn't support iterating")
116
+ def __iter__(self) -> Iterator[str]:
117
+ """Iterates through the names of the registered named_resources.
118
+
119
+ Usage:
120
+
121
+ .. doctest::
122
+
123
+ from torchx import specs
124
+
125
+ for resource_name in specs.named_resources:
126
+ resource = specs.resource(h=resource_name)
127
+ assert isinstance(resource, specs.Resource)
128
+
129
+ """
130
+ for key in _named_resource_factories:
131
+ yield (key)
99
132
 
100
133
 
101
134
  named_resources: _NamedResourcesLibrary = _NamedResourcesLibrary()
@@ -115,7 +148,7 @@ def resource(
115
148
 
116
149
  If ``h`` is specified then it is used to look up the
117
150
  resource specs from the list of registered named resources.
118
- See `registering named resource <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
151
+ See `registering named resource <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
119
152
 
120
153
  Otherwise a ``Resource`` object is created from the raw resource specs.
121
154
 
@@ -181,3 +214,47 @@ def get_named_resources(res: str) -> Resource:
181
214
 
182
215
  """
183
216
  return named_resources[res]
217
+
218
+
219
+ __all__ = [
220
+ "AppDef",
221
+ "AppDryRunInfo",
222
+ "AppHandle",
223
+ "AppState",
224
+ "AppStatus",
225
+ "BindMount",
226
+ "CfgVal",
227
+ "DeviceMount",
228
+ "get_type_name",
229
+ "is_terminal",
230
+ "macros",
231
+ "MISSING",
232
+ "NONE",
233
+ "NULL_RESOURCE",
234
+ "parse_app_handle",
235
+ "ReplicaState",
236
+ "ReplicaStatus",
237
+ "Resource",
238
+ "RetryPolicy",
239
+ "Role",
240
+ "RoleStatus",
241
+ "runopt",
242
+ "runopts",
243
+ "UnknownAppException",
244
+ "UnknownSchedulerException",
245
+ "InvalidRunConfigException",
246
+ "MalformedAppHandleException",
247
+ "VolumeMount",
248
+ "resource",
249
+ "get_named_resources",
250
+ "named_resources",
251
+ "make_app_handle",
252
+ "materialize_appdef",
253
+ "parse_mounts",
254
+ "torchx_run_args_from_argparse",
255
+ "torchx_run_args_from_json",
256
+ "TorchXRunArgs",
257
+ "ALL",
258
+ "TORCHX_HOME",
259
+ "Workspace",
260
+ ]