torchx-nightly 2025.7.9__py3-none-any.whl → 2025.11.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/cli/cmd_list.py +1 -2
- torchx/cli/cmd_run.py +202 -28
- torchx/cli/cmd_tracker.py +1 -1
- torchx/components/__init__.py +1 -8
- torchx/components/dist.py +9 -3
- torchx/components/integration_tests/component_provider.py +2 -2
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +92 -81
- torchx/runner/config.py +11 -9
- torchx/runner/events/__init__.py +20 -10
- torchx/runner/events/api.py +1 -1
- torchx/schedulers/__init__.py +7 -10
- torchx/schedulers/api.py +20 -15
- torchx/schedulers/aws_batch_scheduler.py +45 -2
- torchx/schedulers/docker_scheduler.py +3 -0
- torchx/schedulers/kubernetes_scheduler.py +200 -17
- torchx/schedulers/local_scheduler.py +1 -0
- torchx/schedulers/slurm_scheduler.py +160 -26
- torchx/specs/__init__.py +23 -6
- torchx/specs/api.py +279 -33
- torchx/specs/builders.py +109 -28
- torchx/specs/file_linter.py +117 -53
- torchx/specs/finder.py +25 -37
- torchx/specs/named_resources_aws.py +13 -2
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/util/strings.py +1 -1
- torchx/util/types.py +12 -1
- torchx/version.py +2 -2
- torchx/workspace/api.py +102 -5
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
- torchx/pipelines/kfp/__init__.py +0 -30
- torchx/pipelines/kfp/adapter.py +0 -274
- torchx/pipelines/kfp/version.py +0 -19
- torchx/schedulers/gcp_batch_scheduler.py +0 -497
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -454
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
torchx/specs/builders.py
CHANGED
|
@@ -4,13 +4,13 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
# pyre-
|
|
7
|
+
# pyre-unsafe
|
|
8
8
|
|
|
9
9
|
import argparse
|
|
10
10
|
import inspect
|
|
11
11
|
import os
|
|
12
12
|
from argparse import Namespace
|
|
13
|
-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|
13
|
+
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union
|
|
14
14
|
|
|
15
15
|
from torchx.specs.api import BindMount, MountType, VolumeMount
|
|
16
16
|
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
|
|
@@ -19,6 +19,14 @@ from torchx.util.types import decode, decode_optional, get_argparse_param_type,
|
|
|
19
19
|
from .api import AppDef, DeviceMount
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
class ComponentArgs(NamedTuple):
|
|
23
|
+
"""Parsed component function arguments"""
|
|
24
|
+
|
|
25
|
+
positional_args: dict[str, Any]
|
|
26
|
+
var_args: list[str]
|
|
27
|
+
kwargs: dict[str, Any]
|
|
28
|
+
|
|
29
|
+
|
|
22
30
|
def _create_args_parser(
|
|
23
31
|
cmpnt_fn: Callable[..., AppDef],
|
|
24
32
|
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
@@ -31,7 +39,7 @@ def _create_args_parser(
|
|
|
31
39
|
|
|
32
40
|
|
|
33
41
|
def _create_args_parser_from_parameters(
|
|
34
|
-
cmpnt_fn: Callable[...,
|
|
42
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
35
43
|
parameters: Mapping[str, inspect.Parameter],
|
|
36
44
|
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
37
45
|
config: Optional[Dict[str, Any]] = None,
|
|
@@ -112,7 +120,7 @@ def _merge_config_values_with_args(
|
|
|
112
120
|
|
|
113
121
|
|
|
114
122
|
def parse_args(
|
|
115
|
-
cmpnt_fn: Callable[...,
|
|
123
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
116
124
|
cmpnt_args: List[str],
|
|
117
125
|
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
118
126
|
config: Optional[Dict[str, Any]] = None,
|
|
@@ -140,8 +148,97 @@ def parse_args(
|
|
|
140
148
|
return parsed_args
|
|
141
149
|
|
|
142
150
|
|
|
151
|
+
def component_args_from_str(
|
|
152
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
153
|
+
cmpnt_args: list[str],
|
|
154
|
+
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
|
|
155
|
+
config: Optional[Dict[str, Any]] = None,
|
|
156
|
+
) -> ComponentArgs:
|
|
157
|
+
"""
|
|
158
|
+
Parses and decodes command-line arguments for a component function.
|
|
159
|
+
|
|
160
|
+
This function takes a component function and its arguments, parses them using argparse,
|
|
161
|
+
and decodes the arguments into their expected types based on the function's signature.
|
|
162
|
+
It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
cmpnt_fn: The component function whose arguments are to be parsed and decoded.
|
|
166
|
+
cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
|
|
167
|
+
cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
|
|
168
|
+
config: Optional dictionary containing additional configuration values.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
ComponentArgs representing the input args to a component function containing:
|
|
172
|
+
- positional_args: Dictionary of positional and positional-or-keyword arguments.
|
|
173
|
+
- var_args: List of variable positional arguments (*args).
|
|
174
|
+
- kwargs: Dictionary of keyword-only arguments.
|
|
175
|
+
|
|
176
|
+
Usage:
|
|
177
|
+
|
|
178
|
+
.. doctest::
|
|
179
|
+
from torchx.specs.api import AppDef
|
|
180
|
+
from torchx.specs.builders import component_args_from_str
|
|
181
|
+
|
|
182
|
+
def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
|
|
183
|
+
return AppDef(name="example")
|
|
184
|
+
|
|
185
|
+
# Supports space separated arguments
|
|
186
|
+
args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
|
|
187
|
+
parsed_args = component_args_from_str(example_component_fn, args)
|
|
188
|
+
|
|
189
|
+
assert parsed_args.positional_args == {"foo": "fooval"}
|
|
190
|
+
assert parsed_args.var_args == ["arg1", "arg2"]
|
|
191
|
+
assert parsed_args.kwargs == {"bar": "barval"}
|
|
192
|
+
|
|
193
|
+
# Supports '=' separated arguments
|
|
194
|
+
args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
|
|
195
|
+
parsed_args = component_args_from_str(example_component_fn, args)
|
|
196
|
+
|
|
197
|
+
assert parsed_args.positional_args == {"foo": "fooval"}
|
|
198
|
+
assert parsed_args.var_args == ["arg1", "arg2"]
|
|
199
|
+
assert parsed_args.kwargs == {"bar": "barval"}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
parsed_args: Namespace = parse_args(
|
|
204
|
+
cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
positional_args = {}
|
|
208
|
+
var_args = []
|
|
209
|
+
kwargs = {}
|
|
210
|
+
|
|
211
|
+
parameters = inspect.signature(cmpnt_fn).parameters
|
|
212
|
+
for param_name, parameter in parameters.items():
|
|
213
|
+
arg_value = getattr(parsed_args, param_name)
|
|
214
|
+
parameter_type = parameter.annotation
|
|
215
|
+
parameter_type = decode_optional(parameter_type)
|
|
216
|
+
if (
|
|
217
|
+
parameter_type != arg_value.__class__
|
|
218
|
+
and parameter.kind != inspect.Parameter.VAR_POSITIONAL
|
|
219
|
+
):
|
|
220
|
+
arg_value = decode(arg_value, parameter_type)
|
|
221
|
+
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
222
|
+
var_args = arg_value
|
|
223
|
+
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
224
|
+
kwargs[param_name] = arg_value
|
|
225
|
+
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
226
|
+
raise TypeError(
|
|
227
|
+
f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
|
|
228
|
+
f"type to a dict or explicitly declare the params"
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
# POSITIONAL or POSITIONAL_OR_KEYWORD
|
|
232
|
+
positional_args[param_name] = arg_value
|
|
233
|
+
|
|
234
|
+
if len(var_args) > 0 and var_args[0] == "--":
|
|
235
|
+
var_args = var_args[1:]
|
|
236
|
+
|
|
237
|
+
return ComponentArgs(positional_args, var_args, kwargs)
|
|
238
|
+
|
|
239
|
+
|
|
143
240
|
def materialize_appdef(
|
|
144
|
-
cmpnt_fn: Callable[...,
|
|
241
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
145
242
|
cmpnt_args: List[str],
|
|
146
243
|
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
147
244
|
config: Optional[Dict[str, Any]] = None,
|
|
@@ -174,30 +271,14 @@ def materialize_appdef(
|
|
|
174
271
|
An application spec
|
|
175
272
|
"""
|
|
176
273
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
for param_name, parameter in parameters.items():
|
|
185
|
-
arg_value = getattr(parsed_args, param_name)
|
|
186
|
-
parameter_type = parameter.annotation
|
|
187
|
-
parameter_type = decode_optional(parameter_type)
|
|
188
|
-
arg_value = decode(arg_value, parameter_type)
|
|
189
|
-
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
190
|
-
var_arg = arg_value
|
|
191
|
-
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
192
|
-
kwargs[param_name] = arg_value
|
|
193
|
-
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
194
|
-
raise TypeError("**kwargs are not supported for component definitions")
|
|
195
|
-
else:
|
|
196
|
-
function_args.append(arg_value)
|
|
197
|
-
if len(var_arg) > 0 and var_arg[0] == "--":
|
|
198
|
-
var_arg = var_arg[1:]
|
|
274
|
+
component_args: ComponentArgs = component_args_from_str(
|
|
275
|
+
cmpnt_fn, cmpnt_args, cmpnt_defaults, config
|
|
276
|
+
)
|
|
277
|
+
positional_arg_values = list(component_args.positional_args.values())
|
|
278
|
+
appdef = cmpnt_fn(
|
|
279
|
+
*positional_arg_values, *component_args.var_args, **component_args.kwargs
|
|
280
|
+
)
|
|
199
281
|
|
|
200
|
-
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
|
|
201
282
|
if not isinstance(appdef, AppDef):
|
|
202
283
|
raise TypeError(
|
|
203
284
|
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
|
torchx/specs/file_linter.py
CHANGED
|
@@ -11,8 +11,9 @@ import abc
|
|
|
11
11
|
import argparse
|
|
12
12
|
import ast
|
|
13
13
|
import inspect
|
|
14
|
+
import sys
|
|
14
15
|
from dataclasses import dataclass
|
|
15
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
16
17
|
|
|
17
18
|
from docstring_parser import parse
|
|
18
19
|
from torchx.util.io import read_conf_file
|
|
@@ -74,7 +75,7 @@ def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
|
|
|
74
75
|
if the description
|
|
75
76
|
"""
|
|
76
77
|
default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
|
|
77
|
-
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
|
|
78
|
+
to your component (see: https://meta-pytorch.org/torchx/latest/component_best_practices.html)"""
|
|
78
79
|
args_description = _get_default_arguments_descriptions(fn)
|
|
79
80
|
func_description = inspect.getdoc(fn)
|
|
80
81
|
if not func_description:
|
|
@@ -98,7 +99,7 @@ class LinterMessage:
|
|
|
98
99
|
severity: str = "error"
|
|
99
100
|
|
|
100
101
|
|
|
101
|
-
class
|
|
102
|
+
class ComponentFunctionValidator(abc.ABC):
|
|
102
103
|
@abc.abstractmethod
|
|
103
104
|
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
|
|
104
105
|
"""
|
|
@@ -116,7 +117,55 @@ class TorchxFunctionValidator(abc.ABC):
|
|
|
116
117
|
)
|
|
117
118
|
|
|
118
119
|
|
|
119
|
-
|
|
120
|
+
def OK() -> list[LinterMessage]:
|
|
121
|
+
return [] # empty linter error means validation passes
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def is_primitive(arg: ast.expr) -> bool:
|
|
125
|
+
# whether the arg is a primitive type (e.g. int, float, str, bool)
|
|
126
|
+
return isinstance(arg, ast.Name)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_generic_type(arg: ast.expr) -> ast.expr:
|
|
130
|
+
# returns the slice expr of a subscripted type
|
|
131
|
+
# `arg` must be an instance of ast.Subscript (caller checks)
|
|
132
|
+
# in this validator's context, this is the generic type of a container type
|
|
133
|
+
# e.g. for Optional[str] returns the expr for str
|
|
134
|
+
|
|
135
|
+
assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
|
|
136
|
+
|
|
137
|
+
if isinstance(arg.slice, ast.Index): # python>=3.10
|
|
138
|
+
return arg.slice.value
|
|
139
|
+
else: # python-3.9
|
|
140
|
+
return arg.slice
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
|
|
144
|
+
"""
|
|
145
|
+
Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
|
|
146
|
+
is not an ``Optional``. Handles both:
|
|
147
|
+
1. ``typing.Optional[T]`` (python<3.10)
|
|
148
|
+
2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
|
|
149
|
+
"""
|
|
150
|
+
# case 1: 'a: Optional[T]'
|
|
151
|
+
if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
|
|
152
|
+
return get_generic_type(arg)
|
|
153
|
+
|
|
154
|
+
# case 2: 'a: T | None' or 'a: None | T'
|
|
155
|
+
if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
|
|
156
|
+
if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
|
|
157
|
+
if isinstance(arg.right, ast.Constant) and arg.right.value is None:
|
|
158
|
+
return arg.left
|
|
159
|
+
if isinstance(arg.left, ast.Constant) and arg.left.value is None:
|
|
160
|
+
return arg.right
|
|
161
|
+
|
|
162
|
+
# case 3: is not optional
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ArgTypeValidator(ComponentFunctionValidator):
|
|
167
|
+
"""Validates component function's argument types."""
|
|
168
|
+
|
|
120
169
|
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
|
|
121
170
|
linter_errors = []
|
|
122
171
|
for arg_def in app_specs_func_def.args.args:
|
|
@@ -133,53 +182,68 @@ class TorchxFunctionArgsValidator(TorchxFunctionValidator):
|
|
|
133
182
|
return linter_errors
|
|
134
183
|
|
|
135
184
|
def _validate_arg_def(
|
|
136
|
-
self, function_name: str,
|
|
185
|
+
self, function_name: str, arg: ast.arg
|
|
137
186
|
) -> List[LinterMessage]:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
)
|
|
143
|
-
]
|
|
144
|
-
if isinstance(arg_def.annotation, ast.Name):
|
|
187
|
+
arg_type = arg.annotation # type hint
|
|
188
|
+
|
|
189
|
+
def ok() -> list[LinterMessage]:
|
|
190
|
+
# return value when validation passes (e.g. no linter errors)
|
|
145
191
|
return []
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
192
|
+
|
|
193
|
+
def err(reason: str) -> list[LinterMessage]:
|
|
194
|
+
msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
|
|
195
|
+
return [self._gen_linter_message(msg, arg.lineno)]
|
|
196
|
+
|
|
197
|
+
if not arg_type:
|
|
198
|
+
return err("Missing type annotation")
|
|
199
|
+
|
|
200
|
+
# Case 1: optional
|
|
201
|
+
if T := get_optional_type(arg_type):
|
|
202
|
+
# NOTE: optional types can be primitives or any of the allowed container types
|
|
203
|
+
# so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
|
|
204
|
+
arg_type = T
|
|
205
|
+
|
|
206
|
+
# Case 2: int, float, str, bool
|
|
207
|
+
if is_primitive(arg_type):
|
|
208
|
+
return ok()
|
|
209
|
+
# Case 3: Containers (Dict, List, Tuple)
|
|
210
|
+
elif isinstance(arg_type, ast.Subscript):
|
|
211
|
+
container_type = arg_type.value.id
|
|
212
|
+
|
|
213
|
+
if container_type in ["Dict", "dict"]:
|
|
214
|
+
KV = get_generic_type(arg_type)
|
|
215
|
+
|
|
216
|
+
assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
|
|
217
|
+
|
|
218
|
+
K, V = KV.elts
|
|
219
|
+
if not is_primitive(K):
|
|
220
|
+
return err(f"Non-primitive key type {ast.unparse(K)!r}")
|
|
221
|
+
if not is_primitive(V):
|
|
222
|
+
return err(f"Non-primitive value type {ast.unparse(V)!r}")
|
|
223
|
+
return ok()
|
|
224
|
+
elif container_type in ["List", "list"]:
|
|
225
|
+
T = get_generic_type(arg_type)
|
|
226
|
+
if is_primitive(T):
|
|
227
|
+
return ok()
|
|
228
|
+
else:
|
|
229
|
+
return err(f"Non-primitive element type {ast.unparse(T)!r}")
|
|
230
|
+
elif container_type in ["Tuple", "tuple"]:
|
|
231
|
+
E_N = get_generic_type(arg_type)
|
|
232
|
+
assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
|
|
233
|
+
|
|
234
|
+
for e in E_N.elts:
|
|
235
|
+
if not is_primitive(e):
|
|
236
|
+
return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
|
|
237
|
+
|
|
238
|
+
return ok()
|
|
239
|
+
|
|
240
|
+
return err(f"Unsupported container type {container_type!r}")
|
|
168
241
|
else:
|
|
169
|
-
|
|
170
|
-
if type_name == "Dict":
|
|
171
|
-
sub_type_tuple = cast(ast.Tuple, sub_type)
|
|
172
|
-
for el in sub_type_tuple.elts:
|
|
173
|
-
if not isinstance(el, ast.Name):
|
|
174
|
-
desc = "Dict can only have primitive types"
|
|
175
|
-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
|
|
176
|
-
elif not isinstance(sub_type, ast.Name):
|
|
177
|
-
desc = "List can only have primitive types"
|
|
178
|
-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
|
|
179
|
-
return linter_errors
|
|
242
|
+
return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
|
|
180
243
|
|
|
181
244
|
|
|
182
|
-
class
|
|
245
|
+
class ReturnTypeValidator(ComponentFunctionValidator):
|
|
246
|
+
"""Validates that component functions always return AppDef type"""
|
|
183
247
|
|
|
184
248
|
def __init__(self, supported_return_type: str) -> None:
|
|
185
249
|
super().__init__()
|
|
@@ -231,7 +295,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
|
|
|
231
295
|
return linter_errors
|
|
232
296
|
|
|
233
297
|
|
|
234
|
-
class
|
|
298
|
+
class ComponentFnVisitor(ast.NodeVisitor):
|
|
235
299
|
"""
|
|
236
300
|
Visitor that finds the component_function and runs registered validators on it.
|
|
237
301
|
Current registered validators:
|
|
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
|
|
|
252
316
|
def __init__(
|
|
253
317
|
self,
|
|
254
318
|
component_function_name: str,
|
|
255
|
-
validators: Optional[List[
|
|
319
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
256
320
|
) -> None:
|
|
257
321
|
if validators is None:
|
|
258
|
-
self.validators: List[
|
|
259
|
-
|
|
260
|
-
|
|
322
|
+
self.validators: List[ComponentFunctionValidator] = [
|
|
323
|
+
ArgTypeValidator(),
|
|
324
|
+
ReturnTypeValidator("AppDef"),
|
|
261
325
|
]
|
|
262
326
|
else:
|
|
263
327
|
self.validators = validators
|
|
@@ -279,7 +343,7 @@ class TorchFunctionVisitor(ast.NodeVisitor):
|
|
|
279
343
|
def validate(
|
|
280
344
|
path: str,
|
|
281
345
|
component_function: str,
|
|
282
|
-
validators: Optional[List[
|
|
346
|
+
validators: Optional[List[ComponentFunctionValidator]] = None,
|
|
283
347
|
) -> List[LinterMessage]:
|
|
284
348
|
"""
|
|
285
349
|
Validates the function to make sure it complies the component standard.
|
|
@@ -309,7 +373,7 @@ def validate(
|
|
|
309
373
|
severity="error",
|
|
310
374
|
)
|
|
311
375
|
return [linter_message]
|
|
312
|
-
visitor =
|
|
376
|
+
visitor = ComponentFnVisitor(component_function, validators)
|
|
313
377
|
visitor.visit(module)
|
|
314
378
|
linter_errors = visitor.linter_errors
|
|
315
379
|
if not visitor.visited_function:
|
torchx/specs/finder.py
CHANGED
|
@@ -17,9 +17,15 @@ from dataclasses import dataclass
|
|
|
17
17
|
from inspect import getmembers, isfunction
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from types import ModuleType
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import Callable, Dict, Generator, List, Optional, Union
|
|
21
21
|
|
|
22
|
-
from torchx.specs
|
|
22
|
+
from torchx.specs import AppDef
|
|
23
|
+
|
|
24
|
+
from torchx.specs.file_linter import (
|
|
25
|
+
ComponentFunctionValidator,
|
|
26
|
+
get_fn_docstring,
|
|
27
|
+
validate,
|
|
28
|
+
)
|
|
23
29
|
from torchx.util import entrypoints
|
|
24
30
|
from torchx.util.io import read_conf_file
|
|
25
31
|
from torchx.util.types import none_throws
|
|
@@ -55,8 +61,7 @@ class _Component:
|
|
|
55
61
|
description: str
|
|
56
62
|
fn_name: str
|
|
57
63
|
|
|
58
|
-
|
|
59
|
-
fn: Callable[..., Any]
|
|
64
|
+
fn: Callable[..., AppDef]
|
|
60
65
|
|
|
61
66
|
validation_errors: List[str]
|
|
62
67
|
|
|
@@ -64,7 +69,7 @@ class _Component:
|
|
|
64
69
|
class ComponentsFinder(abc.ABC):
|
|
65
70
|
@abc.abstractmethod
|
|
66
71
|
def find(
|
|
67
|
-
self, validators: Optional[List[
|
|
72
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
68
73
|
) -> List[_Component]:
|
|
69
74
|
"""
|
|
70
75
|
Retrieves a set of components. A component is defined as a python
|
|
@@ -210,7 +215,7 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
210
215
|
yield self._try_import(module_info.name)
|
|
211
216
|
|
|
212
217
|
def find(
|
|
213
|
-
self, validators: Optional[List[
|
|
218
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
214
219
|
) -> List[_Component]:
|
|
215
220
|
components = []
|
|
216
221
|
for m in self._iter_modules_recursive(self.base_module):
|
|
@@ -230,7 +235,7 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
230
235
|
return module
|
|
231
236
|
|
|
232
237
|
def _get_components_from_module(
|
|
233
|
-
self, module: ModuleType, validators: Optional[List[
|
|
238
|
+
self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
|
|
234
239
|
) -> List[_Component]:
|
|
235
240
|
functions = getmembers(module, isfunction)
|
|
236
241
|
component_defs = []
|
|
@@ -269,28 +274,17 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
269
274
|
self,
|
|
270
275
|
path: str,
|
|
271
276
|
function_name: str,
|
|
272
|
-
validators: Optional[List[
|
|
277
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
273
278
|
) -> List[str]:
|
|
274
279
|
linter_errors = validate(path, function_name, validators)
|
|
275
280
|
return [linter_error.description for linter_error in linter_errors]
|
|
276
281
|
|
|
277
|
-
def _get_path_to_function_decl(
|
|
278
|
-
self, function: Callable[..., Any] # pyre-ignore[2]
|
|
279
|
-
) -> str:
|
|
280
|
-
"""
|
|
281
|
-
Attempts to return the path to the file where the function is implemented.
|
|
282
|
-
This can be different from the path where the function is looked up, for example if we have:
|
|
283
|
-
my_component defined in some_file.py, imported in other_file.py
|
|
284
|
-
and the component is invoked as other_file.py:my_component
|
|
285
|
-
"""
|
|
286
|
-
path_to_function_decl = inspect.getabsfile(function)
|
|
287
|
-
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
|
|
288
|
-
return self._filepath
|
|
289
|
-
return path_to_function_decl
|
|
290
|
-
|
|
291
282
|
def find(
|
|
292
|
-
self, validators: Optional[List[
|
|
283
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
293
284
|
) -> List[_Component]:
|
|
285
|
+
validation_errors = self._get_validation_errors(
|
|
286
|
+
self._filepath, self._function_name, validators
|
|
287
|
+
)
|
|
294
288
|
|
|
295
289
|
file_source = read_conf_file(self._filepath)
|
|
296
290
|
namespace = copy.copy(globals())
|
|
@@ -303,12 +297,6 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
303
297
|
)
|
|
304
298
|
app_fn = namespace[self._function_name]
|
|
305
299
|
fn_desc, _ = get_fn_docstring(app_fn)
|
|
306
|
-
|
|
307
|
-
func_path = self._get_path_to_function_decl(app_fn)
|
|
308
|
-
validation_errors = self._get_validation_errors(
|
|
309
|
-
func_path, self._function_name, validators
|
|
310
|
-
)
|
|
311
|
-
|
|
312
300
|
return [
|
|
313
301
|
_Component(
|
|
314
302
|
name=f"{self._filepath}:{self._function_name}",
|
|
@@ -321,7 +309,7 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
321
309
|
|
|
322
310
|
|
|
323
311
|
def _load_custom_components(
|
|
324
|
-
validators: Optional[List[
|
|
312
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
325
313
|
) -> List[_Component]:
|
|
326
314
|
component_modules = {
|
|
327
315
|
name: load_fn()
|
|
@@ -346,7 +334,7 @@ def _load_custom_components(
|
|
|
346
334
|
|
|
347
335
|
|
|
348
336
|
def _load_components(
|
|
349
|
-
validators: Optional[List[
|
|
337
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
350
338
|
) -> Dict[str, _Component]:
|
|
351
339
|
"""
|
|
352
340
|
Loads either the custom component defs from the entrypoint ``[torchx.components]``
|
|
@@ -368,7 +356,7 @@ _components: Optional[Dict[str, _Component]] = None
|
|
|
368
356
|
|
|
369
357
|
|
|
370
358
|
def _find_components(
|
|
371
|
-
validators: Optional[List[
|
|
359
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
372
360
|
) -> Dict[str, _Component]:
|
|
373
361
|
global _components
|
|
374
362
|
if not _components:
|
|
@@ -381,7 +369,7 @@ def _is_custom_component(component_name: str) -> bool:
|
|
|
381
369
|
|
|
382
370
|
|
|
383
371
|
def _find_custom_components(
|
|
384
|
-
name: str, validators: Optional[List[
|
|
372
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]]
|
|
385
373
|
) -> Dict[str, _Component]:
|
|
386
374
|
if ":" not in name:
|
|
387
375
|
raise ValueError(
|
|
@@ -393,7 +381,7 @@ def _find_custom_components(
|
|
|
393
381
|
|
|
394
382
|
|
|
395
383
|
def get_components(
|
|
396
|
-
validators: Optional[List[
|
|
384
|
+
validators: Optional[List[ComponentFunctionValidator]] = None,
|
|
397
385
|
) -> Dict[str, _Component]:
|
|
398
386
|
"""
|
|
399
387
|
Returns all custom components registered via ``[torchx.components]`` entrypoints
|
|
@@ -448,7 +436,7 @@ def get_components(
|
|
|
448
436
|
|
|
449
437
|
|
|
450
438
|
def get_component(
|
|
451
|
-
name: str, validators: Optional[List[
|
|
439
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
|
|
452
440
|
) -> _Component:
|
|
453
441
|
"""
|
|
454
442
|
Retrieves components by the provided name.
|
|
@@ -464,7 +452,7 @@ def get_component(
|
|
|
464
452
|
raise ComponentNotFoundException(
|
|
465
453
|
f"Component `{name}` not found. Please make sure it is one of the "
|
|
466
454
|
"builtins: `torchx builtins`. Or registered via `[torchx.components]` "
|
|
467
|
-
"entry point (see: https://pytorch.org/torchx/latest/configure.html)"
|
|
455
|
+
"entry point (see: https://meta-pytorch.org/torchx/latest/configure.html)"
|
|
468
456
|
)
|
|
469
457
|
|
|
470
458
|
component = components[name]
|
|
@@ -477,7 +465,7 @@ def get_component(
|
|
|
477
465
|
|
|
478
466
|
|
|
479
467
|
def get_builtin_source(
|
|
480
|
-
name: str, validators: Optional[List[
|
|
468
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
|
|
481
469
|
) -> str:
|
|
482
470
|
"""
|
|
483
471
|
Returns a string of the the builtin component's function source code
|
|
@@ -16,7 +16,7 @@ the equvalent resource in mem, cpu and gpu numbers.
|
|
|
16
16
|
|
|
17
17
|
.. note::
|
|
18
18
|
These resource definitions may change in future. It is expected for each user to
|
|
19
|
-
manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
|
|
19
|
+
manage their own resources. Follow https://meta-pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
|
|
20
20
|
to set up named resources.
|
|
21
21
|
|
|
22
22
|
Usage:
|
|
@@ -47,7 +47,7 @@ NEURON_DEVICE = "aws.amazon.com/neurondevice"
|
|
|
47
47
|
MEM_TAX = 0.96
|
|
48
48
|
|
|
49
49
|
# determines instance type for non-honogeneous CEs
|
|
50
|
-
# see https://github.com/pytorch/torchx/issues/780
|
|
50
|
+
# see https://github.com/meta-pytorch/torchx/issues/780
|
|
51
51
|
K8S_ITYPE = "node.kubernetes.io/instance-type"
|
|
52
52
|
GiB: int = int(1024 * MEM_TAX)
|
|
53
53
|
|
|
@@ -120,6 +120,16 @@ def aws_p5_48xlarge() -> Resource:
|
|
|
120
120
|
)
|
|
121
121
|
|
|
122
122
|
|
|
123
|
+
def aws_p5e_48xlarge() -> Resource:
|
|
124
|
+
return Resource(
|
|
125
|
+
cpu=192,
|
|
126
|
+
gpu=8,
|
|
127
|
+
memMB=2048 * GiB,
|
|
128
|
+
capabilities={K8S_ITYPE: "p5e.48xlarge"},
|
|
129
|
+
devices={EFA_DEVICE: 32},
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
123
133
|
def aws_p5en_48xlarge() -> Resource:
|
|
124
134
|
return Resource(
|
|
125
135
|
cpu=192,
|
|
@@ -419,6 +429,7 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
|
|
|
419
429
|
"aws_p4d.24xlarge": aws_p4d_24xlarge,
|
|
420
430
|
"aws_p4de.24xlarge": aws_p4de_24xlarge,
|
|
421
431
|
"aws_p5.48xlarge": aws_p5_48xlarge,
|
|
432
|
+
"aws_p5e.48xlarge": aws_p5e_48xlarge,
|
|
422
433
|
"aws_p5en.48xlarge": aws_p5en_48xlarge,
|
|
423
434
|
"aws_g4dn.xlarge": aws_g4dn_xlarge,
|
|
424
435
|
"aws_g4dn.2xlarge": aws_g4dn_2xlarge,
|
torchx/tracker/__init__.py
CHANGED
|
@@ -32,7 +32,7 @@ implementation.
|
|
|
32
32
|
|
|
33
33
|
Example usage
|
|
34
34
|
-------------
|
|
35
|
-
Sample `code <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
|
|
35
|
+
Sample `code <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
Tracker Setup
|
|
@@ -111,7 +111,7 @@ Use :py:meth:`~torchx.tracker.app_run_from_env`:
|
|
|
111
111
|
Reference :py:class:`~torchx.tracker.api.TrackerBase` implementation
|
|
112
112
|
--------------------------------------------------------------------
|
|
113
113
|
:py:class:`~torchx.tracker.backend.fsspec.FsspecTracker` provides reference implementation of a tracker backend.
|
|
114
|
-
GitHub example `directory <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
|
|
114
|
+
GitHub example `directory <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
|
|
115
115
|
configure and use it in user application.
|
|
116
116
|
|
|
117
117
|
|
torchx/tracker/api.py
CHANGED
|
@@ -191,7 +191,7 @@ def build_trackers(
|
|
|
191
191
|
factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
|
|
192
192
|
if not factory:
|
|
193
193
|
logger.warning(
|
|
194
|
-
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
|
|
194
|
+
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
|
|
195
195
|
)
|
|
196
196
|
continue
|
|
197
197
|
if config:
|