torchx-nightly 2024.2.12__py3-none-any.whl → 2025.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/__init__.py +2 -0
- torchx/apps/serve/serve.py +2 -0
- torchx/apps/utils/booth_main.py +2 -0
- torchx/apps/utils/copy_main.py +2 -0
- torchx/apps/utils/process_monitor.py +2 -0
- torchx/cli/__init__.py +2 -0
- torchx/cli/argparse_util.py +38 -3
- torchx/cli/cmd_base.py +2 -0
- torchx/cli/cmd_cancel.py +2 -0
- torchx/cli/cmd_configure.py +2 -0
- torchx/cli/cmd_describe.py +2 -0
- torchx/cli/cmd_list.py +2 -0
- torchx/cli/cmd_log.py +6 -24
- torchx/cli/cmd_run.py +30 -12
- torchx/cli/cmd_runopts.py +2 -0
- torchx/cli/cmd_status.py +2 -0
- torchx/cli/cmd_tracker.py +2 -0
- torchx/cli/colors.py +2 -0
- torchx/cli/main.py +2 -0
- torchx/components/__init__.py +2 -0
- torchx/components/component_test_base.py +2 -0
- torchx/components/dist.py +2 -0
- torchx/components/integration_tests/component_provider.py +2 -0
- torchx/components/integration_tests/integ_tests.py +2 -0
- torchx/components/serve.py +2 -0
- torchx/components/structured_arg.py +2 -0
- torchx/components/utils.py +2 -0
- torchx/examples/apps/datapreproc/datapreproc.py +2 -0
- torchx/examples/apps/lightning/data.py +5 -3
- torchx/examples/apps/lightning/model.py +2 -0
- torchx/examples/apps/lightning/profiler.py +7 -4
- torchx/examples/apps/lightning/train.py +2 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
- torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
- torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
- torchx/examples/torchx_out_of_sync_training.py +11 -0
- torchx/notebook.py +2 -0
- torchx/pipelines/kfp/__init__.py +2 -0
- torchx/pipelines/kfp/adapter.py +7 -4
- torchx/pipelines/kfp/version.py +2 -0
- torchx/runner/__init__.py +2 -0
- torchx/runner/api.py +78 -20
- torchx/runner/config.py +34 -3
- torchx/runner/events/__init__.py +37 -3
- torchx/runner/events/api.py +13 -2
- torchx/runner/events/handlers.py +2 -0
- torchx/runtime/tracking/__init__.py +2 -0
- torchx/runtime/tracking/api.py +2 -0
- torchx/schedulers/__init__.py +10 -5
- torchx/schedulers/api.py +3 -1
- torchx/schedulers/aws_batch_scheduler.py +4 -0
- torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
- torchx/schedulers/devices.py +17 -4
- torchx/schedulers/docker_scheduler.py +38 -8
- torchx/schedulers/gcp_batch_scheduler.py +8 -9
- torchx/schedulers/ids.py +2 -0
- torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
- torchx/schedulers/kubernetes_scheduler.py +31 -5
- torchx/schedulers/local_scheduler.py +45 -6
- torchx/schedulers/lsf_scheduler.py +3 -1
- torchx/schedulers/ray/ray_driver.py +7 -7
- torchx/schedulers/ray_scheduler.py +1 -1
- torchx/schedulers/slurm_scheduler.py +3 -1
- torchx/schedulers/streams.py +2 -0
- torchx/specs/__init__.py +49 -8
- torchx/specs/api.py +87 -5
- torchx/specs/builders.py +61 -19
- torchx/specs/file_linter.py +8 -2
- torchx/specs/finder.py +2 -0
- torchx/specs/named_resources_aws.py +109 -2
- torchx/specs/named_resources_generic.py +2 -0
- torchx/specs/test/components/__init__.py +2 -0
- torchx/specs/test/components/a/__init__.py +2 -0
- torchx/specs/test/components/a/b/__init__.py +2 -0
- torchx/specs/test/components/a/b/c.py +2 -0
- torchx/specs/test/components/c/__init__.py +2 -0
- torchx/specs/test/components/c/d.py +2 -0
- torchx/tracker/__init__.py +2 -0
- torchx/tracker/api.py +4 -4
- torchx/tracker/backend/fsspec.py +2 -0
- torchx/util/cuda.py +2 -0
- torchx/util/datetime.py +2 -0
- torchx/util/entrypoints.py +6 -2
- torchx/util/io.py +2 -0
- torchx/util/log_tee_helpers.py +210 -0
- torchx/util/modules.py +2 -0
- torchx/util/session.py +42 -0
- torchx/util/shlex.py +2 -0
- torchx/util/strings.py +2 -0
- torchx/util/types.py +20 -2
- torchx/version.py +3 -1
- torchx/workspace/__init__.py +2 -0
- torchx/workspace/api.py +34 -1
- torchx/workspace/dir_workspace.py +2 -0
- torchx/workspace/docker_workspace.py +25 -2
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/METADATA +55 -48
- torchx_nightly-2025.1.14.dist-info/RECORD +123 -0
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/WHEEL +1 -1
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/entry_points.txt +0 -1
- torchx_nightly-2024.2.12.dist-info/RECORD +0 -119
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/LICENSE +0 -0
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/top_level.txt +0 -0
torchx/specs/api.py
CHANGED
|
@@ -5,12 +5,19 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
8
11
|
import copy
|
|
12
|
+
import inspect
|
|
9
13
|
import json
|
|
14
|
+
import logging as logger
|
|
10
15
|
import re
|
|
16
|
+
import typing
|
|
11
17
|
from dataclasses import asdict, dataclass, field
|
|
12
18
|
from datetime import datetime
|
|
13
19
|
from enum import Enum
|
|
20
|
+
from json import JSONDecodeError
|
|
14
21
|
from string import Template
|
|
15
22
|
from typing import (
|
|
16
23
|
Any,
|
|
@@ -183,11 +190,37 @@ class macros:
|
|
|
183
190
|
apply applies the values to a copy the specified role and returns it.
|
|
184
191
|
"""
|
|
185
192
|
|
|
193
|
+
# Overrides might contain future values which can't be serialized so taken out for the copy
|
|
194
|
+
overrides = role.overrides
|
|
195
|
+
if len(overrides) > 0:
|
|
196
|
+
logger.warning(
|
|
197
|
+
"Role overrides are not supported for macros. Overrides will not be copied"
|
|
198
|
+
)
|
|
199
|
+
role.overrides = {}
|
|
186
200
|
role = copy.deepcopy(role)
|
|
201
|
+
role.overrides = overrides
|
|
202
|
+
|
|
187
203
|
role.args = [self.substitute(arg) for arg in role.args]
|
|
188
204
|
role.env = {key: self.substitute(arg) for key, arg in role.env.items()}
|
|
205
|
+
role.metadata = self._apply_nested(role.metadata)
|
|
206
|
+
|
|
189
207
|
return role
|
|
190
208
|
|
|
209
|
+
def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]:
|
|
210
|
+
stack = [d]
|
|
211
|
+
while stack:
|
|
212
|
+
current_dict = stack.pop()
|
|
213
|
+
for k, v in current_dict.items():
|
|
214
|
+
if isinstance(v, dict):
|
|
215
|
+
stack.append(v)
|
|
216
|
+
elif isinstance(v, str):
|
|
217
|
+
current_dict[k] = self.substitute(v)
|
|
218
|
+
elif isinstance(v, list):
|
|
219
|
+
for i in range(len(v)):
|
|
220
|
+
if isinstance(v[i], str):
|
|
221
|
+
v[i] = self.substitute(v[i])
|
|
222
|
+
return d
|
|
223
|
+
|
|
191
224
|
def substitute(self, arg: str) -> str:
|
|
192
225
|
"""
|
|
193
226
|
substitute applies the values to the template arg.
|
|
@@ -216,11 +249,13 @@ class RetryPolicy(str, Enum):
|
|
|
216
249
|
application to deal with failed replica departures and
|
|
217
250
|
replacement replica admittance.
|
|
218
251
|
2. APPLICATION: Restarts the entire application.
|
|
219
|
-
|
|
252
|
+
3. ROLE: Restarts the role when any error occurs in that role. This does not
|
|
253
|
+
restart the whole job.
|
|
220
254
|
"""
|
|
221
255
|
|
|
222
256
|
REPLICA = "REPLICA"
|
|
223
257
|
APPLICATION = "APPLICATION"
|
|
258
|
+
ROLE = "ROLE"
|
|
224
259
|
|
|
225
260
|
|
|
226
261
|
class MountType(str, Enum):
|
|
@@ -347,6 +382,24 @@ class Role:
|
|
|
347
382
|
mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
|
|
348
383
|
default_factory=list
|
|
349
384
|
)
|
|
385
|
+
overrides: Dict[str, Any] = field(default_factory=dict)
|
|
386
|
+
|
|
387
|
+
# pyre-ignore
|
|
388
|
+
def __getattribute__(self, attrname: str) -> Any:
|
|
389
|
+
if attrname == "overrides":
|
|
390
|
+
return super().__getattribute__(attrname)
|
|
391
|
+
try:
|
|
392
|
+
ov = super().__getattribute__("overrides")
|
|
393
|
+
except AttributeError:
|
|
394
|
+
ov = {}
|
|
395
|
+
if attrname in ov:
|
|
396
|
+
if inspect.isawaitable(ov[attrname]):
|
|
397
|
+
result = asyncio.get_event_loop().run_until_complete(ov[attrname])
|
|
398
|
+
else:
|
|
399
|
+
result = ov[attrname]()
|
|
400
|
+
setattr(self, attrname, result)
|
|
401
|
+
del ov[attrname]
|
|
402
|
+
return super().__getattribute__(attrname)
|
|
350
403
|
|
|
351
404
|
def pre_proc(
|
|
352
405
|
self,
|
|
@@ -552,7 +605,10 @@ class AppStatus:
|
|
|
552
605
|
|
|
553
606
|
def _format_replica_status(self, replica_status: ReplicaStatus) -> str:
|
|
554
607
|
if replica_status.structured_error_msg != NONE:
|
|
555
|
-
|
|
608
|
+
try:
|
|
609
|
+
error_data = json.loads(replica_status.structured_error_msg)
|
|
610
|
+
except JSONDecodeError:
|
|
611
|
+
return replica_status.structured_error_msg
|
|
556
612
|
error_message = self._format_error_message(
|
|
557
613
|
msg=error_data["message"]["message"], header=" error_msg: "
|
|
558
614
|
)
|
|
@@ -637,11 +693,11 @@ class AppStatusError(Exception):
|
|
|
637
693
|
self.status = status
|
|
638
694
|
|
|
639
695
|
|
|
640
|
-
# valid run cfg values; only support primitives (str, int, float, bool, List[str])
|
|
696
|
+
# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
|
|
641
697
|
# TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
|
|
642
698
|
# in isinstance(). Should replace with that.
|
|
643
699
|
# see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
|
|
644
|
-
CfgVal = Union[str, int, float, bool, List[str], None]
|
|
700
|
+
CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]
|
|
645
701
|
|
|
646
702
|
|
|
647
703
|
T = TypeVar("T")
|
|
@@ -755,6 +811,10 @@ class runopts:
|
|
|
755
811
|
except TypeError:
|
|
756
812
|
if isinstance(obj, list):
|
|
757
813
|
return all(isinstance(e, str) for e in obj)
|
|
814
|
+
elif isinstance(obj, dict):
|
|
815
|
+
return all(
|
|
816
|
+
isinstance(k, str) and isinstance(v, str) for k, v in obj.items()
|
|
817
|
+
)
|
|
758
818
|
else:
|
|
759
819
|
return False
|
|
760
820
|
|
|
@@ -863,8 +923,13 @@ class runopts:
|
|
|
863
923
|
# lists may be ; or , delimited
|
|
864
924
|
# also deal with trailing "," by removing empty strings
|
|
865
925
|
return [v for v in value.replace(";", ",").split(",") if v]
|
|
926
|
+
elif opt_type == Dict[str, str]:
|
|
927
|
+
return {
|
|
928
|
+
s.split(":", 1)[0]: s.split(":", 1)[1]
|
|
929
|
+
for s in value.replace(";", ",").split(",")
|
|
930
|
+
}
|
|
866
931
|
else:
|
|
867
|
-
# pyre-ignore[19]
|
|
932
|
+
# pyre-ignore[19, 6] type won't be dict here as we handled it above
|
|
868
933
|
return opt_type(value)
|
|
869
934
|
|
|
870
935
|
cfg: Dict[str, CfgVal] = {}
|
|
@@ -874,6 +939,23 @@ class runopts:
|
|
|
874
939
|
cfg[key] = _cast_to_type(val, runopt_.opt_type)
|
|
875
940
|
return cfg
|
|
876
941
|
|
|
942
|
+
def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
|
|
943
|
+
"""
|
|
944
|
+
Converts the given dict to a valid cfg for this ``runopts`` object.
|
|
945
|
+
"""
|
|
946
|
+
cfg: Dict[str, CfgVal] = {}
|
|
947
|
+
cfg_dict = json.loads(json_repr)
|
|
948
|
+
for key, val in cfg_dict.items():
|
|
949
|
+
runopt_ = self.get(key)
|
|
950
|
+
if runopt_:
|
|
951
|
+
if runopt_.opt_type == List[str]:
|
|
952
|
+
cfg[key] = [str(v) for v in val]
|
|
953
|
+
elif runopt_.opt_type == Dict[str, str]:
|
|
954
|
+
cfg[key] = {str(k): str(v) for k, v in val.items()}
|
|
955
|
+
else:
|
|
956
|
+
cfg[key] = val
|
|
957
|
+
return cfg
|
|
958
|
+
|
|
877
959
|
def add(
|
|
878
960
|
self,
|
|
879
961
|
cfg_key: str,
|
torchx/specs/builders.py
CHANGED
|
@@ -4,25 +4,25 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import argparse
|
|
8
10
|
import inspect
|
|
11
|
+
import os
|
|
12
|
+
from argparse import Namespace
|
|
9
13
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|
10
14
|
|
|
11
15
|
from torchx.specs.api import BindMount, MountType, VolumeMount
|
|
12
16
|
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
|
|
13
|
-
from torchx.util.types import
|
|
14
|
-
decode_from_string,
|
|
15
|
-
decode_optional,
|
|
16
|
-
get_argparse_param_type,
|
|
17
|
-
is_bool,
|
|
18
|
-
is_primitive,
|
|
19
|
-
)
|
|
17
|
+
from torchx.util.types import decode, decode_optional, get_argparse_param_type, is_bool
|
|
20
18
|
|
|
21
19
|
from .api import AppDef, DeviceMount
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
def _create_args_parser(
|
|
25
|
-
cmpnt_fn: Callable[..., AppDef],
|
|
23
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
24
|
+
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
25
|
+
config: Optional[Dict[str, Any]] = None,
|
|
26
26
|
) -> argparse.ArgumentParser:
|
|
27
27
|
parameters = inspect.signature(cmpnt_fn).parameters
|
|
28
28
|
function_desc, args_desc = get_fn_docstring(cmpnt_fn)
|
|
@@ -85,15 +85,55 @@ def _create_args_parser(
|
|
|
85
85
|
if len(param_name) == 1:
|
|
86
86
|
arg_names = [f"-{param_name}"] + arg_names
|
|
87
87
|
if "default" not in args:
|
|
88
|
-
|
|
88
|
+
if (config and param_name not in config) or not config:
|
|
89
|
+
args["required"] = True
|
|
90
|
+
|
|
89
91
|
script_parser.add_argument(*arg_names, **args)
|
|
90
92
|
return script_parser
|
|
91
93
|
|
|
92
94
|
|
|
95
|
+
def _merge_config_values_with_args(
|
|
96
|
+
parsed_args: argparse.Namespace, config: Dict[str, Any]
|
|
97
|
+
) -> None:
|
|
98
|
+
for key, val in config.items():
|
|
99
|
+
if key in parsed_args:
|
|
100
|
+
setattr(parsed_args, key, val)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def parse_args(
|
|
104
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
105
|
+
cmpnt_args: List[str],
|
|
106
|
+
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
107
|
+
config: Optional[Dict[str, Any]] = None,
|
|
108
|
+
) -> Namespace:
|
|
109
|
+
"""
|
|
110
|
+
Parse passed arguments, defaults, and config values into a namespace for
|
|
111
|
+
a component function.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
cmpnt_fn: Component function
|
|
115
|
+
cmpnt_args: Function args
|
|
116
|
+
cmpnt_defaults: Additional default values for parameters of ``app_fn``
|
|
117
|
+
(overrides the defaults set on the fn declaration)
|
|
118
|
+
config: Optional dict containing additional configuration for the component from a passed config file
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
A Namespace object with the args, defaults, and config values incorporated.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults, config)
|
|
125
|
+
parsed_args = script_parser.parse_args(cmpnt_args)
|
|
126
|
+
if config:
|
|
127
|
+
_merge_config_values_with_args(parsed_args, config)
|
|
128
|
+
|
|
129
|
+
return parsed_args
|
|
130
|
+
|
|
131
|
+
|
|
93
132
|
def materialize_appdef(
|
|
94
133
|
cmpnt_fn: Callable[..., AppDef],
|
|
95
134
|
cmpnt_args: List[str],
|
|
96
|
-
cmpnt_defaults: Optional[Dict[str,
|
|
135
|
+
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
136
|
+
config: Optional[Dict[str, Any]] = None,
|
|
97
137
|
) -> AppDef:
|
|
98
138
|
"""
|
|
99
139
|
Creates an application by running user defined ``app_fn``.
|
|
@@ -118,26 +158,23 @@ def materialize_appdef(
|
|
|
118
158
|
cmpnt_args: Function args
|
|
119
159
|
cmpnt_defaults: Additional default values for parameters of ``app_fn``
|
|
120
160
|
(overrides the defaults set on the fn declaration)
|
|
161
|
+
config: Optional dict containing additional configuration for the component from a passed config file
|
|
121
162
|
Returns:
|
|
122
163
|
An application spec
|
|
123
164
|
"""
|
|
124
165
|
|
|
125
|
-
script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults)
|
|
126
|
-
parsed_args = script_parser.parse_args(cmpnt_args)
|
|
127
|
-
|
|
128
166
|
function_args = []
|
|
129
167
|
var_arg = []
|
|
130
168
|
kwargs = {}
|
|
131
169
|
|
|
170
|
+
parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)
|
|
171
|
+
|
|
132
172
|
parameters = inspect.signature(cmpnt_fn).parameters
|
|
133
173
|
for param_name, parameter in parameters.items():
|
|
134
174
|
arg_value = getattr(parsed_args, param_name)
|
|
135
175
|
parameter_type = parameter.annotation
|
|
136
176
|
parameter_type = decode_optional(parameter_type)
|
|
137
|
-
|
|
138
|
-
arg_value = arg_value and arg_value.lower() == "true"
|
|
139
|
-
elif not is_primitive(parameter_type):
|
|
140
|
-
arg_value = decode_from_string(arg_value, parameter_type)
|
|
177
|
+
arg_value = decode(arg_value, parameter_type)
|
|
141
178
|
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
142
179
|
var_arg = arg_value
|
|
143
180
|
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
@@ -149,7 +186,9 @@ def materialize_appdef(
|
|
|
149
186
|
if len(var_arg) > 0 and var_arg[0] == "--":
|
|
150
187
|
var_arg = var_arg[1:]
|
|
151
188
|
|
|
152
|
-
|
|
189
|
+
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
|
|
190
|
+
|
|
191
|
+
return appdef
|
|
153
192
|
|
|
154
193
|
|
|
155
194
|
def make_app_handle(scheduler_backend: str, session_name: str, app_id: str) -> str:
|
|
@@ -205,9 +244,12 @@ def parse_mounts(opts: List[str]) -> List[Union[BindMount, VolumeMount, DeviceMo
|
|
|
205
244
|
for opts in mount_opts:
|
|
206
245
|
typ = opts.get("type")
|
|
207
246
|
if typ == MountType.BIND:
|
|
247
|
+
src_path = opts["src"]
|
|
248
|
+
if src_path.startswith("~"):
|
|
249
|
+
src_path = os.path.expanduser(src_path)
|
|
208
250
|
mounts.append(
|
|
209
251
|
BindMount(
|
|
210
|
-
src_path=
|
|
252
|
+
src_path=src_path,
|
|
211
253
|
dst_path=opts["dst"],
|
|
212
254
|
read_only="readonly" in opts,
|
|
213
255
|
)
|
torchx/specs/file_linter.py
CHANGED
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
8
10
|
import abc
|
|
9
11
|
import argparse
|
|
10
12
|
import ast
|
|
@@ -29,7 +31,11 @@ def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str,
|
|
|
29
31
|
return args_decs
|
|
30
32
|
|
|
31
33
|
|
|
32
|
-
class TorchXArgumentHelpFormatter(
|
|
34
|
+
class TorchXArgumentHelpFormatter(
|
|
35
|
+
argparse.RawDescriptionHelpFormatter,
|
|
36
|
+
argparse.ArgumentDefaultsHelpFormatter,
|
|
37
|
+
argparse.MetavarTypeHelpFormatter,
|
|
38
|
+
):
|
|
33
39
|
"""Help message formatter which adds default values and required to argument help.
|
|
34
40
|
|
|
35
41
|
If the argument is required, the class appends `(required)` at the end of the help message.
|
|
@@ -79,7 +85,7 @@ to your component (see: https://pytorch.org/torchx/latest/component_best_practic
|
|
|
79
85
|
args_description[param.arg_name] = param.description
|
|
80
86
|
short_func_description = docstring.short_description or default_fn_desc
|
|
81
87
|
if docstring.long_description:
|
|
82
|
-
short_func_description += "
|
|
88
|
+
short_func_description += "\n" + docstring.long_description
|
|
83
89
|
return (short_func_description or default_fn_desc, args_description)
|
|
84
90
|
|
|
85
91
|
|
torchx/specs/finder.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
r"""
|
|
8
10
|
`torchx.specs.named_resources_aws` contains resource definitions that represent corresponding AWS instance types
|
|
9
11
|
taken from https://aws.amazon.com/ec2/instance-types/. The resources are exposed
|
|
@@ -35,6 +37,7 @@ from typing import Callable, Mapping
|
|
|
35
37
|
from torchx.specs.api import Resource
|
|
36
38
|
|
|
37
39
|
EFA_DEVICE = "vpc.amazonaws.com/efa"
|
|
40
|
+
NEURON_DEVICE = "aws.amazon.com/neurondevice"
|
|
38
41
|
|
|
39
42
|
# ecs and ec2 have memtax and currently AWS Batch uses hard memory limits
|
|
40
43
|
# so we have to account for mem tax when registering these resources for AWS
|
|
@@ -107,6 +110,16 @@ def aws_p4de_24xlarge() -> Resource:
|
|
|
107
110
|
)
|
|
108
111
|
|
|
109
112
|
|
|
113
|
+
def aws_p5_48xlarge() -> Resource:
|
|
114
|
+
return Resource(
|
|
115
|
+
cpu=192,
|
|
116
|
+
gpu=8,
|
|
117
|
+
memMB=2048 * GiB,
|
|
118
|
+
capabilities={K8S_ITYPE: "p5.48xlarge"},
|
|
119
|
+
devices={EFA_DEVICE: 32},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
110
123
|
def aws_t3_medium() -> Resource:
|
|
111
124
|
return Resource(cpu=2, gpu=0, memMB=4 * GiB, capabilities={K8S_ITYPE: "t3.medium"})
|
|
112
125
|
|
|
@@ -117,6 +130,12 @@ def aws_m5_2xlarge() -> Resource:
|
|
|
117
130
|
)
|
|
118
131
|
|
|
119
132
|
|
|
133
|
+
def aws_c5_18xlarge() -> Resource:
|
|
134
|
+
return Resource(
|
|
135
|
+
cpu=72, gpu=0, memMB=144 * GiB, capabilities={K8S_ITYPE: "c5.18xlarge"}
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
120
139
|
def aws_g4dn_xlarge() -> Resource:
|
|
121
140
|
return Resource(
|
|
122
141
|
cpu=4, gpu=1, memMB=16 * GiB, capabilities={K8S_ITYPE: "g4dn.xlarge"}
|
|
@@ -241,9 +260,87 @@ def aws_g5_48xlarge() -> Resource:
|
|
|
241
260
|
)
|
|
242
261
|
|
|
243
262
|
|
|
263
|
+
def aws_g6e_xlarge() -> Resource:
|
|
264
|
+
return Resource(
|
|
265
|
+
cpu=4,
|
|
266
|
+
gpu=1,
|
|
267
|
+
memMB=32 * GiB,
|
|
268
|
+
capabilities={K8S_ITYPE: "g6e.xlarge"},
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def aws_g6e_2xlarge() -> Resource:
|
|
273
|
+
return Resource(
|
|
274
|
+
cpu=8,
|
|
275
|
+
gpu=1,
|
|
276
|
+
memMB=64 * GiB,
|
|
277
|
+
capabilities={K8S_ITYPE: "g6e.2xlarge"},
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def aws_g6e_4xlarge() -> Resource:
|
|
282
|
+
return Resource(
|
|
283
|
+
cpu=16,
|
|
284
|
+
gpu=1,
|
|
285
|
+
memMB=128 * GiB,
|
|
286
|
+
capabilities={K8S_ITYPE: "g6e.4xlarge"},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def aws_g6e_8xlarge() -> Resource:
|
|
291
|
+
return Resource(
|
|
292
|
+
cpu=32,
|
|
293
|
+
gpu=1,
|
|
294
|
+
memMB=256 * GiB,
|
|
295
|
+
capabilities={K8S_ITYPE: "g6e.8xlarge"},
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def aws_g6e_16xlarge() -> Resource:
|
|
300
|
+
return Resource(
|
|
301
|
+
cpu=64,
|
|
302
|
+
gpu=1,
|
|
303
|
+
memMB=512 * GiB,
|
|
304
|
+
capabilities={K8S_ITYPE: "g6e.16xlarge"},
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def aws_g6e_12xlarge() -> Resource:
|
|
309
|
+
return Resource(
|
|
310
|
+
cpu=48,
|
|
311
|
+
gpu=4,
|
|
312
|
+
memMB=384 * GiB,
|
|
313
|
+
capabilities={K8S_ITYPE: "g6e.12xlarge"},
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def aws_g6e_24xlarge() -> Resource:
|
|
318
|
+
return Resource(
|
|
319
|
+
cpu=96,
|
|
320
|
+
gpu=4,
|
|
321
|
+
memMB=768 * GiB,
|
|
322
|
+
capabilities={K8S_ITYPE: "g6e.24xlarge"},
|
|
323
|
+
devices={EFA_DEVICE: 2},
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def aws_g6e_48xlarge() -> Resource:
|
|
328
|
+
return Resource(
|
|
329
|
+
cpu=192,
|
|
330
|
+
gpu=8,
|
|
331
|
+
memMB=1536 * GiB,
|
|
332
|
+
capabilities={K8S_ITYPE: "g6e.48xlarge"},
|
|
333
|
+
devices={EFA_DEVICE: 4},
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
|
|
244
337
|
def aws_trn1_2xlarge() -> Resource:
|
|
245
338
|
return Resource(
|
|
246
|
-
cpu=8,
|
|
339
|
+
cpu=8,
|
|
340
|
+
gpu=0,
|
|
341
|
+
memMB=32 * GiB,
|
|
342
|
+
capabilities={K8S_ITYPE: "trn1.2xlarge"},
|
|
343
|
+
devices={NEURON_DEVICE: 1},
|
|
247
344
|
)
|
|
248
345
|
|
|
249
346
|
|
|
@@ -253,19 +350,21 @@ def aws_trn1_32xlarge() -> Resource:
|
|
|
253
350
|
gpu=0,
|
|
254
351
|
memMB=512 * GiB,
|
|
255
352
|
capabilities={K8S_ITYPE: "trn1.32xlarge"},
|
|
256
|
-
devices={EFA_DEVICE: 8},
|
|
353
|
+
devices={EFA_DEVICE: 8, NEURON_DEVICE: 16},
|
|
257
354
|
)
|
|
258
355
|
|
|
259
356
|
|
|
260
357
|
NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
|
|
261
358
|
"aws_t3.medium": aws_t3_medium,
|
|
262
359
|
"aws_m5.2xlarge": aws_m5_2xlarge,
|
|
360
|
+
"aws_c5.18xlarge": aws_c5_18xlarge,
|
|
263
361
|
"aws_p3.2xlarge": aws_p3_2xlarge,
|
|
264
362
|
"aws_p3.8xlarge": aws_p3_8xlarge,
|
|
265
363
|
"aws_p3.16xlarge": aws_p3_16xlarge,
|
|
266
364
|
"aws_p3dn.24xlarge": aws_p3dn_24xlarge,
|
|
267
365
|
"aws_p4d.24xlarge": aws_p4d_24xlarge,
|
|
268
366
|
"aws_p4de.24xlarge": aws_p4de_24xlarge,
|
|
367
|
+
"aws_p5.48xlarge": aws_p5_48xlarge,
|
|
269
368
|
"aws_g4dn.xlarge": aws_g4dn_xlarge,
|
|
270
369
|
"aws_g4dn.2xlarge": aws_g4dn_2xlarge,
|
|
271
370
|
"aws_g4dn.4xlarge": aws_g4dn_4xlarge,
|
|
@@ -281,6 +380,14 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
|
|
|
281
380
|
"aws_g5.12xlarge": aws_g5_12xlarge,
|
|
282
381
|
"aws_g5.24xlarge": aws_g5_24xlarge,
|
|
283
382
|
"aws_g5.48xlarge": aws_g5_48xlarge,
|
|
383
|
+
"aws_g6e.xlarge": aws_g6e_xlarge,
|
|
384
|
+
"aws_g6e.2xlarge": aws_g6e_2xlarge,
|
|
385
|
+
"aws_g6e.4xlarge": aws_g6e_4xlarge,
|
|
386
|
+
"aws_g6e.8xlarge": aws_g6e_8xlarge,
|
|
387
|
+
"aws_g6e.16xlarge": aws_g6e_16xlarge,
|
|
388
|
+
"aws_g6e.12xlarge": aws_g6e_12xlarge,
|
|
389
|
+
"aws_g6e.24xlarge": aws_g6e_24xlarge,
|
|
390
|
+
"aws_g6e.48xlarge": aws_g6e_48xlarge,
|
|
284
391
|
"aws_trn1.2xlarge": aws_trn1_2xlarge,
|
|
285
392
|
"aws_trn1.32xlarge": aws_trn1_32xlarge,
|
|
286
393
|
}
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
Defines generic named resources that are not specific to any cloud provider's
|
|
9
11
|
instance types. These generic named resources are meant to be used as
|
torchx/tracker/__init__.py
CHANGED
torchx/tracker/api.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
from __future__ import annotations
|
|
8
10
|
|
|
9
11
|
import logging
|
|
@@ -67,8 +69,7 @@ class AppRunTrackableSource:
|
|
|
67
69
|
artifact_name: Optional[str]
|
|
68
70
|
|
|
69
71
|
|
|
70
|
-
class Lineage:
|
|
71
|
-
...
|
|
72
|
+
class Lineage: ...
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class TrackerBase(ABC):
|
|
@@ -332,5 +333,4 @@ class AppRun:
|
|
|
332
333
|
|
|
333
334
|
return model_run_sources
|
|
334
335
|
|
|
335
|
-
def children(self) -> Iterable[AppRun]:
|
|
336
|
-
...
|
|
336
|
+
def children(self) -> Iterable[AppRun]: ...
|
torchx/tracker/backend/fsspec.py
CHANGED
torchx/util/cuda.py
CHANGED