torchx-nightly 2025.9.28__py3-none-any.whl → 2025.11.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

@@ -27,10 +27,81 @@ Install Volcano:
27
27
  See the
28
28
  `Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
29
29
  for more information.
30
+
31
+ Pod Overlay
32
+ ===========
33
+
34
+ You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
35
+ the ``kubernetes`` metadata on your role. The value can be:
36
+
37
+ - A dict with the overlay structure
38
+ - A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
39
+
40
+ Merge semantics:
41
+ - **dict**: recursive merge (upsert)
42
+ - **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
43
+ - **primitives**: replace
44
+
45
+ .. code:: python
46
+
47
+ from torchx.specs import Role
48
+
49
+ # Dict overlay - lists append, tuples replace
50
+ role = Role(
51
+ name="trainer",
52
+ image="my-image:latest",
53
+ entrypoint="train.py",
54
+ metadata={
55
+ "kubernetes": {
56
+ "spec": {
57
+ "nodeSelector": {"gpu": "true"},
58
+ "tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
59
+ "volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
60
+ }
61
+ }
62
+ }
63
+ )
64
+
65
+ # File URI overlay
66
+ role = Role(
67
+ name="trainer",
68
+ image="my-image:latest",
69
+ entrypoint="train.py",
70
+ metadata={
71
+ "kubernetes": "file:///path/to/pod_overlay.yaml"
72
+ }
73
+ )
74
+
75
+ CLI usage with builtin components:
76
+
77
+ .. code:: bash
78
+
79
+ $ torchx run --scheduler kubernetes dist.ddp \\
80
+ --metadata kubernetes=file:///path/to/pod_overlay.yaml \\
81
+ --script train.py
82
+
83
+ Example ``pod_overlay.yaml``:
84
+
85
+ .. code:: yaml
86
+
87
+ spec:
88
+ nodeSelector:
89
+ node.kubernetes.io/instance-type: p4d.24xlarge
90
+ tolerations:
91
+ - key: nvidia.com/gpu
92
+ operator: Exists
93
+ effect: NoSchedule
94
+ volumes: !!python/tuple
95
+ - name: my-volume
96
+ emptyDir: {}
97
+
98
+ The overlay is deep-merged with the generated pod, preserving existing fields
99
+ and adding or overriding specified ones.
30
100
  """
31
101
 
32
102
  import json
33
103
  import logging
104
+ import re
34
105
  import warnings
35
106
  from dataclasses import dataclass
36
107
  from datetime import datetime
@@ -45,6 +116,7 @@ from typing import (
45
116
  Tuple,
46
117
  TYPE_CHECKING,
47
118
  TypedDict,
119
+ Union,
48
120
  )
49
121
 
50
122
  import torchx
@@ -97,6 +169,40 @@ logger: logging.Logger = logging.getLogger(__name__)
97
169
  RESERVED_MILLICPU = 100
98
170
  RESERVED_MEMMB = 1024
99
171
 
172
+
173
+ def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
174
+ """Apply overlay dict to V1Pod object, merging nested fields.
175
+
176
+ Merge semantics:
177
+ - dict: upsert (recursive merge)
178
+ - list: append by default, replace if tuple
179
+ - primitives: replace
180
+ """
181
+ from kubernetes import client
182
+
183
+ api = client.ApiClient()
184
+ pod_dict = api.sanitize_for_serialization(pod)
185
+
186
+ def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
187
+ for key, value in overlay.items():
188
+ if isinstance(value, dict) and key in base and isinstance(base[key], dict):
189
+ deep_merge(base[key], value)
190
+ elif isinstance(value, tuple):
191
+ base[key] = list(value)
192
+ elif (
193
+ isinstance(value, list) and key in base and isinstance(base[key], list)
194
+ ):
195
+ base[key].extend(value)
196
+ else:
197
+ base[key] = value
198
+
199
+ deep_merge(pod_dict, overlay)
200
+
201
+ merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
202
+ pod.spec = merged_pod.spec
203
+ pod.metadata = merged_pod.metadata
204
+
205
+
100
206
  RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101
207
  RetryPolicy.REPLICA: [],
102
208
  RetryPolicy.APPLICATION: [
@@ -369,7 +475,7 @@ def app_to_resource(
369
475
  queue: str,
370
476
  service_account: Optional[str],
371
477
  priority_class: Optional[str] = None,
372
- ) -> Dict[str, object]:
478
+ ) -> Dict[str, Any]:
373
479
  """
374
480
  app_to_resource creates a volcano job kubernetes resource definition from
375
481
  the provided AppDef. The resource definition can be used to launch the
@@ -399,8 +505,20 @@ def app_to_resource(
399
505
  replica_role = values.apply(role)
400
506
  if role_idx == 0 and replica_id == 0:
401
507
  replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
508
+ replica_role.env["TORCHX_IMAGE"] = replica_role.image
402
509
 
403
510
  pod = role_to_pod(name, replica_role, service_account)
511
+ if k8s_metadata := role.metadata.get("kubernetes"):
512
+ if isinstance(k8s_metadata, str):
513
+ import fsspec
514
+
515
+ with fsspec.open(k8s_metadata, "r") as f:
516
+ k8s_metadata = yaml.unsafe_load(f)
517
+ elif not isinstance(k8s_metadata, dict):
518
+ raise ValueError(
519
+ f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
520
+ )
521
+ _apply_pod_overlay(pod, k8s_metadata)
404
522
  pod.metadata.labels.update(
405
523
  pod_labels(
406
524
  app=app,
@@ -443,7 +561,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
443
561
  if priority_class is not None:
444
562
  job_spec["priorityClassName"] = priority_class
445
563
 
446
- resource: Dict[str, object] = {
564
+ resource: Dict[str, Any] = {
447
565
  "apiVersion": "batch.volcano.sh/v1alpha1",
448
566
  "kind": "Job",
449
567
  "metadata": {"name": f"{unique_app_id}"},
@@ -455,7 +573,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
455
573
  @dataclass
456
574
  class KubernetesJob:
457
575
  images_to_push: Dict[str, Tuple[str, str]]
458
- resource: Dict[str, object]
576
+ resource: Dict[str, Any]
459
577
 
460
578
  def __str__(self) -> str:
461
579
  return yaml.dump(sanitize_for_serialization(self.resource))
@@ -470,6 +588,7 @@ class KubernetesOpts(TypedDict, total=False):
470
588
  image_repo: Optional[str]
471
589
  service_account: Optional[str]
472
590
  priority_class: Optional[str]
591
+ validate_spec: Optional[bool]
473
592
 
474
593
 
475
594
  class KubernetesScheduler(
@@ -485,7 +604,7 @@ class KubernetesScheduler(
485
604
  For installation instructions see: https://github.com/volcano-sh/volcano
486
605
 
487
606
  This has been confirmed to work with Volcano v1.3.0 and Kubernetes versions
488
- v1.18-1.21. See https://github.com/pytorch/torchx/issues/120 which is
607
+ v1.18-1.21. See https://github.com/meta-pytorch/torchx/issues/120 which is
489
608
  tracking Volcano support for Kubernetes v1.22.
490
609
 
491
610
  .. note::
@@ -635,7 +754,7 @@ class KubernetesScheduler(
635
754
  else:
636
755
  raise
637
756
 
638
- return f'{namespace}:{resp["metadata"]["name"]}'
757
+ return f"{namespace}:{resp['metadata']['name']}"
639
758
 
640
759
  def _submit_dryrun(
641
760
  self, app: AppDef, cfg: KubernetesOpts
@@ -658,6 +777,36 @@ class KubernetesScheduler(
658
777
  ), "priority_class must be a str"
659
778
 
660
779
  resource = app_to_resource(app, queue, service_account, priority_class)
780
+
781
+ if cfg.get("validate_spec"):
782
+ try:
783
+ self._custom_objects_api().create_namespaced_custom_object(
784
+ group="batch.volcano.sh",
785
+ version="v1alpha1",
786
+ namespace=cfg.get("namespace") or "default",
787
+ plural="jobs",
788
+ body=resource,
789
+ dry_run="All",
790
+ )
791
+ except Exception as e:
792
+ from kubernetes.client.rest import ApiException
793
+
794
+ if isinstance(e, ApiException):
795
+ raise ValueError(f"Invalid job spec: {e.reason}") from e
796
+ raise
797
+
798
+ job_name = resource["metadata"]["name"]
799
+ for task in resource["spec"]["tasks"]:
800
+ task_name = task["name"]
801
+ replicas = task.get("replicas", 1)
802
+ max_index = replicas - 1
803
+ pod_name = f"{job_name}-{task_name}-{max_index}"
804
+ if len(pod_name) > 63:
805
+ raise ValueError(
806
+ f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
807
+ f"Shorten app.name or role names"
808
+ )
809
+
661
810
  req = KubernetesJob(
662
811
  resource=resource,
663
812
  images_to_push=images_to_push,
@@ -702,19 +851,32 @@ class KubernetesScheduler(
702
851
  type_=str,
703
852
  help="The name of the PriorityClass to set on the job specs",
704
853
  )
854
+ opts.add(
855
+ "validate_spec",
856
+ type_=bool,
857
+ help="Validate job spec using Kubernetes API dry-run before submission",
858
+ default=True,
859
+ )
705
860
  return opts
706
861
 
707
862
  def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
863
+ from kubernetes.client.rest import ApiException
864
+
708
865
  namespace, name = app_id.split(":")
709
866
  roles = {}
710
867
  roles_statuses = {}
711
- resp = self._custom_objects_api().get_namespaced_custom_object_status(
712
- group="batch.volcano.sh",
713
- version="v1alpha1",
714
- namespace=namespace,
715
- plural="jobs",
716
- name=name,
717
- )
868
+ try:
869
+ resp = self._custom_objects_api().get_namespaced_custom_object_status(
870
+ group="batch.volcano.sh",
871
+ version="v1alpha1",
872
+ namespace=namespace,
873
+ plural="jobs",
874
+ name=name,
875
+ )
876
+ except ApiException as e:
877
+ if e.status == 404:
878
+ return None
879
+ raise
718
880
  status = resp.get("status")
719
881
  if status:
720
882
  state_str = status["state"]["phase"]
@@ -823,13 +985,34 @@ def create_scheduler(
823
985
  def pod_labels(
824
986
  app: AppDef, role_idx: int, role: Role, replica_id: int, app_id: str
825
987
  ) -> Dict[str, str]:
988
+
989
+ def clean(label_value: str) -> str:
990
+ # cleans the provided `label_value` to make it compliant
991
+ # to pod label specs as described in
992
+ # https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
993
+ #
994
+ # Valid label value:
995
+ # must be 63 characters or less (can be empty),
996
+ # unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]),
997
+ # could contain dashes (-), underscores (_), dots (.), and alphanumerics between.
998
+
999
+ # Replace invalid characters (allow: alphanum, -, _, .) with "."
1000
+ label_value = re.sub(r"[^A-Za-z0-9\-_.]", ".", label_value)
1001
+ # Replace leading non-alphanumeric with "."
1002
+ label_value = re.sub(r"^[^A-Za-z0-9]+", ".", label_value)
1003
+ # Replace trailing non-alphanumeric with "."
1004
+ label_value = re.sub(r"[^A-Za-z0-9]+$", ".", label_value)
1005
+
1006
+ # Trim to 63 characters
1007
+ return label_value[:63]
1008
+
826
1009
  return {
827
- LABEL_VERSION: torchx.__version__,
828
- LABEL_APP_NAME: app.name,
1010
+ LABEL_VERSION: clean(torchx.__version__),
1011
+ LABEL_APP_NAME: clean(app.name),
829
1012
  LABEL_ROLE_INDEX: str(role_idx),
830
- LABEL_ROLE_NAME: role.name,
1013
+ LABEL_ROLE_NAME: clean(role.name),
831
1014
  LABEL_REPLICA_ID: str(replica_id),
832
- LABEL_KUBE_APP_NAME: app.name,
1015
+ LABEL_KUBE_APP_NAME: clean(app.name),
833
1016
  LABEL_ORGANIZATION: "torchx.pytorch.org",
834
- LABEL_UNIQUE_NAME: app_id,
1017
+ LABEL_UNIQUE_NAME: clean(app_id),
835
1018
  }
@@ -73,6 +73,15 @@ def appstate_from_slurm_state(slurm_state: str) -> AppState:
73
73
  return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74
74
 
75
75
 
76
+ def get_appstate_from_job(job: dict[str, object]) -> AppState:
77
+ # Prior to slurm-23.11, job_state was a string and not a list
78
+ job_state = job.get("job_state", None)
79
+ if isinstance(job_state, list):
80
+ return appstate_from_slurm_state(job_state[0])
81
+ else:
82
+ return appstate_from_slurm_state(str(job_state))
83
+
84
+
76
85
  def version() -> Tuple[int, int]:
77
86
  """
78
87
  Uses ``sinfo --version`` to get the slurm version. If the command fails, it
@@ -666,7 +675,7 @@ class SlurmScheduler(
666
675
 
667
676
  entrypoint = job["command"]
668
677
  image = job["current_working_directory"]
669
- state = appstate_from_slurm_state(job["job_state"][0])
678
+ state = get_appstate_from_job(job)
670
679
 
671
680
  job_resources = job["job_resources"]
672
681
 
@@ -881,7 +890,7 @@ class SlurmScheduler(
881
890
  out.append(
882
891
  ListAppResponse(
883
892
  app_id=str(job["job_id"]),
884
- state=SLURM_STATES[job["job_state"][0]],
893
+ state=get_appstate_from_job(job),
885
894
  name=job["name"],
886
895
  )
887
896
  )
torchx/specs/__init__.py CHANGED
@@ -12,7 +12,9 @@ used by components to define the apps which can then be launched via a TorchX
12
12
  scheduler or pipeline adapter.
13
13
  """
14
14
  import difflib
15
- from typing import Callable, Dict, Mapping, Optional
15
+
16
+ import os
17
+ from typing import Callable, Dict, Iterator, Mapping, Optional
16
18
 
17
19
  from torchx.specs.api import (
18
20
  ALL,
@@ -41,9 +43,11 @@ from torchx.specs.api import (
41
43
  RoleStatus,
42
44
  runopt,
43
45
  runopts,
46
+ TORCHX_HOME,
44
47
  UnknownAppException,
45
48
  UnknownSchedulerException,
46
49
  VolumeMount,
50
+ Workspace,
47
51
  )
48
52
  from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts
49
53
 
@@ -53,6 +57,7 @@ from torchx.util.modules import import_attr
53
57
 
54
58
  GiB: int = 1024
55
59
 
60
+
56
61
  ResourceFactory = Callable[[], Resource]
57
62
 
58
63
  AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
@@ -61,8 +66,10 @@ AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
61
66
  GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
62
67
  "torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
63
68
  )
64
- FB_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
65
- "torchx.specs.fb.named_resources", "NAMED_RESOURCES", default={}
69
+ CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
70
+ os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"),
71
+ "NAMED_RESOURCES",
72
+ default={},
66
73
  )
67
74
 
68
75
 
@@ -73,7 +80,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
73
80
  for name, resource in {
74
81
  **GENERIC_NAMED_RESOURCES,
75
82
  **AWS_NAMED_RESOURCES,
76
- **FB_NAMED_RESOURCES,
83
+ **CUSTOM_NAMED_RESOURCES,
77
84
  **resource_methods,
78
85
  }.items():
79
86
  materialized_resources[name] = resource
@@ -106,8 +113,22 @@ class _NamedResourcesLibrary:
106
113
  def __contains__(self, key: str) -> bool:
107
114
  return key in _named_resource_factories
108
115
 
109
- def __iter__(self) -> None:
110
- raise NotImplementedError("named resources doesn't support iterating")
116
+ def __iter__(self) -> Iterator[str]:
117
+ """Iterates through the names of the registered named_resources.
118
+
119
+ Usage:
120
+
121
+ .. doctest::
122
+
123
+ from torchx import specs
124
+
125
+ for resource_name in specs.named_resources:
126
+ resource = specs.resource(h=resource_name)
127
+ assert isinstance(resource, specs.Resource)
128
+
129
+ """
130
+ for key in _named_resource_factories:
131
+ yield (key)
111
132
 
112
133
 
113
134
  named_resources: _NamedResourcesLibrary = _NamedResourcesLibrary()
@@ -127,7 +148,7 @@ def resource(
127
148
 
128
149
  If ``h`` is specified then it is used to look up the
129
150
  resource specs from the list of registered named resources.
130
- See `registering named resource <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
151
+ See `registering named resource <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
131
152
 
132
153
  Otherwise a ``Resource`` object is created from the raw resource specs.
133
154
 
@@ -234,4 +255,6 @@ __all__ = [
234
255
  "torchx_run_args_from_json",
235
256
  "TorchXRunArgs",
236
257
  "ALL",
258
+ "TORCHX_HOME",
259
+ "Workspace",
237
260
  ]