torchx-nightly 2025.8.5__py3-none-any.whl → 2025.11.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/cli/cmd_list.py +1 -2
- torchx/cli/cmd_run.py +202 -28
- torchx/cli/cmd_tracker.py +1 -1
- torchx/components/__init__.py +1 -8
- torchx/components/dist.py +9 -3
- torchx/components/integration_tests/component_provider.py +2 -2
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +92 -81
- torchx/runner/config.py +3 -1
- torchx/runner/events/__init__.py +20 -10
- torchx/runner/events/api.py +1 -1
- torchx/schedulers/__init__.py +7 -10
- torchx/schedulers/api.py +20 -15
- torchx/schedulers/aws_batch_scheduler.py +45 -2
- torchx/schedulers/docker_scheduler.py +3 -0
- torchx/schedulers/kubernetes_scheduler.py +200 -17
- torchx/schedulers/local_scheduler.py +1 -0
- torchx/schedulers/slurm_scheduler.py +93 -24
- torchx/specs/__init__.py +23 -6
- torchx/specs/api.py +219 -11
- torchx/specs/builders.py +109 -28
- torchx/specs/file_linter.py +117 -53
- torchx/specs/finder.py +25 -37
- torchx/specs/named_resources_aws.py +13 -2
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/util/strings.py +1 -1
- torchx/util/types.py +12 -1
- torchx/version.py +2 -2
- torchx/workspace/api.py +102 -5
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
- torchx/pipelines/kfp/__init__.py +0 -30
- torchx/pipelines/kfp/adapter.py +0 -274
- torchx/pipelines/kfp/version.py +0 -19
- torchx/schedulers/gcp_batch_scheduler.py +0 -497
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -454
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
torchx/specs/file_linter.py
CHANGED
|
@@ -11,8 +11,9 @@ import abc
|
|
|
11
11
|
import argparse
|
|
12
12
|
import ast
|
|
13
13
|
import inspect
|
|
14
|
+
import sys
|
|
14
15
|
from dataclasses import dataclass
|
|
15
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
16
17
|
|
|
17
18
|
from docstring_parser import parse
|
|
18
19
|
from torchx.util.io import read_conf_file
|
|
@@ -74,7 +75,7 @@ def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
|
|
|
74
75
|
if the description
|
|
75
76
|
"""
|
|
76
77
|
default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
|
|
77
|
-
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
|
|
78
|
+
to your component (see: https://meta-pytorch.org/torchx/latest/component_best_practices.html)"""
|
|
78
79
|
args_description = _get_default_arguments_descriptions(fn)
|
|
79
80
|
func_description = inspect.getdoc(fn)
|
|
80
81
|
if not func_description:
|
|
@@ -98,7 +99,7 @@ class LinterMessage:
|
|
|
98
99
|
severity: str = "error"
|
|
99
100
|
|
|
100
101
|
|
|
101
|
-
class
|
|
102
|
+
class ComponentFunctionValidator(abc.ABC):
|
|
102
103
|
@abc.abstractmethod
|
|
103
104
|
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
|
|
104
105
|
"""
|
|
@@ -116,7 +117,55 @@ class TorchxFunctionValidator(abc.ABC):
|
|
|
116
117
|
)
|
|
117
118
|
|
|
118
119
|
|
|
119
|
-
|
|
120
|
+
def OK() -> list[LinterMessage]:
|
|
121
|
+
return [] # empty linter error means validation passes
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def is_primitive(arg: ast.expr) -> bool:
|
|
125
|
+
# whether the arg is a primitive type (e.g. int, float, str, bool)
|
|
126
|
+
return isinstance(arg, ast.Name)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_generic_type(arg: ast.expr) -> ast.expr:
|
|
130
|
+
# returns the slice expr of a subscripted type
|
|
131
|
+
# `arg` must be an instance of ast.Subscript (caller checks)
|
|
132
|
+
# in this validator's context, this is the generic type of a container type
|
|
133
|
+
# e.g. for Optional[str] returns the expr for str
|
|
134
|
+
|
|
135
|
+
assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
|
|
136
|
+
|
|
137
|
+
if isinstance(arg.slice, ast.Index): # python>=3.10
|
|
138
|
+
return arg.slice.value
|
|
139
|
+
else: # python-3.9
|
|
140
|
+
return arg.slice
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
|
|
144
|
+
"""
|
|
145
|
+
Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
|
|
146
|
+
is not an ``Optional``. Handles both:
|
|
147
|
+
1. ``typing.Optional[T]`` (python<3.10)
|
|
148
|
+
2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
|
|
149
|
+
"""
|
|
150
|
+
# case 1: 'a: Optional[T]'
|
|
151
|
+
if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
|
|
152
|
+
return get_generic_type(arg)
|
|
153
|
+
|
|
154
|
+
# case 2: 'a: T | None' or 'a: None | T'
|
|
155
|
+
if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
|
|
156
|
+
if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
|
|
157
|
+
if isinstance(arg.right, ast.Constant) and arg.right.value is None:
|
|
158
|
+
return arg.left
|
|
159
|
+
if isinstance(arg.left, ast.Constant) and arg.left.value is None:
|
|
160
|
+
return arg.right
|
|
161
|
+
|
|
162
|
+
# case 3: is not optional
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ArgTypeValidator(ComponentFunctionValidator):
|
|
167
|
+
"""Validates component function's argument types."""
|
|
168
|
+
|
|
120
169
|
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
|
|
121
170
|
linter_errors = []
|
|
122
171
|
for arg_def in app_specs_func_def.args.args:
|
|
@@ -133,53 +182,68 @@ class TorchxFunctionArgsValidator(TorchxFunctionValidator):
|
|
|
133
182
|
return linter_errors
|
|
134
183
|
|
|
135
184
|
def _validate_arg_def(
|
|
136
|
-
self, function_name: str,
|
|
185
|
+
self, function_name: str, arg: ast.arg
|
|
137
186
|
) -> List[LinterMessage]:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
)
|
|
143
|
-
]
|
|
144
|
-
if isinstance(arg_def.annotation, ast.Name):
|
|
187
|
+
arg_type = arg.annotation # type hint
|
|
188
|
+
|
|
189
|
+
def ok() -> list[LinterMessage]:
|
|
190
|
+
# return value when validation passes (e.g. no linter errors)
|
|
145
191
|
return []
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
192
|
+
|
|
193
|
+
def err(reason: str) -> list[LinterMessage]:
|
|
194
|
+
msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
|
|
195
|
+
return [self._gen_linter_message(msg, arg.lineno)]
|
|
196
|
+
|
|
197
|
+
if not arg_type:
|
|
198
|
+
return err("Missing type annotation")
|
|
199
|
+
|
|
200
|
+
# Case 1: optional
|
|
201
|
+
if T := get_optional_type(arg_type):
|
|
202
|
+
# NOTE: optional types can be primitives or any of the allowed container types
|
|
203
|
+
# so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
|
|
204
|
+
arg_type = T
|
|
205
|
+
|
|
206
|
+
# Case 2: int, float, str, bool
|
|
207
|
+
if is_primitive(arg_type):
|
|
208
|
+
return ok()
|
|
209
|
+
# Case 3: Containers (Dict, List, Tuple)
|
|
210
|
+
elif isinstance(arg_type, ast.Subscript):
|
|
211
|
+
container_type = arg_type.value.id
|
|
212
|
+
|
|
213
|
+
if container_type in ["Dict", "dict"]:
|
|
214
|
+
KV = get_generic_type(arg_type)
|
|
215
|
+
|
|
216
|
+
assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
|
|
217
|
+
|
|
218
|
+
K, V = KV.elts
|
|
219
|
+
if not is_primitive(K):
|
|
220
|
+
return err(f"Non-primitive key type {ast.unparse(K)!r}")
|
|
221
|
+
if not is_primitive(V):
|
|
222
|
+
return err(f"Non-primitive value type {ast.unparse(V)!r}")
|
|
223
|
+
return ok()
|
|
224
|
+
elif container_type in ["List", "list"]:
|
|
225
|
+
T = get_generic_type(arg_type)
|
|
226
|
+
if is_primitive(T):
|
|
227
|
+
return ok()
|
|
228
|
+
else:
|
|
229
|
+
return err(f"Non-primitive element type {ast.unparse(T)!r}")
|
|
230
|
+
elif container_type in ["Tuple", "tuple"]:
|
|
231
|
+
E_N = get_generic_type(arg_type)
|
|
232
|
+
assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
|
|
233
|
+
|
|
234
|
+
for e in E_N.elts:
|
|
235
|
+
if not is_primitive(e):
|
|
236
|
+
return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
|
|
237
|
+
|
|
238
|
+
return ok()
|
|
239
|
+
|
|
240
|
+
return err(f"Unsupported container type {container_type!r}")
|
|
168
241
|
else:
|
|
169
|
-
|
|
170
|
-
if type_name == "Dict":
|
|
171
|
-
sub_type_tuple = cast(ast.Tuple, sub_type)
|
|
172
|
-
for el in sub_type_tuple.elts:
|
|
173
|
-
if not isinstance(el, ast.Name):
|
|
174
|
-
desc = "Dict can only have primitive types"
|
|
175
|
-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
|
|
176
|
-
elif not isinstance(sub_type, ast.Name):
|
|
177
|
-
desc = "List can only have primitive types"
|
|
178
|
-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
|
|
179
|
-
return linter_errors
|
|
242
|
+
return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
|
|
180
243
|
|
|
181
244
|
|
|
182
|
-
class
|
|
245
|
+
class ReturnTypeValidator(ComponentFunctionValidator):
|
|
246
|
+
"""Validates that component functions always return AppDef type"""
|
|
183
247
|
|
|
184
248
|
def __init__(self, supported_return_type: str) -> None:
|
|
185
249
|
super().__init__()
|
|
@@ -231,7 +295,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
|
|
|
231
295
|
return linter_errors
|
|
232
296
|
|
|
233
297
|
|
|
234
|
-
class
|
|
298
|
+
class ComponentFnVisitor(ast.NodeVisitor):
|
|
235
299
|
"""
|
|
236
300
|
Visitor that finds the component_function and runs registered validators on it.
|
|
237
301
|
Current registered validators:
|
|
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
|
|
|
252
316
|
def __init__(
|
|
253
317
|
self,
|
|
254
318
|
component_function_name: str,
|
|
255
|
-
validators: Optional[List[
|
|
319
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
256
320
|
) -> None:
|
|
257
321
|
if validators is None:
|
|
258
|
-
self.validators: List[
|
|
259
|
-
|
|
260
|
-
|
|
322
|
+
self.validators: List[ComponentFunctionValidator] = [
|
|
323
|
+
ArgTypeValidator(),
|
|
324
|
+
ReturnTypeValidator("AppDef"),
|
|
261
325
|
]
|
|
262
326
|
else:
|
|
263
327
|
self.validators = validators
|
|
@@ -279,7 +343,7 @@ class TorchFunctionVisitor(ast.NodeVisitor):
|
|
|
279
343
|
def validate(
|
|
280
344
|
path: str,
|
|
281
345
|
component_function: str,
|
|
282
|
-
validators: Optional[List[
|
|
346
|
+
validators: Optional[List[ComponentFunctionValidator]] = None,
|
|
283
347
|
) -> List[LinterMessage]:
|
|
284
348
|
"""
|
|
285
349
|
Validates the function to make sure it complies the component standard.
|
|
@@ -309,7 +373,7 @@ def validate(
|
|
|
309
373
|
severity="error",
|
|
310
374
|
)
|
|
311
375
|
return [linter_message]
|
|
312
|
-
visitor =
|
|
376
|
+
visitor = ComponentFnVisitor(component_function, validators)
|
|
313
377
|
visitor.visit(module)
|
|
314
378
|
linter_errors = visitor.linter_errors
|
|
315
379
|
if not visitor.visited_function:
|
torchx/specs/finder.py
CHANGED
|
@@ -17,9 +17,15 @@ from dataclasses import dataclass
|
|
|
17
17
|
from inspect import getmembers, isfunction
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from types import ModuleType
|
|
20
|
-
from typing import
|
|
20
|
+
from typing import Callable, Dict, Generator, List, Optional, Union
|
|
21
21
|
|
|
22
|
-
from torchx.specs
|
|
22
|
+
from torchx.specs import AppDef
|
|
23
|
+
|
|
24
|
+
from torchx.specs.file_linter import (
|
|
25
|
+
ComponentFunctionValidator,
|
|
26
|
+
get_fn_docstring,
|
|
27
|
+
validate,
|
|
28
|
+
)
|
|
23
29
|
from torchx.util import entrypoints
|
|
24
30
|
from torchx.util.io import read_conf_file
|
|
25
31
|
from torchx.util.types import none_throws
|
|
@@ -55,8 +61,7 @@ class _Component:
|
|
|
55
61
|
description: str
|
|
56
62
|
fn_name: str
|
|
57
63
|
|
|
58
|
-
|
|
59
|
-
fn: Callable[..., Any]
|
|
64
|
+
fn: Callable[..., AppDef]
|
|
60
65
|
|
|
61
66
|
validation_errors: List[str]
|
|
62
67
|
|
|
@@ -64,7 +69,7 @@ class _Component:
|
|
|
64
69
|
class ComponentsFinder(abc.ABC):
|
|
65
70
|
@abc.abstractmethod
|
|
66
71
|
def find(
|
|
67
|
-
self, validators: Optional[List[
|
|
72
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
68
73
|
) -> List[_Component]:
|
|
69
74
|
"""
|
|
70
75
|
Retrieves a set of components. A component is defined as a python
|
|
@@ -210,7 +215,7 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
210
215
|
yield self._try_import(module_info.name)
|
|
211
216
|
|
|
212
217
|
def find(
|
|
213
|
-
self, validators: Optional[List[
|
|
218
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
214
219
|
) -> List[_Component]:
|
|
215
220
|
components = []
|
|
216
221
|
for m in self._iter_modules_recursive(self.base_module):
|
|
@@ -230,7 +235,7 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
230
235
|
return module
|
|
231
236
|
|
|
232
237
|
def _get_components_from_module(
|
|
233
|
-
self, module: ModuleType, validators: Optional[List[
|
|
238
|
+
self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
|
|
234
239
|
) -> List[_Component]:
|
|
235
240
|
functions = getmembers(module, isfunction)
|
|
236
241
|
component_defs = []
|
|
@@ -269,28 +274,17 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
269
274
|
self,
|
|
270
275
|
path: str,
|
|
271
276
|
function_name: str,
|
|
272
|
-
validators: Optional[List[
|
|
277
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
273
278
|
) -> List[str]:
|
|
274
279
|
linter_errors = validate(path, function_name, validators)
|
|
275
280
|
return [linter_error.description for linter_error in linter_errors]
|
|
276
281
|
|
|
277
|
-
def _get_path_to_function_decl(
|
|
278
|
-
self, function: Callable[..., Any] # pyre-ignore[2]
|
|
279
|
-
) -> str:
|
|
280
|
-
"""
|
|
281
|
-
Attempts to return the path to the file where the function is implemented.
|
|
282
|
-
This can be different from the path where the function is looked up, for example if we have:
|
|
283
|
-
my_component defined in some_file.py, imported in other_file.py
|
|
284
|
-
and the component is invoked as other_file.py:my_component
|
|
285
|
-
"""
|
|
286
|
-
path_to_function_decl = inspect.getabsfile(function)
|
|
287
|
-
if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
|
|
288
|
-
return self._filepath
|
|
289
|
-
return path_to_function_decl
|
|
290
|
-
|
|
291
282
|
def find(
|
|
292
|
-
self, validators: Optional[List[
|
|
283
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
293
284
|
) -> List[_Component]:
|
|
285
|
+
validation_errors = self._get_validation_errors(
|
|
286
|
+
self._filepath, self._function_name, validators
|
|
287
|
+
)
|
|
294
288
|
|
|
295
289
|
file_source = read_conf_file(self._filepath)
|
|
296
290
|
namespace = copy.copy(globals())
|
|
@@ -303,12 +297,6 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
303
297
|
)
|
|
304
298
|
app_fn = namespace[self._function_name]
|
|
305
299
|
fn_desc, _ = get_fn_docstring(app_fn)
|
|
306
|
-
|
|
307
|
-
func_path = self._get_path_to_function_decl(app_fn)
|
|
308
|
-
validation_errors = self._get_validation_errors(
|
|
309
|
-
func_path, self._function_name, validators
|
|
310
|
-
)
|
|
311
|
-
|
|
312
300
|
return [
|
|
313
301
|
_Component(
|
|
314
302
|
name=f"{self._filepath}:{self._function_name}",
|
|
@@ -321,7 +309,7 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
321
309
|
|
|
322
310
|
|
|
323
311
|
def _load_custom_components(
|
|
324
|
-
validators: Optional[List[
|
|
312
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
325
313
|
) -> List[_Component]:
|
|
326
314
|
component_modules = {
|
|
327
315
|
name: load_fn()
|
|
@@ -346,7 +334,7 @@ def _load_custom_components(
|
|
|
346
334
|
|
|
347
335
|
|
|
348
336
|
def _load_components(
|
|
349
|
-
validators: Optional[List[
|
|
337
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
350
338
|
) -> Dict[str, _Component]:
|
|
351
339
|
"""
|
|
352
340
|
Loads either the custom component defs from the entrypoint ``[torchx.components]``
|
|
@@ -368,7 +356,7 @@ _components: Optional[Dict[str, _Component]] = None
|
|
|
368
356
|
|
|
369
357
|
|
|
370
358
|
def _find_components(
|
|
371
|
-
validators: Optional[List[
|
|
359
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
372
360
|
) -> Dict[str, _Component]:
|
|
373
361
|
global _components
|
|
374
362
|
if not _components:
|
|
@@ -381,7 +369,7 @@ def _is_custom_component(component_name: str) -> bool:
|
|
|
381
369
|
|
|
382
370
|
|
|
383
371
|
def _find_custom_components(
|
|
384
|
-
name: str, validators: Optional[List[
|
|
372
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]]
|
|
385
373
|
) -> Dict[str, _Component]:
|
|
386
374
|
if ":" not in name:
|
|
387
375
|
raise ValueError(
|
|
@@ -393,7 +381,7 @@ def _find_custom_components(
|
|
|
393
381
|
|
|
394
382
|
|
|
395
383
|
def get_components(
|
|
396
|
-
validators: Optional[List[
|
|
384
|
+
validators: Optional[List[ComponentFunctionValidator]] = None,
|
|
397
385
|
) -> Dict[str, _Component]:
|
|
398
386
|
"""
|
|
399
387
|
Returns all custom components registered via ``[torchx.components]`` entrypoints
|
|
@@ -448,7 +436,7 @@ def get_components(
|
|
|
448
436
|
|
|
449
437
|
|
|
450
438
|
def get_component(
|
|
451
|
-
name: str, validators: Optional[List[
|
|
439
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
|
|
452
440
|
) -> _Component:
|
|
453
441
|
"""
|
|
454
442
|
Retrieves components by the provided name.
|
|
@@ -464,7 +452,7 @@ def get_component(
|
|
|
464
452
|
raise ComponentNotFoundException(
|
|
465
453
|
f"Component `{name}` not found. Please make sure it is one of the "
|
|
466
454
|
"builtins: `torchx builtins`. Or registered via `[torchx.components]` "
|
|
467
|
-
"entry point (see: https://pytorch.org/torchx/latest/configure.html)"
|
|
455
|
+
"entry point (see: https://meta-pytorch.org/torchx/latest/configure.html)"
|
|
468
456
|
)
|
|
469
457
|
|
|
470
458
|
component = components[name]
|
|
@@ -477,7 +465,7 @@ def get_component(
|
|
|
477
465
|
|
|
478
466
|
|
|
479
467
|
def get_builtin_source(
|
|
480
|
-
name: str, validators: Optional[List[
|
|
468
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
|
|
481
469
|
) -> str:
|
|
482
470
|
"""
|
|
483
471
|
Returns a string of the the builtin component's function source code
|
|
@@ -16,7 +16,7 @@ the equvalent resource in mem, cpu and gpu numbers.
|
|
|
16
16
|
|
|
17
17
|
.. note::
|
|
18
18
|
These resource definitions may change in future. It is expected for each user to
|
|
19
|
-
manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
|
|
19
|
+
manage their own resources. Follow https://meta-pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
|
|
20
20
|
to set up named resources.
|
|
21
21
|
|
|
22
22
|
Usage:
|
|
@@ -47,7 +47,7 @@ NEURON_DEVICE = "aws.amazon.com/neurondevice"
|
|
|
47
47
|
MEM_TAX = 0.96
|
|
48
48
|
|
|
49
49
|
# determines instance type for non-honogeneous CEs
|
|
50
|
-
# see https://github.com/pytorch/torchx/issues/780
|
|
50
|
+
# see https://github.com/meta-pytorch/torchx/issues/780
|
|
51
51
|
K8S_ITYPE = "node.kubernetes.io/instance-type"
|
|
52
52
|
GiB: int = int(1024 * MEM_TAX)
|
|
53
53
|
|
|
@@ -120,6 +120,16 @@ def aws_p5_48xlarge() -> Resource:
|
|
|
120
120
|
)
|
|
121
121
|
|
|
122
122
|
|
|
123
|
+
def aws_p5e_48xlarge() -> Resource:
|
|
124
|
+
return Resource(
|
|
125
|
+
cpu=192,
|
|
126
|
+
gpu=8,
|
|
127
|
+
memMB=2048 * GiB,
|
|
128
|
+
capabilities={K8S_ITYPE: "p5e.48xlarge"},
|
|
129
|
+
devices={EFA_DEVICE: 32},
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
123
133
|
def aws_p5en_48xlarge() -> Resource:
|
|
124
134
|
return Resource(
|
|
125
135
|
cpu=192,
|
|
@@ -419,6 +429,7 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
|
|
|
419
429
|
"aws_p4d.24xlarge": aws_p4d_24xlarge,
|
|
420
430
|
"aws_p4de.24xlarge": aws_p4de_24xlarge,
|
|
421
431
|
"aws_p5.48xlarge": aws_p5_48xlarge,
|
|
432
|
+
"aws_p5e.48xlarge": aws_p5e_48xlarge,
|
|
422
433
|
"aws_p5en.48xlarge": aws_p5en_48xlarge,
|
|
423
434
|
"aws_g4dn.xlarge": aws_g4dn_xlarge,
|
|
424
435
|
"aws_g4dn.2xlarge": aws_g4dn_2xlarge,
|
torchx/tracker/__init__.py
CHANGED
|
@@ -32,7 +32,7 @@ implementation.
|
|
|
32
32
|
|
|
33
33
|
Example usage
|
|
34
34
|
-------------
|
|
35
|
-
Sample `code <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
|
|
35
|
+
Sample `code <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
Tracker Setup
|
|
@@ -111,7 +111,7 @@ Use :py:meth:`~torchx.tracker.app_run_from_env`:
|
|
|
111
111
|
Reference :py:class:`~torchx.tracker.api.TrackerBase` implementation
|
|
112
112
|
--------------------------------------------------------------------
|
|
113
113
|
:py:class:`~torchx.tracker.backend.fsspec.FsspecTracker` provides reference implementation of a tracker backend.
|
|
114
|
-
GitHub example `directory <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
|
|
114
|
+
GitHub example `directory <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
|
|
115
115
|
configure and use it in user application.
|
|
116
116
|
|
|
117
117
|
|
torchx/tracker/api.py
CHANGED
|
@@ -191,7 +191,7 @@ def build_trackers(
|
|
|
191
191
|
factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
|
|
192
192
|
if not factory:
|
|
193
193
|
logger.warning(
|
|
194
|
-
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
|
|
194
|
+
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
|
|
195
195
|
)
|
|
196
196
|
continue
|
|
197
197
|
if config:
|
torchx/util/entrypoints.py
CHANGED
|
@@ -69,9 +69,7 @@ def _defer_load_ep(ep: EntryPoint) -> object:
|
|
|
69
69
|
return run
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
def load_group(
|
|
73
|
-
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
|
|
74
|
-
):
|
|
72
|
+
def load_group(group: str, default: Optional[Dict[str, Any]] = None):
|
|
75
73
|
"""
|
|
76
74
|
Loads all the entry points specified by ``group`` and returns
|
|
77
75
|
the entry points as a map of ``name (str) -> deferred_load_fn``.
|
|
@@ -90,7 +88,6 @@ def load_group(
|
|
|
90
88
|
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
|
|
91
89
|
1. ``load_group("food")`` -> ``None``
|
|
92
90
|
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
|
|
93
|
-
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
|
|
94
91
|
|
|
95
92
|
|
|
96
93
|
If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
|
|
@@ -115,8 +112,6 @@ def load_group(
|
|
|
115
112
|
entrypoints = metadata.entry_points().get(group, ())
|
|
116
113
|
|
|
117
114
|
if len(entrypoints) == 0:
|
|
118
|
-
if skip_defaults:
|
|
119
|
-
return None
|
|
120
115
|
return default
|
|
121
116
|
|
|
122
117
|
eps = {}
|
torchx/util/strings.py
CHANGED
|
@@ -13,7 +13,7 @@ def normalize_str(data: str) -> str:
|
|
|
13
13
|
"""
|
|
14
14
|
Invokes ``lower`` on thes string and removes all
|
|
15
15
|
characters that do not satisfy ``[a-z0-9\\-]`` pattern.
|
|
16
|
-
This method is mostly used to make sure kubernetes
|
|
16
|
+
This method is mostly used to make sure kubernetes scheduler gets
|
|
17
17
|
the job name that does not violate its restrictions.
|
|
18
18
|
"""
|
|
19
19
|
if data.startswith("-"):
|
torchx/util/types.py
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
import inspect
|
|
10
10
|
import re
|
|
11
|
+
from types import UnionType
|
|
11
12
|
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
|
|
12
13
|
|
|
13
14
|
|
|
@@ -234,10 +235,20 @@ def decode_optional(param_type: Any) -> Any:
|
|
|
234
235
|
If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
|
|
235
236
|
Otherwise returns ``param_type``
|
|
236
237
|
"""
|
|
238
|
+
|
|
237
239
|
if not hasattr(param_type, "__origin__"):
|
|
238
|
-
|
|
240
|
+
if isinstance(param_type, UnionType):
|
|
241
|
+
# handle BinOp style Optional (e.g. `T | None`)
|
|
242
|
+
if len(param_type.__args__) == 2 and param_type.__args__[1] is type(None):
|
|
243
|
+
return param_type.__args__[0]
|
|
244
|
+
else:
|
|
245
|
+
return param_type
|
|
246
|
+
else:
|
|
247
|
+
return param_type
|
|
248
|
+
|
|
239
249
|
if param_type.__origin__ is not Union:
|
|
240
250
|
return param_type
|
|
251
|
+
|
|
241
252
|
args = param_type.__args__
|
|
242
253
|
if len(args) == 2 and args[1] is type(None):
|
|
243
254
|
return args[0]
|
torchx/version.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
@@ -7,6 +6,7 @@
|
|
|
7
6
|
|
|
8
7
|
# pyre-strict
|
|
9
8
|
|
|
9
|
+
from torchx._version import BASE_VERSION
|
|
10
10
|
from torchx.util.entrypoints import load
|
|
11
11
|
|
|
12
12
|
# Follows PEP-0440 version scheme guidelines
|
|
@@ -18,7 +18,7 @@ from torchx.util.entrypoints import load
|
|
|
18
18
|
# 0.1.0bN # Beta release
|
|
19
19
|
# 0.1.0rcN # Release Candidate
|
|
20
20
|
# 0.1.0 # Final release
|
|
21
|
-
__version__ =
|
|
21
|
+
__version__: str = BASE_VERSION
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# Use the github container registry images corresponding to the current package
|