torchx-nightly 2025.10.16__py3-none-any.whl → 2025.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/_version.py +8 -0
- torchx/cli/cmd_delete.py +30 -0
- torchx/cli/main.py +2 -0
- torchx/runner/api.py +35 -33
- torchx/schedulers/api.py +58 -17
- torchx/schedulers/aws_batch_scheduler.py +2 -4
- torchx/schedulers/aws_sagemaker_scheduler.py +1 -1
- torchx/schedulers/docker_scheduler.py +1 -3
- torchx/schedulers/kubernetes_mcad_scheduler.py +1 -4
- torchx/schedulers/kubernetes_scheduler.py +234 -20
- torchx/schedulers/local_scheduler.py +1 -1
- torchx/schedulers/lsf_scheduler.py +1 -1
- torchx/schedulers/slurm_scheduler.py +9 -3
- torchx/specs/__init__.py +17 -3
- torchx/specs/api.py +82 -41
- torchx/version.py +2 -2
- torchx/workspace/api.py +63 -42
- {torchx_nightly-2025.10.16.dist-info → torchx_nightly-2025.12.2.dist-info}/METADATA +21 -8
- {torchx_nightly-2025.10.16.dist-info → torchx_nightly-2025.12.2.dist-info}/RECORD +23 -21
- {torchx_nightly-2025.10.16.dist-info → torchx_nightly-2025.12.2.dist-info}/WHEEL +1 -1
- {torchx_nightly-2025.10.16.dist-info → torchx_nightly-2025.12.2.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.10.16.dist-info → torchx_nightly-2025.12.2.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.10.16.dist-info → torchx_nightly-2025.12.2.dist-info}/top_level.txt +0 -0
|
@@ -27,10 +27,81 @@ Install Volcano:
|
|
|
27
27
|
See the
|
|
28
28
|
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
|
|
29
29
|
for more information.
|
|
30
|
+
|
|
31
|
+
Pod Overlay
|
|
32
|
+
===========
|
|
33
|
+
|
|
34
|
+
You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
|
|
35
|
+
the ``kubernetes`` metadata on your role. The value can be:
|
|
36
|
+
|
|
37
|
+
- A dict with the overlay structure
|
|
38
|
+
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
|
|
39
|
+
|
|
40
|
+
Merge semantics:
|
|
41
|
+
- **dict**: recursive merge (upsert)
|
|
42
|
+
- **list**: append by default, replace if tuple (Python) or ``!!python/tuple`` tag (YAML)
|
|
43
|
+
- **primitives**: replace
|
|
44
|
+
|
|
45
|
+
.. code:: python
|
|
46
|
+
|
|
47
|
+
from torchx.specs import Role
|
|
48
|
+
|
|
49
|
+
# Dict overlay - lists append, tuples replace
|
|
50
|
+
role = Role(
|
|
51
|
+
name="trainer",
|
|
52
|
+
image="my-image:latest",
|
|
53
|
+
entrypoint="train.py",
|
|
54
|
+
metadata={
|
|
55
|
+
"kubernetes": {
|
|
56
|
+
"spec": {
|
|
57
|
+
"nodeSelector": {"gpu": "true"},
|
|
58
|
+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}], # appends
|
|
59
|
+
"volumes": ({"name": "my-volume", "emptyDir": {}},) # replaces
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# File URI overlay
|
|
66
|
+
role = Role(
|
|
67
|
+
name="trainer",
|
|
68
|
+
image="my-image:latest",
|
|
69
|
+
entrypoint="train.py",
|
|
70
|
+
metadata={
|
|
71
|
+
"kubernetes": "file:///path/to/pod_overlay.yaml"
|
|
72
|
+
}
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
CLI usage with builtin components:
|
|
76
|
+
|
|
77
|
+
.. code:: bash
|
|
78
|
+
|
|
79
|
+
$ torchx run --scheduler kubernetes dist.ddp \\
|
|
80
|
+
--metadata kubernetes=file:///path/to/pod_overlay.yaml \\
|
|
81
|
+
--script train.py
|
|
82
|
+
|
|
83
|
+
Example ``pod_overlay.yaml``:
|
|
84
|
+
|
|
85
|
+
.. code:: yaml
|
|
86
|
+
|
|
87
|
+
spec:
|
|
88
|
+
nodeSelector:
|
|
89
|
+
node.kubernetes.io/instance-type: p4d.24xlarge
|
|
90
|
+
tolerations:
|
|
91
|
+
- key: nvidia.com/gpu
|
|
92
|
+
operator: Exists
|
|
93
|
+
effect: NoSchedule
|
|
94
|
+
volumes: !!python/tuple
|
|
95
|
+
- name: my-volume
|
|
96
|
+
emptyDir: {}
|
|
97
|
+
|
|
98
|
+
The overlay is deep-merged with the generated pod, preserving existing fields
|
|
99
|
+
and adding or overriding specified ones.
|
|
30
100
|
"""
|
|
31
101
|
|
|
32
102
|
import json
|
|
33
103
|
import logging
|
|
104
|
+
import re
|
|
34
105
|
import warnings
|
|
35
106
|
from dataclasses import dataclass
|
|
36
107
|
from datetime import datetime
|
|
@@ -45,6 +116,7 @@ from typing import (
|
|
|
45
116
|
Tuple,
|
|
46
117
|
TYPE_CHECKING,
|
|
47
118
|
TypedDict,
|
|
119
|
+
Union,
|
|
48
120
|
)
|
|
49
121
|
|
|
50
122
|
import torchx
|
|
@@ -97,6 +169,40 @@ logger: logging.Logger = logging.getLogger(__name__)
|
|
|
97
169
|
RESERVED_MILLICPU = 100
|
|
98
170
|
RESERVED_MEMMB = 1024
|
|
99
171
|
|
|
172
|
+
|
|
173
|
+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
|
|
174
|
+
"""Apply overlay dict to V1Pod object, merging nested fields.
|
|
175
|
+
|
|
176
|
+
Merge semantics:
|
|
177
|
+
- dict: upsert (recursive merge)
|
|
178
|
+
- list: append by default, replace if tuple
|
|
179
|
+
- primitives: replace
|
|
180
|
+
"""
|
|
181
|
+
from kubernetes import client
|
|
182
|
+
|
|
183
|
+
api = client.ApiClient()
|
|
184
|
+
pod_dict = api.sanitize_for_serialization(pod)
|
|
185
|
+
|
|
186
|
+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
|
|
187
|
+
for key, value in overlay.items():
|
|
188
|
+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
|
|
189
|
+
deep_merge(base[key], value)
|
|
190
|
+
elif isinstance(value, tuple):
|
|
191
|
+
base[key] = list(value)
|
|
192
|
+
elif (
|
|
193
|
+
isinstance(value, list) and key in base and isinstance(base[key], list)
|
|
194
|
+
):
|
|
195
|
+
base[key].extend(value)
|
|
196
|
+
else:
|
|
197
|
+
base[key] = value
|
|
198
|
+
|
|
199
|
+
deep_merge(pod_dict, overlay)
|
|
200
|
+
|
|
201
|
+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
|
|
202
|
+
pod.spec = merged_pod.spec
|
|
203
|
+
pod.metadata = merged_pod.metadata
|
|
204
|
+
|
|
205
|
+
|
|
100
206
|
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
|
|
101
207
|
RetryPolicy.REPLICA: [],
|
|
102
208
|
RetryPolicy.APPLICATION: [
|
|
@@ -369,7 +475,7 @@ def app_to_resource(
|
|
|
369
475
|
queue: str,
|
|
370
476
|
service_account: Optional[str],
|
|
371
477
|
priority_class: Optional[str] = None,
|
|
372
|
-
) -> Dict[str,
|
|
478
|
+
) -> Dict[str, Any]:
|
|
373
479
|
"""
|
|
374
480
|
app_to_resource creates a volcano job kubernetes resource definition from
|
|
375
481
|
the provided AppDef. The resource definition can be used to launch the
|
|
@@ -402,6 +508,17 @@ def app_to_resource(
|
|
|
402
508
|
replica_role.env["TORCHX_IMAGE"] = replica_role.image
|
|
403
509
|
|
|
404
510
|
pod = role_to_pod(name, replica_role, service_account)
|
|
511
|
+
if k8s_metadata := role.metadata.get("kubernetes"):
|
|
512
|
+
if isinstance(k8s_metadata, str):
|
|
513
|
+
import fsspec
|
|
514
|
+
|
|
515
|
+
with fsspec.open(k8s_metadata, "r") as f:
|
|
516
|
+
k8s_metadata = yaml.unsafe_load(f)
|
|
517
|
+
elif not isinstance(k8s_metadata, dict):
|
|
518
|
+
raise ValueError(
|
|
519
|
+
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
|
|
520
|
+
)
|
|
521
|
+
_apply_pod_overlay(pod, k8s_metadata)
|
|
405
522
|
pod.metadata.labels.update(
|
|
406
523
|
pod_labels(
|
|
407
524
|
app=app,
|
|
@@ -444,7 +561,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
|
|
|
444
561
|
if priority_class is not None:
|
|
445
562
|
job_spec["priorityClassName"] = priority_class
|
|
446
563
|
|
|
447
|
-
resource: Dict[str,
|
|
564
|
+
resource: Dict[str, Any] = {
|
|
448
565
|
"apiVersion": "batch.volcano.sh/v1alpha1",
|
|
449
566
|
"kind": "Job",
|
|
450
567
|
"metadata": {"name": f"{unique_app_id}"},
|
|
@@ -456,7 +573,7 @@ does NOT support retries correctly. More info: https://github.com/volcano-sh/vol
|
|
|
456
573
|
@dataclass
|
|
457
574
|
class KubernetesJob:
|
|
458
575
|
images_to_push: Dict[str, Tuple[str, str]]
|
|
459
|
-
resource: Dict[str,
|
|
576
|
+
resource: Dict[str, Any]
|
|
460
577
|
|
|
461
578
|
def __str__(self) -> str:
|
|
462
579
|
return yaml.dump(sanitize_for_serialization(self.resource))
|
|
@@ -471,12 +588,10 @@ class KubernetesOpts(TypedDict, total=False):
|
|
|
471
588
|
image_repo: Optional[str]
|
|
472
589
|
service_account: Optional[str]
|
|
473
590
|
priority_class: Optional[str]
|
|
591
|
+
validate_spec: Optional[bool]
|
|
474
592
|
|
|
475
593
|
|
|
476
|
-
class KubernetesScheduler(
|
|
477
|
-
DockerWorkspaceMixin,
|
|
478
|
-
Scheduler[KubernetesOpts, AppDef, AppDryRunInfo[KubernetesJob]],
|
|
479
|
-
):
|
|
594
|
+
class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
|
|
480
595
|
"""
|
|
481
596
|
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.
|
|
482
597
|
|
|
@@ -504,6 +619,16 @@ class KubernetesScheduler(
|
|
|
504
619
|
$ torchx status kubernetes://torchx_user/1234
|
|
505
620
|
...
|
|
506
621
|
|
|
622
|
+
**Cancellation**
|
|
623
|
+
|
|
624
|
+
Canceling a job aborts it while preserving the job spec for inspection
|
|
625
|
+
and cloning via kubectl apply. Use the delete command to remove the job entirely:
|
|
626
|
+
|
|
627
|
+
.. code-block:: bash
|
|
628
|
+
|
|
629
|
+
$ torchx cancel kubernetes://namespace/jobname # abort, preserves spec
|
|
630
|
+
$ torchx delete kubernetes://namespace/jobname # delete completely
|
|
631
|
+
|
|
507
632
|
**Config Options**
|
|
508
633
|
|
|
509
634
|
.. runopts::
|
|
@@ -636,7 +761,7 @@ class KubernetesScheduler(
|
|
|
636
761
|
else:
|
|
637
762
|
raise
|
|
638
763
|
|
|
639
|
-
return f
|
|
764
|
+
return f"{namespace}:{resp['metadata']['name']}"
|
|
640
765
|
|
|
641
766
|
def _submit_dryrun(
|
|
642
767
|
self, app: AppDef, cfg: KubernetesOpts
|
|
@@ -659,6 +784,36 @@ class KubernetesScheduler(
|
|
|
659
784
|
), "priority_class must be a str"
|
|
660
785
|
|
|
661
786
|
resource = app_to_resource(app, queue, service_account, priority_class)
|
|
787
|
+
|
|
788
|
+
if cfg.get("validate_spec"):
|
|
789
|
+
try:
|
|
790
|
+
self._custom_objects_api().create_namespaced_custom_object(
|
|
791
|
+
group="batch.volcano.sh",
|
|
792
|
+
version="v1alpha1",
|
|
793
|
+
namespace=cfg.get("namespace") or "default",
|
|
794
|
+
plural="jobs",
|
|
795
|
+
body=resource,
|
|
796
|
+
dry_run="All",
|
|
797
|
+
)
|
|
798
|
+
except Exception as e:
|
|
799
|
+
from kubernetes.client.rest import ApiException
|
|
800
|
+
|
|
801
|
+
if isinstance(e, ApiException):
|
|
802
|
+
raise ValueError(f"Invalid job spec: {e.reason}") from e
|
|
803
|
+
raise
|
|
804
|
+
|
|
805
|
+
job_name = resource["metadata"]["name"]
|
|
806
|
+
for task in resource["spec"]["tasks"]:
|
|
807
|
+
task_name = task["name"]
|
|
808
|
+
replicas = task.get("replicas", 1)
|
|
809
|
+
max_index = replicas - 1
|
|
810
|
+
pod_name = f"{job_name}-{task_name}-{max_index}"
|
|
811
|
+
if len(pod_name) > 63:
|
|
812
|
+
raise ValueError(
|
|
813
|
+
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
|
|
814
|
+
f"Shorten app.name or role names"
|
|
815
|
+
)
|
|
816
|
+
|
|
662
817
|
req = KubernetesJob(
|
|
663
818
|
resource=resource,
|
|
664
819
|
images_to_push=images_to_push,
|
|
@@ -670,6 +825,31 @@ class KubernetesScheduler(
|
|
|
670
825
|
pass
|
|
671
826
|
|
|
672
827
|
def _cancel_existing(self, app_id: str) -> None:
|
|
828
|
+
"""
|
|
829
|
+
Abort a Volcano job while preserving the spec for inspection.
|
|
830
|
+
"""
|
|
831
|
+
namespace, name = app_id.split(":")
|
|
832
|
+
vcjob = self._custom_objects_api().get_namespaced_custom_object(
|
|
833
|
+
group="batch.volcano.sh",
|
|
834
|
+
version="v1alpha1",
|
|
835
|
+
namespace=namespace,
|
|
836
|
+
plural="jobs",
|
|
837
|
+
name=name,
|
|
838
|
+
)
|
|
839
|
+
vcjob["status"]["state"]["phase"] = "Aborted"
|
|
840
|
+
self._custom_objects_api().replace_namespaced_custom_object_status(
|
|
841
|
+
group="batch.volcano.sh",
|
|
842
|
+
version="v1alpha1",
|
|
843
|
+
namespace=namespace,
|
|
844
|
+
plural="jobs",
|
|
845
|
+
name=name,
|
|
846
|
+
body=vcjob,
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
def _delete_existing(self, app_id: str) -> None:
|
|
850
|
+
"""
|
|
851
|
+
Delete a Volcano job completely from the cluster.
|
|
852
|
+
"""
|
|
673
853
|
namespace, name = app_id.split(":")
|
|
674
854
|
self._custom_objects_api().delete_namespaced_custom_object(
|
|
675
855
|
group="batch.volcano.sh",
|
|
@@ -703,19 +883,32 @@ class KubernetesScheduler(
|
|
|
703
883
|
type_=str,
|
|
704
884
|
help="The name of the PriorityClass to set on the job specs",
|
|
705
885
|
)
|
|
886
|
+
opts.add(
|
|
887
|
+
"validate_spec",
|
|
888
|
+
type_=bool,
|
|
889
|
+
help="Validate job spec using Kubernetes API dry-run before submission",
|
|
890
|
+
default=True,
|
|
891
|
+
)
|
|
706
892
|
return opts
|
|
707
893
|
|
|
708
894
|
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
|
|
895
|
+
from kubernetes.client.rest import ApiException
|
|
896
|
+
|
|
709
897
|
namespace, name = app_id.split(":")
|
|
710
898
|
roles = {}
|
|
711
899
|
roles_statuses = {}
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
900
|
+
try:
|
|
901
|
+
resp = self._custom_objects_api().get_namespaced_custom_object_status(
|
|
902
|
+
group="batch.volcano.sh",
|
|
903
|
+
version="v1alpha1",
|
|
904
|
+
namespace=namespace,
|
|
905
|
+
plural="jobs",
|
|
906
|
+
name=name,
|
|
907
|
+
)
|
|
908
|
+
except ApiException as e:
|
|
909
|
+
if e.status == 404:
|
|
910
|
+
return None
|
|
911
|
+
raise
|
|
719
912
|
status = resp.get("status")
|
|
720
913
|
if status:
|
|
721
914
|
state_str = status["state"]["phase"]
|
|
@@ -824,13 +1017,34 @@ def create_scheduler(
|
|
|
824
1017
|
def pod_labels(
|
|
825
1018
|
app: AppDef, role_idx: int, role: Role, replica_id: int, app_id: str
|
|
826
1019
|
) -> Dict[str, str]:
|
|
1020
|
+
|
|
1021
|
+
def clean(label_value: str) -> str:
|
|
1022
|
+
# cleans the provided `label_value` to make it compliant
|
|
1023
|
+
# to pod label specs as described in
|
|
1024
|
+
# https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/
|
|
1025
|
+
#
|
|
1026
|
+
# Valid label value:
|
|
1027
|
+
# must be 63 characters or less (can be empty),
|
|
1028
|
+
# unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]),
|
|
1029
|
+
# could contain dashes (-), underscores (_), dots (.), and alphanumerics between.
|
|
1030
|
+
|
|
1031
|
+
# Replace invalid characters (allow: alphanum, -, _, .) with "."
|
|
1032
|
+
label_value = re.sub(r"[^A-Za-z0-9\-_.]", ".", label_value)
|
|
1033
|
+
# Replace leading non-alphanumeric with "."
|
|
1034
|
+
label_value = re.sub(r"^[^A-Za-z0-9]+", ".", label_value)
|
|
1035
|
+
# Replace trailing non-alphanumeric with "."
|
|
1036
|
+
label_value = re.sub(r"[^A-Za-z0-9]+$", ".", label_value)
|
|
1037
|
+
|
|
1038
|
+
# Trim to 63 characters
|
|
1039
|
+
return label_value[:63]
|
|
1040
|
+
|
|
827
1041
|
return {
|
|
828
|
-
LABEL_VERSION: torchx.__version__,
|
|
829
|
-
LABEL_APP_NAME: app.name,
|
|
1042
|
+
LABEL_VERSION: clean(torchx.__version__),
|
|
1043
|
+
LABEL_APP_NAME: clean(app.name),
|
|
830
1044
|
LABEL_ROLE_INDEX: str(role_idx),
|
|
831
|
-
LABEL_ROLE_NAME: role.name,
|
|
1045
|
+
LABEL_ROLE_NAME: clean(role.name),
|
|
832
1046
|
LABEL_REPLICA_ID: str(replica_id),
|
|
833
|
-
LABEL_KUBE_APP_NAME: app.name,
|
|
1047
|
+
LABEL_KUBE_APP_NAME: clean(app.name),
|
|
834
1048
|
LABEL_ORGANIZATION: "torchx.pytorch.org",
|
|
835
|
-
LABEL_UNIQUE_NAME: app_id,
|
|
1049
|
+
LABEL_UNIQUE_NAME: clean(app_id),
|
|
836
1050
|
}
|
|
@@ -529,7 +529,7 @@ def _register_termination_signals() -> None:
|
|
|
529
529
|
signal.signal(signal.SIGINT, _terminate_process_handler)
|
|
530
530
|
|
|
531
531
|
|
|
532
|
-
class LocalScheduler(Scheduler[LocalOpts
|
|
532
|
+
class LocalScheduler(Scheduler[LocalOpts]):
|
|
533
533
|
"""
|
|
534
534
|
Schedules on localhost. Containers are modeled as processes and
|
|
535
535
|
certain properties of the container that are either not relevant
|
|
@@ -135,6 +135,7 @@ SBATCH_JOB_OPTIONS = {
|
|
|
135
135
|
"comment",
|
|
136
136
|
"mail-user",
|
|
137
137
|
"mail-type",
|
|
138
|
+
"account",
|
|
138
139
|
}
|
|
139
140
|
SBATCH_GROUP_OPTIONS = {
|
|
140
141
|
"partition",
|
|
@@ -159,6 +160,7 @@ def _apply_app_id_env(s: str) -> str:
|
|
|
159
160
|
SlurmOpts = TypedDict(
|
|
160
161
|
"SlurmOpts",
|
|
161
162
|
{
|
|
163
|
+
"account": Optional[str],
|
|
162
164
|
"partition": str,
|
|
163
165
|
"time": str,
|
|
164
166
|
"comment": Optional[str],
|
|
@@ -335,9 +337,7 @@ fi
|
|
|
335
337
|
{self.materialize()}"""
|
|
336
338
|
|
|
337
339
|
|
|
338
|
-
class SlurmScheduler(
|
|
339
|
-
DirWorkspaceMixin, Scheduler[SlurmOpts, AppDef, AppDryRunInfo[SlurmBatchRequest]]
|
|
340
|
-
):
|
|
340
|
+
class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
|
|
341
341
|
"""
|
|
342
342
|
SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects
|
|
343
343
|
that slurm CLI tools are locally installed and job accounting is enabled.
|
|
@@ -406,6 +406,12 @@ class SlurmScheduler(
|
|
|
406
406
|
|
|
407
407
|
def _run_opts(self) -> runopts:
|
|
408
408
|
opts = runopts()
|
|
409
|
+
opts.add(
|
|
410
|
+
"account",
|
|
411
|
+
type_=str,
|
|
412
|
+
help="The account to use for the slurm job.",
|
|
413
|
+
default=None,
|
|
414
|
+
)
|
|
409
415
|
opts.add(
|
|
410
416
|
"partition",
|
|
411
417
|
type_=str,
|
torchx/specs/__init__.py
CHANGED
|
@@ -14,7 +14,7 @@ scheduler or pipeline adapter.
|
|
|
14
14
|
import difflib
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
from typing import Callable, Dict, Mapping, Optional
|
|
17
|
+
from typing import Callable, Dict, Iterator, Mapping, Optional
|
|
18
18
|
|
|
19
19
|
from torchx.specs.api import (
|
|
20
20
|
ALL,
|
|
@@ -113,8 +113,22 @@ class _NamedResourcesLibrary:
|
|
|
113
113
|
def __contains__(self, key: str) -> bool:
|
|
114
114
|
return key in _named_resource_factories
|
|
115
115
|
|
|
116
|
-
def __iter__(self) ->
|
|
117
|
-
|
|
116
|
+
def __iter__(self) -> Iterator[str]:
|
|
117
|
+
"""Iterates through the names of the registered named_resources.
|
|
118
|
+
|
|
119
|
+
Usage:
|
|
120
|
+
|
|
121
|
+
.. doctest::
|
|
122
|
+
|
|
123
|
+
from torchx import specs
|
|
124
|
+
|
|
125
|
+
for resource_name in specs.named_resources:
|
|
126
|
+
resource = specs.resource(h=resource_name)
|
|
127
|
+
assert isinstance(resource, specs.Resource)
|
|
128
|
+
|
|
129
|
+
"""
|
|
130
|
+
for key in _named_resource_factories:
|
|
131
|
+
yield (key)
|
|
118
132
|
|
|
119
133
|
|
|
120
134
|
named_resources: _NamedResourcesLibrary = _NamedResourcesLibrary()
|
torchx/specs/api.py
CHANGED
|
@@ -14,10 +14,12 @@ import logging as logger
|
|
|
14
14
|
import os
|
|
15
15
|
import pathlib
|
|
16
16
|
import re
|
|
17
|
+
import shutil
|
|
17
18
|
import typing
|
|
19
|
+
import warnings
|
|
18
20
|
from dataclasses import asdict, dataclass, field
|
|
19
21
|
from datetime import datetime
|
|
20
|
-
from enum import Enum
|
|
22
|
+
from enum import Enum, IntEnum
|
|
21
23
|
from json import JSONDecodeError
|
|
22
24
|
from string import Template
|
|
23
25
|
from typing import (
|
|
@@ -251,7 +253,9 @@ class macros:
|
|
|
251
253
|
current_dict[k] = self.substitute(v)
|
|
252
254
|
elif isinstance(v, list):
|
|
253
255
|
for i in range(len(v)):
|
|
254
|
-
if isinstance(v[i],
|
|
256
|
+
if isinstance(v[i], dict):
|
|
257
|
+
stack.append(v[i])
|
|
258
|
+
elif isinstance(v[i], str):
|
|
255
259
|
v[i] = self.substitute(v[i])
|
|
256
260
|
return d
|
|
257
261
|
|
|
@@ -380,6 +384,16 @@ class Workspace:
|
|
|
380
384
|
"""False if no projects mapping. Lets us use workspace object in an if-statement"""
|
|
381
385
|
return bool(self.projects)
|
|
382
386
|
|
|
387
|
+
def __eq__(self, other: object) -> bool:
|
|
388
|
+
if not isinstance(other, Workspace):
|
|
389
|
+
return False
|
|
390
|
+
return self.projects == other.projects
|
|
391
|
+
|
|
392
|
+
def __hash__(self) -> int:
|
|
393
|
+
# makes it possible to use Workspace as the key in the workspace build cache
|
|
394
|
+
# see WorkspaceMixin.caching_build_workspace_and_update_role
|
|
395
|
+
return hash(frozenset(self.projects.items()))
|
|
396
|
+
|
|
383
397
|
def is_unmapped_single_project(self) -> bool:
|
|
384
398
|
"""
|
|
385
399
|
Returns ``True`` if this workspace only has 1 project
|
|
@@ -387,6 +401,39 @@ class Workspace:
|
|
|
387
401
|
"""
|
|
388
402
|
return len(self.projects) == 1 and not next(iter(self.projects.values()))
|
|
389
403
|
|
|
404
|
+
def merge_into(self, outdir: str | pathlib.Path) -> None:
|
|
405
|
+
"""
|
|
406
|
+
Copies each project dir of this workspace into the specified ``outdir``.
|
|
407
|
+
Each project dir is copied into ``{outdir}/{target}`` where ``target`` is
|
|
408
|
+
the target mapping of the project dir.
|
|
409
|
+
|
|
410
|
+
For example:
|
|
411
|
+
|
|
412
|
+
.. code-block:: python
|
|
413
|
+
from os.path import expanduser
|
|
414
|
+
|
|
415
|
+
workspace = Workspace(
|
|
416
|
+
projects={
|
|
417
|
+
expanduser("~/workspace/torch"): "torch",
|
|
418
|
+
expanduser("~/workspace/my_project": "")
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
workspace.merge_into(expanduser("~/tmp"))
|
|
422
|
+
|
|
423
|
+
Copies:
|
|
424
|
+
|
|
425
|
+
* ``~/workspace/torch/**`` into ``~/tmp/torch/**``
|
|
426
|
+
* ``~/workspace/my_project/**`` into ``~/tmp/**``
|
|
427
|
+
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
for src, dst in self.projects.items():
|
|
431
|
+
dst_path = pathlib.Path(outdir) / dst
|
|
432
|
+
if pathlib.Path(src).is_file():
|
|
433
|
+
shutil.copy2(src, dst_path)
|
|
434
|
+
else: # src is dir
|
|
435
|
+
shutil.copytree(src, dst_path, dirs_exist_ok=True)
|
|
436
|
+
|
|
390
437
|
@staticmethod
|
|
391
438
|
def from_str(workspace: str | None) -> "Workspace":
|
|
392
439
|
import yaml
|
|
@@ -891,14 +938,12 @@ class runopt:
|
|
|
891
938
|
Represents the metadata about the specific run option
|
|
892
939
|
"""
|
|
893
940
|
|
|
894
|
-
class alias(str):
|
|
895
|
-
pass
|
|
896
|
-
|
|
897
941
|
default: CfgVal
|
|
898
942
|
opt_type: Type[CfgVal]
|
|
899
943
|
is_required: bool
|
|
900
944
|
help: str
|
|
901
|
-
aliases: list[
|
|
945
|
+
aliases: list[str] | None = None
|
|
946
|
+
deprecated_aliases: list[str] | None = None
|
|
902
947
|
|
|
903
948
|
@property
|
|
904
949
|
def is_type_list_of_str(self) -> bool:
|
|
@@ -990,7 +1035,7 @@ class runopts:
|
|
|
990
1035
|
|
|
991
1036
|
def __init__(self) -> None:
|
|
992
1037
|
self._opts: Dict[str, runopt] = {}
|
|
993
|
-
self._alias_to_key: dict[
|
|
1038
|
+
self._alias_to_key: dict[str, str] = {}
|
|
994
1039
|
|
|
995
1040
|
def __iter__(self) -> Iterator[Tuple[str, runopt]]:
|
|
996
1041
|
return self._opts.items().__iter__()
|
|
@@ -1044,12 +1089,24 @@ class runopts:
|
|
|
1044
1089
|
val = resolved_cfg.get(cfg_key)
|
|
1045
1090
|
resolved_name = None
|
|
1046
1091
|
aliases = runopt.aliases or []
|
|
1092
|
+
deprecated_aliases = runopt.deprecated_aliases or []
|
|
1047
1093
|
if val is None:
|
|
1048
1094
|
for alias in aliases:
|
|
1049
1095
|
val = resolved_cfg.get(alias)
|
|
1050
1096
|
if alias in cfg or val is not None:
|
|
1051
1097
|
resolved_name = alias
|
|
1052
1098
|
break
|
|
1099
|
+
for alias in deprecated_aliases:
|
|
1100
|
+
val = resolved_cfg.get(alias)
|
|
1101
|
+
if val is not None:
|
|
1102
|
+
resolved_name = alias
|
|
1103
|
+
use_instead = self._alias_to_key.get(alias)
|
|
1104
|
+
warnings.warn(
|
|
1105
|
+
f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
|
|
1106
|
+
UserWarning,
|
|
1107
|
+
stacklevel=2,
|
|
1108
|
+
)
|
|
1109
|
+
break
|
|
1053
1110
|
else:
|
|
1054
1111
|
resolved_name = cfg_key
|
|
1055
1112
|
for alias in aliases:
|
|
@@ -1172,49 +1229,23 @@ class runopts:
|
|
|
1172
1229
|
cfg[key] = val
|
|
1173
1230
|
return cfg
|
|
1174
1231
|
|
|
1175
|
-
def _get_primary_key_and_aliases(
|
|
1176
|
-
self,
|
|
1177
|
-
cfg_key: list[str] | str,
|
|
1178
|
-
) -> tuple[str, list[runopt.alias]]:
|
|
1179
|
-
"""
|
|
1180
|
-
Returns the primary key and aliases for the given cfg_key.
|
|
1181
|
-
"""
|
|
1182
|
-
if isinstance(cfg_key, str):
|
|
1183
|
-
return cfg_key, []
|
|
1184
|
-
|
|
1185
|
-
if len(cfg_key) == 0:
|
|
1186
|
-
raise ValueError("cfg_key must be a non-empty list")
|
|
1187
|
-
primary_key = None
|
|
1188
|
-
aliases = list[runopt.alias]()
|
|
1189
|
-
for name in cfg_key:
|
|
1190
|
-
if isinstance(name, runopt.alias):
|
|
1191
|
-
aliases.append(name)
|
|
1192
|
-
else:
|
|
1193
|
-
if primary_key is not None:
|
|
1194
|
-
raise ValueError(
|
|
1195
|
-
f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. "
|
|
1196
|
-
)
|
|
1197
|
-
primary_key = name
|
|
1198
|
-
if primary_key is None or primary_key == "":
|
|
1199
|
-
raise ValueError(
|
|
1200
|
-
"Missing cfg_key. Please provide one other than the aliases."
|
|
1201
|
-
)
|
|
1202
|
-
return primary_key, aliases
|
|
1203
|
-
|
|
1204
1232
|
def add(
|
|
1205
1233
|
self,
|
|
1206
|
-
cfg_key: str
|
|
1234
|
+
cfg_key: str,
|
|
1207
1235
|
type_: Type[CfgVal],
|
|
1208
1236
|
help: str,
|
|
1209
1237
|
default: CfgVal = None,
|
|
1210
1238
|
required: bool = False,
|
|
1239
|
+
aliases: Optional[list[str]] = None,
|
|
1240
|
+
deprecated_aliases: Optional[list[str]] = None,
|
|
1211
1241
|
) -> None:
|
|
1212
1242
|
"""
|
|
1213
1243
|
Adds the ``config`` option with the given help string and ``default``
|
|
1214
1244
|
value (if any). If the ``default`` is not specified then this option
|
|
1215
1245
|
is a required option.
|
|
1216
1246
|
"""
|
|
1217
|
-
|
|
1247
|
+
aliases = aliases or []
|
|
1248
|
+
deprecated_aliases = deprecated_aliases or []
|
|
1218
1249
|
if required and default is not None:
|
|
1219
1250
|
raise ValueError(
|
|
1220
1251
|
f"Required option: {cfg_key} must not specify default value. Given: {default}"
|
|
@@ -1225,10 +1256,20 @@ class runopts:
|
|
|
1225
1256
|
f"Option: {cfg_key}, must be of type: {type_}."
|
|
1226
1257
|
f" Given: {default} ({type(default).__name__})"
|
|
1227
1258
|
)
|
|
1228
|
-
|
|
1259
|
+
|
|
1260
|
+
opt = runopt(
|
|
1261
|
+
default,
|
|
1262
|
+
type_,
|
|
1263
|
+
required,
|
|
1264
|
+
help,
|
|
1265
|
+
list(set(aliases)),
|
|
1266
|
+
list(set(deprecated_aliases)),
|
|
1267
|
+
)
|
|
1229
1268
|
for alias in aliases:
|
|
1230
|
-
self._alias_to_key[alias] =
|
|
1231
|
-
|
|
1269
|
+
self._alias_to_key[alias] = cfg_key
|
|
1270
|
+
for deprecated_alias in deprecated_aliases:
|
|
1271
|
+
self._alias_to_key[deprecated_alias] = cfg_key
|
|
1272
|
+
self._opts[cfg_key] = opt
|
|
1232
1273
|
|
|
1233
1274
|
def update(self, other: "runopts") -> None:
|
|
1234
1275
|
self._opts.update(other._opts)
|
torchx/version.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
@@ -7,6 +6,7 @@
|
|
|
7
6
|
|
|
8
7
|
# pyre-strict
|
|
9
8
|
|
|
9
|
+
from torchx._version import BASE_VERSION
|
|
10
10
|
from torchx.util.entrypoints import load
|
|
11
11
|
|
|
12
12
|
# Follows PEP-0440 version scheme guidelines
|
|
@@ -18,7 +18,7 @@ from torchx.util.entrypoints import load
|
|
|
18
18
|
# 0.1.0bN # Beta release
|
|
19
19
|
# 0.1.0rcN # Release Candidate
|
|
20
20
|
# 0.1.0 # Final release
|
|
21
|
-
__version__ =
|
|
21
|
+
__version__: str = BASE_VERSION
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# Use the github container registry images corresponding to the current package
|