torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/__init__.py +2 -0
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/apps/serve/serve.py +2 -0
- torchx/apps/utils/booth_main.py +2 -0
- torchx/apps/utils/copy_main.py +2 -0
- torchx/apps/utils/process_monitor.py +2 -0
- torchx/cli/__init__.py +2 -0
- torchx/cli/argparse_util.py +38 -3
- torchx/cli/cmd_base.py +2 -0
- torchx/cli/cmd_cancel.py +2 -0
- torchx/cli/cmd_configure.py +2 -0
- torchx/cli/cmd_delete.py +30 -0
- torchx/cli/cmd_describe.py +2 -0
- torchx/cli/cmd_list.py +8 -4
- torchx/cli/cmd_log.py +6 -24
- torchx/cli/cmd_run.py +269 -45
- torchx/cli/cmd_runopts.py +2 -0
- torchx/cli/cmd_status.py +12 -1
- torchx/cli/cmd_tracker.py +3 -1
- torchx/cli/colors.py +2 -0
- torchx/cli/main.py +4 -0
- torchx/components/__init__.py +3 -8
- torchx/components/component_test_base.py +2 -0
- torchx/components/dist.py +18 -7
- torchx/components/integration_tests/component_provider.py +4 -2
- torchx/components/integration_tests/integ_tests.py +2 -0
- torchx/components/serve.py +2 -0
- torchx/components/structured_arg.py +7 -6
- torchx/components/utils.py +15 -4
- torchx/distributed/__init__.py +2 -4
- torchx/examples/apps/datapreproc/datapreproc.py +2 -0
- torchx/examples/apps/lightning/data.py +5 -3
- torchx/examples/apps/lightning/model.py +7 -6
- torchx/examples/apps/lightning/profiler.py +7 -4
- torchx/examples/apps/lightning/train.py +11 -2
- torchx/examples/torchx_out_of_sync_training.py +11 -0
- torchx/notebook.py +2 -0
- torchx/runner/__init__.py +2 -0
- torchx/runner/api.py +167 -60
- torchx/runner/config.py +43 -10
- torchx/runner/events/__init__.py +57 -13
- torchx/runner/events/api.py +14 -3
- torchx/runner/events/handlers.py +2 -0
- torchx/runtime/tracking/__init__.py +2 -0
- torchx/runtime/tracking/api.py +2 -0
- torchx/schedulers/__init__.py +16 -15
- torchx/schedulers/api.py +70 -14
- torchx/schedulers/aws_batch_scheduler.py +79 -5
- torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
- torchx/schedulers/devices.py +17 -4
- torchx/schedulers/docker_scheduler.py +43 -11
- torchx/schedulers/ids.py +29 -23
- torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
- torchx/schedulers/kubernetes_scheduler.py +383 -38
- torchx/schedulers/local_scheduler.py +100 -27
- torchx/schedulers/lsf_scheduler.py +5 -4
- torchx/schedulers/slurm_scheduler.py +336 -20
- torchx/schedulers/streams.py +2 -0
- torchx/specs/__init__.py +89 -12
- torchx/specs/api.py +431 -32
- torchx/specs/builders.py +176 -38
- torchx/specs/file_linter.py +143 -57
- torchx/specs/finder.py +68 -28
- torchx/specs/named_resources_aws.py +254 -22
- torchx/specs/named_resources_generic.py +2 -0
- torchx/specs/overlays.py +106 -0
- torchx/specs/test/components/__init__.py +2 -0
- torchx/specs/test/components/a/__init__.py +2 -0
- torchx/specs/test/components/a/b/__init__.py +2 -0
- torchx/specs/test/components/a/b/c.py +2 -0
- torchx/specs/test/components/c/__init__.py +2 -0
- torchx/specs/test/components/c/d.py +2 -0
- torchx/tracker/__init__.py +12 -6
- torchx/tracker/api.py +15 -18
- torchx/tracker/backend/fsspec.py +2 -0
- torchx/util/cuda.py +2 -0
- torchx/util/datetime.py +2 -0
- torchx/util/entrypoints.py +39 -15
- torchx/util/io.py +2 -0
- torchx/util/log_tee_helpers.py +210 -0
- torchx/util/modules.py +65 -0
- torchx/util/session.py +42 -0
- torchx/util/shlex.py +2 -0
- torchx/util/strings.py +3 -1
- torchx/util/types.py +90 -29
- torchx/version.py +4 -2
- torchx/workspace/__init__.py +2 -0
- torchx/workspace/api.py +136 -6
- torchx/workspace/dir_workspace.py +2 -0
- torchx/workspace/docker_workspace.py +30 -2
- torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
- torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
- torchx/pipelines/kfp/__init__.py +0 -28
- torchx/pipelines/kfp/adapter.py +0 -271
- torchx/pipelines/kfp/version.py +0 -17
- torchx/schedulers/gcp_batch_scheduler.py +0 -487
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -453
- torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
- torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/builders.py
CHANGED
|
@@ -4,27 +4,46 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
7
9
|
import argparse
|
|
8
10
|
import inspect
|
|
9
|
-
|
|
11
|
+
import os
|
|
12
|
+
from argparse import Namespace
|
|
13
|
+
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union
|
|
10
14
|
|
|
11
15
|
from torchx.specs.api import BindMount, MountType, VolumeMount
|
|
12
16
|
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
|
|
13
|
-
from torchx.util.types import
|
|
14
|
-
decode_from_string,
|
|
15
|
-
decode_optional,
|
|
16
|
-
get_argparse_param_type,
|
|
17
|
-
is_bool,
|
|
18
|
-
is_primitive,
|
|
19
|
-
)
|
|
17
|
+
from torchx.util.types import decode, decode_optional, get_argparse_param_type, is_bool
|
|
20
18
|
|
|
21
19
|
from .api import AppDef, DeviceMount
|
|
22
20
|
|
|
23
21
|
|
|
22
|
+
class ComponentArgs(NamedTuple):
|
|
23
|
+
"""Parsed component function arguments"""
|
|
24
|
+
|
|
25
|
+
positional_args: dict[str, Any]
|
|
26
|
+
var_args: list[str]
|
|
27
|
+
kwargs: dict[str, Any]
|
|
28
|
+
|
|
29
|
+
|
|
24
30
|
def _create_args_parser(
|
|
25
|
-
cmpnt_fn: Callable[..., AppDef],
|
|
31
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
32
|
+
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
33
|
+
config: Optional[Dict[str, Any]] = None,
|
|
26
34
|
) -> argparse.ArgumentParser:
|
|
27
35
|
parameters = inspect.signature(cmpnt_fn).parameters
|
|
36
|
+
return _create_args_parser_from_parameters(
|
|
37
|
+
cmpnt_fn, parameters, cmpnt_defaults, config
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _create_args_parser_from_parameters(
|
|
42
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
43
|
+
parameters: Mapping[str, inspect.Parameter],
|
|
44
|
+
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
45
|
+
config: Optional[Dict[str, Any]] = None,
|
|
46
|
+
) -> argparse.ArgumentParser:
|
|
28
47
|
function_desc, args_desc = get_fn_docstring(cmpnt_fn)
|
|
29
48
|
script_parser = argparse.ArgumentParser(
|
|
30
49
|
prog=f"torchx run <run args...> {cmpnt_fn.__name__} ",
|
|
@@ -85,15 +104,144 @@ def _create_args_parser(
|
|
|
85
104
|
if len(param_name) == 1:
|
|
86
105
|
arg_names = [f"-{param_name}"] + arg_names
|
|
87
106
|
if "default" not in args:
|
|
88
|
-
|
|
107
|
+
if (config and param_name not in config) or not config:
|
|
108
|
+
args["required"] = True
|
|
109
|
+
|
|
89
110
|
script_parser.add_argument(*arg_names, **args)
|
|
90
111
|
return script_parser
|
|
91
112
|
|
|
92
113
|
|
|
114
|
+
def _merge_config_values_with_args(
|
|
115
|
+
parsed_args: argparse.Namespace, config: Dict[str, Any]
|
|
116
|
+
) -> None:
|
|
117
|
+
for key, val in config.items():
|
|
118
|
+
if key in parsed_args:
|
|
119
|
+
setattr(parsed_args, key, val)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def parse_args(
|
|
123
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
124
|
+
cmpnt_args: List[str],
|
|
125
|
+
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
126
|
+
config: Optional[Dict[str, Any]] = None,
|
|
127
|
+
) -> Namespace:
|
|
128
|
+
"""
|
|
129
|
+
Parse passed arguments, defaults, and config values into a namespace for
|
|
130
|
+
a component function.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
cmpnt_fn: Component function
|
|
134
|
+
cmpnt_args: Function args
|
|
135
|
+
cmpnt_defaults: Additional default values for parameters of ``app_fn``
|
|
136
|
+
(overrides the defaults set on the fn declaration)
|
|
137
|
+
config: Optional dict containing additional configuration for the component from a passed config file
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
A Namespace object with the args, defaults, and config values incorporated.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults, config)
|
|
144
|
+
parsed_args = script_parser.parse_args(cmpnt_args)
|
|
145
|
+
if config:
|
|
146
|
+
_merge_config_values_with_args(parsed_args, config)
|
|
147
|
+
|
|
148
|
+
return parsed_args
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def component_args_from_str(
|
|
152
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
153
|
+
cmpnt_args: list[str],
|
|
154
|
+
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
|
|
155
|
+
config: Optional[Dict[str, Any]] = None,
|
|
156
|
+
) -> ComponentArgs:
|
|
157
|
+
"""
|
|
158
|
+
Parses and decodes command-line arguments for a component function.
|
|
159
|
+
|
|
160
|
+
This function takes a component function and its arguments, parses them using argparse,
|
|
161
|
+
and decodes the arguments into their expected types based on the function's signature.
|
|
162
|
+
It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
cmpnt_fn: The component function whose arguments are to be parsed and decoded.
|
|
166
|
+
cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
|
|
167
|
+
cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
|
|
168
|
+
config: Optional dictionary containing additional configuration values.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
ComponentArgs representing the input args to a component function containing:
|
|
172
|
+
- positional_args: Dictionary of positional and positional-or-keyword arguments.
|
|
173
|
+
- var_args: List of variable positional arguments (*args).
|
|
174
|
+
- kwargs: Dictionary of keyword-only arguments.
|
|
175
|
+
|
|
176
|
+
Usage:
|
|
177
|
+
|
|
178
|
+
.. doctest::
|
|
179
|
+
from torchx.specs.api import AppDef
|
|
180
|
+
from torchx.specs.builders import component_args_from_str
|
|
181
|
+
|
|
182
|
+
def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
|
|
183
|
+
return AppDef(name="example")
|
|
184
|
+
|
|
185
|
+
# Supports space separated arguments
|
|
186
|
+
args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
|
|
187
|
+
parsed_args = component_args_from_str(example_component_fn, args)
|
|
188
|
+
|
|
189
|
+
assert parsed_args.positional_args == {"foo": "fooval"}
|
|
190
|
+
assert parsed_args.var_args == ["arg1", "arg2"]
|
|
191
|
+
assert parsed_args.kwargs == {"bar": "barval"}
|
|
192
|
+
|
|
193
|
+
# Supports '=' separated arguments
|
|
194
|
+
args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
|
|
195
|
+
parsed_args = component_args_from_str(example_component_fn, args)
|
|
196
|
+
|
|
197
|
+
assert parsed_args.positional_args == {"foo": "fooval"}
|
|
198
|
+
assert parsed_args.var_args == ["arg1", "arg2"]
|
|
199
|
+
assert parsed_args.kwargs == {"bar": "barval"}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
parsed_args: Namespace = parse_args(
|
|
204
|
+
cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
positional_args = {}
|
|
208
|
+
var_args = []
|
|
209
|
+
kwargs = {}
|
|
210
|
+
|
|
211
|
+
parameters = inspect.signature(cmpnt_fn).parameters
|
|
212
|
+
for param_name, parameter in parameters.items():
|
|
213
|
+
arg_value = getattr(parsed_args, param_name)
|
|
214
|
+
parameter_type = parameter.annotation
|
|
215
|
+
parameter_type = decode_optional(parameter_type)
|
|
216
|
+
if (
|
|
217
|
+
parameter_type != arg_value.__class__
|
|
218
|
+
and parameter.kind != inspect.Parameter.VAR_POSITIONAL
|
|
219
|
+
):
|
|
220
|
+
arg_value = decode(arg_value, parameter_type)
|
|
221
|
+
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
222
|
+
var_args = arg_value
|
|
223
|
+
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
224
|
+
kwargs[param_name] = arg_value
|
|
225
|
+
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
226
|
+
raise TypeError(
|
|
227
|
+
f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
|
|
228
|
+
f"type to a dict or explicitly declare the params"
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
# POSITIONAL or POSITIONAL_OR_KEYWORD
|
|
232
|
+
positional_args[param_name] = arg_value
|
|
233
|
+
|
|
234
|
+
if len(var_args) > 0 and var_args[0] == "--":
|
|
235
|
+
var_args = var_args[1:]
|
|
236
|
+
|
|
237
|
+
return ComponentArgs(positional_args, var_args, kwargs)
|
|
238
|
+
|
|
239
|
+
|
|
93
240
|
def materialize_appdef(
|
|
94
241
|
cmpnt_fn: Callable[..., AppDef],
|
|
95
242
|
cmpnt_args: List[str],
|
|
96
|
-
cmpnt_defaults: Optional[Dict[str,
|
|
243
|
+
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
244
|
+
config: Optional[Dict[str, Any]] = None,
|
|
97
245
|
) -> AppDef:
|
|
98
246
|
"""
|
|
99
247
|
Creates an application by running user defined ``app_fn``.
|
|
@@ -118,38 +266,25 @@ def materialize_appdef(
|
|
|
118
266
|
cmpnt_args: Function args
|
|
119
267
|
cmpnt_defaults: Additional default values for parameters of ``app_fn``
|
|
120
268
|
(overrides the defaults set on the fn declaration)
|
|
269
|
+
config: Optional dict containing additional configuration for the component from a passed config file
|
|
121
270
|
Returns:
|
|
122
271
|
An application spec
|
|
123
272
|
"""
|
|
124
273
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
274
|
+
component_args: ComponentArgs = component_args_from_str(
|
|
275
|
+
cmpnt_fn, cmpnt_args, cmpnt_defaults, config
|
|
276
|
+
)
|
|
277
|
+
positional_arg_values = list(component_args.positional_args.values())
|
|
278
|
+
appdef = cmpnt_fn(
|
|
279
|
+
*positional_arg_values, *component_args.var_args, **component_args.kwargs
|
|
280
|
+
)
|
|
131
281
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
parameter_type = decode_optional(parameter_type)
|
|
137
|
-
if is_bool(parameter_type):
|
|
138
|
-
arg_value = arg_value and arg_value.lower() == "true"
|
|
139
|
-
elif not is_primitive(parameter_type):
|
|
140
|
-
arg_value = decode_from_string(arg_value, parameter_type)
|
|
141
|
-
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
142
|
-
var_arg = arg_value
|
|
143
|
-
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
144
|
-
kwargs[param_name] = arg_value
|
|
145
|
-
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
146
|
-
raise TypeError("**kwargs are not supported for component definitions")
|
|
147
|
-
else:
|
|
148
|
-
function_args.append(arg_value)
|
|
149
|
-
if len(var_arg) > 0 and var_arg[0] == "--":
|
|
150
|
-
var_arg = var_arg[1:]
|
|
282
|
+
if not isinstance(appdef, AppDef):
|
|
283
|
+
raise TypeError(
|
|
284
|
+
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
|
|
285
|
+
)
|
|
151
286
|
|
|
152
|
-
return
|
|
287
|
+
return appdef
|
|
153
288
|
|
|
154
289
|
|
|
155
290
|
def make_app_handle(scheduler_backend: str, session_name: str, app_id: str) -> str:
|
|
@@ -205,9 +340,12 @@ def parse_mounts(opts: List[str]) -> List[Union[BindMount, VolumeMount, DeviceMo
|
|
|
205
340
|
for opts in mount_opts:
|
|
206
341
|
typ = opts.get("type")
|
|
207
342
|
if typ == MountType.BIND:
|
|
343
|
+
src_path = opts["src"]
|
|
344
|
+
if src_path.startswith("~"):
|
|
345
|
+
src_path = os.path.expanduser(src_path)
|
|
208
346
|
mounts.append(
|
|
209
347
|
BindMount(
|
|
210
|
-
src_path=
|
|
348
|
+
src_path=src_path,
|
|
211
349
|
dst_path=opts["dst"],
|
|
212
350
|
read_only="readonly" in opts,
|
|
213
351
|
)
|
torchx/specs/file_linter.py
CHANGED
|
@@ -5,12 +5,15 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
8
10
|
import abc
|
|
9
11
|
import argparse
|
|
10
12
|
import ast
|
|
11
13
|
import inspect
|
|
14
|
+
import sys
|
|
12
15
|
from dataclasses import dataclass
|
|
13
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
14
17
|
|
|
15
18
|
from docstring_parser import parse
|
|
16
19
|
from torchx.util.io import read_conf_file
|
|
@@ -29,7 +32,11 @@ def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str,
|
|
|
29
32
|
return args_decs
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
class TorchXArgumentHelpFormatter(
|
|
35
|
+
class TorchXArgumentHelpFormatter(
|
|
36
|
+
argparse.RawDescriptionHelpFormatter,
|
|
37
|
+
argparse.ArgumentDefaultsHelpFormatter,
|
|
38
|
+
argparse.MetavarTypeHelpFormatter,
|
|
39
|
+
):
|
|
33
40
|
"""Help message formatter which adds default values and required to argument help.
|
|
34
41
|
|
|
35
42
|
If the argument is required, the class appends `(required)` at the end of the help message.
|
|
@@ -68,7 +75,7 @@ def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
|
|
|
68
75
|
if the description
|
|
69
76
|
"""
|
|
70
77
|
default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
|
|
71
|
-
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
|
|
78
|
+
to your component (see: https://meta-pytorch.org/torchx/latest/component_best_practices.html)"""
|
|
72
79
|
args_description = _get_default_arguments_descriptions(fn)
|
|
73
80
|
func_description = inspect.getdoc(fn)
|
|
74
81
|
if not func_description:
|
|
@@ -79,7 +86,7 @@ to your component (see: https://pytorch.org/torchx/latest/component_best_practic
|
|
|
79
86
|
args_description[param.arg_name] = param.description
|
|
80
87
|
short_func_description = docstring.short_description or default_fn_desc
|
|
81
88
|
if docstring.long_description:
|
|
82
|
-
short_func_description += "
|
|
89
|
+
short_func_description += "\n" + docstring.long_description
|
|
83
90
|
return (short_func_description or default_fn_desc, args_description)
|
|
84
91
|
|
|
85
92
|
|
|
@@ -92,7 +99,7 @@ class LinterMessage:
|
|
|
92
99
|
severity: str = "error"
|
|
93
100
|
|
|
94
101
|
|
|
95
|
-
class
|
|
102
|
+
class ComponentFunctionValidator(abc.ABC):
|
|
96
103
|
@abc.abstractmethod
|
|
97
104
|
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
|
|
98
105
|
"""
|
|
@@ -110,7 +117,55 @@ class TorchxFunctionValidator(abc.ABC):
|
|
|
110
117
|
)
|
|
111
118
|
|
|
112
119
|
|
|
113
|
-
|
|
120
|
+
def OK() -> list[LinterMessage]:
|
|
121
|
+
return [] # empty linter error means validation passes
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def is_primitive(arg: ast.expr) -> bool:
|
|
125
|
+
# whether the arg is a primitive type (e.g. int, float, str, bool)
|
|
126
|
+
return isinstance(arg, ast.Name)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def get_generic_type(arg: ast.expr) -> ast.expr:
|
|
130
|
+
# returns the slice expr of a subscripted type
|
|
131
|
+
# `arg` must be an instance of ast.Subscript (caller checks)
|
|
132
|
+
# in this validator's context, this is the generic type of a container type
|
|
133
|
+
# e.g. for Optional[str] returns the expr for str
|
|
134
|
+
|
|
135
|
+
assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
|
|
136
|
+
|
|
137
|
+
if isinstance(arg.slice, ast.Index): # python>=3.10
|
|
138
|
+
return arg.slice.value
|
|
139
|
+
else: # python-3.9
|
|
140
|
+
return arg.slice
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
|
|
144
|
+
"""
|
|
145
|
+
Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
|
|
146
|
+
is not an ``Optional``. Handles both:
|
|
147
|
+
1. ``typing.Optional[T]`` (python<3.10)
|
|
148
|
+
2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
|
|
149
|
+
"""
|
|
150
|
+
# case 1: 'a: Optional[T]'
|
|
151
|
+
if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
|
|
152
|
+
return get_generic_type(arg)
|
|
153
|
+
|
|
154
|
+
# case 2: 'a: T | None' or 'a: None | T'
|
|
155
|
+
if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
|
|
156
|
+
if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
|
|
157
|
+
if isinstance(arg.right, ast.Constant) and arg.right.value is None:
|
|
158
|
+
return arg.left
|
|
159
|
+
if isinstance(arg.left, ast.Constant) and arg.left.value is None:
|
|
160
|
+
return arg.right
|
|
161
|
+
|
|
162
|
+
# case 3: is not optional
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ArgTypeValidator(ComponentFunctionValidator):
|
|
167
|
+
"""Validates component function's argument types."""
|
|
168
|
+
|
|
114
169
|
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
|
|
115
170
|
linter_errors = []
|
|
116
171
|
for arg_def in app_specs_func_def.args.args:
|
|
@@ -127,53 +182,73 @@ class TorchxFunctionArgsValidator(TorchxFunctionValidator):
|
|
|
127
182
|
return linter_errors
|
|
128
183
|
|
|
129
184
|
def _validate_arg_def(
|
|
130
|
-
self, function_name: str,
|
|
185
|
+
self, function_name: str, arg: ast.arg
|
|
131
186
|
) -> List[LinterMessage]:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
)
|
|
137
|
-
]
|
|
138
|
-
if isinstance(arg_def.annotation, ast.Name):
|
|
187
|
+
arg_type = arg.annotation # type hint
|
|
188
|
+
|
|
189
|
+
def ok() -> list[LinterMessage]:
|
|
190
|
+
# return value when validation passes (e.g. no linter errors)
|
|
139
191
|
return []
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
192
|
+
|
|
193
|
+
def err(reason: str) -> list[LinterMessage]:
|
|
194
|
+
msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
|
|
195
|
+
return [self._gen_linter_message(msg, arg.lineno)]
|
|
196
|
+
|
|
197
|
+
if not arg_type:
|
|
198
|
+
return err("Missing type annotation")
|
|
199
|
+
|
|
200
|
+
# Case 1: optional
|
|
201
|
+
if T := get_optional_type(arg_type):
|
|
202
|
+
# NOTE: optional types can be primitives or any of the allowed container types
|
|
203
|
+
# so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
|
|
204
|
+
arg_type = T
|
|
205
|
+
|
|
206
|
+
# Case 2: int, float, str, bool
|
|
207
|
+
if is_primitive(arg_type):
|
|
208
|
+
return ok()
|
|
209
|
+
# Case 3: Containers (Dict, List, Tuple)
|
|
210
|
+
elif isinstance(arg_type, ast.Subscript):
|
|
211
|
+
container_type = arg_type.value.id
|
|
212
|
+
|
|
213
|
+
if container_type in ["Dict", "dict"]:
|
|
214
|
+
KV = get_generic_type(arg_type)
|
|
215
|
+
|
|
216
|
+
assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
|
|
217
|
+
|
|
218
|
+
K, V = KV.elts
|
|
219
|
+
if not is_primitive(K):
|
|
220
|
+
return err(f"Non-primitive key type {ast.unparse(K)!r}")
|
|
221
|
+
if not is_primitive(V):
|
|
222
|
+
return err(f"Non-primitive value type {ast.unparse(V)!r}")
|
|
223
|
+
return ok()
|
|
224
|
+
elif container_type in ["List", "list"]:
|
|
225
|
+
T = get_generic_type(arg_type)
|
|
226
|
+
if is_primitive(T):
|
|
227
|
+
return ok()
|
|
228
|
+
else:
|
|
229
|
+
return err(f"Non-primitive element type {ast.unparse(T)!r}")
|
|
230
|
+
elif container_type in ["Tuple", "tuple"]:
|
|
231
|
+
E_N = get_generic_type(arg_type)
|
|
232
|
+
assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
|
|
233
|
+
|
|
234
|
+
for e in E_N.elts:
|
|
235
|
+
if not is_primitive(e):
|
|
236
|
+
return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
|
|
237
|
+
|
|
238
|
+
return ok()
|
|
239
|
+
|
|
240
|
+
return err(f"Unsupported container type {container_type!r}")
|
|
162
241
|
else:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
sub_type_tuple = cast(ast.Tuple, sub_type)
|
|
166
|
-
for el in sub_type_tuple.elts:
|
|
167
|
-
if not isinstance(el, ast.Name):
|
|
168
|
-
desc = "Dict can only have primitive types"
|
|
169
|
-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
|
|
170
|
-
elif not isinstance(sub_type, ast.Name):
|
|
171
|
-
desc = "List can only have primitive types"
|
|
172
|
-
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
|
|
173
|
-
return linter_errors
|
|
242
|
+
return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
|
|
243
|
+
|
|
174
244
|
|
|
245
|
+
class ReturnTypeValidator(ComponentFunctionValidator):
|
|
246
|
+
"""Validates that component functions always return AppDef type"""
|
|
247
|
+
|
|
248
|
+
def __init__(self, supported_return_type: str) -> None:
|
|
249
|
+
super().__init__()
|
|
250
|
+
self._supported_return_type = supported_return_type
|
|
175
251
|
|
|
176
|
-
class TorchxReturnValidator(TorchxFunctionValidator):
|
|
177
252
|
def _get_return_annotation(
|
|
178
253
|
self, app_specs_func_def: ast.FunctionDef
|
|
179
254
|
) -> Optional[str]:
|
|
@@ -197,7 +272,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
|
|
|
197
272
|
* AppDef
|
|
198
273
|
* specs.AppDef
|
|
199
274
|
"""
|
|
200
|
-
supported_return_annotation =
|
|
275
|
+
supported_return_annotation = self._supported_return_type
|
|
201
276
|
return_annotation = self._get_return_annotation(app_specs_func_def)
|
|
202
277
|
linter_errors = []
|
|
203
278
|
if not return_annotation:
|
|
@@ -220,7 +295,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
|
|
|
220
295
|
return linter_errors
|
|
221
296
|
|
|
222
297
|
|
|
223
|
-
class
|
|
298
|
+
class ComponentFnVisitor(ast.NodeVisitor):
|
|
224
299
|
"""
|
|
225
300
|
Visitor that finds the component_function and runs registered validators on it.
|
|
226
301
|
Current registered validators:
|
|
@@ -238,11 +313,18 @@ class TorchFunctionVisitor(ast.NodeVisitor):
|
|
|
238
313
|
|
|
239
314
|
"""
|
|
240
315
|
|
|
241
|
-
def __init__(
|
|
242
|
-
self
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
component_function_name: str,
|
|
319
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
320
|
+
) -> None:
|
|
321
|
+
if validators is None:
|
|
322
|
+
self.validators: List[ComponentFunctionValidator] = [
|
|
323
|
+
ArgTypeValidator(),
|
|
324
|
+
ReturnTypeValidator("AppDef"),
|
|
325
|
+
]
|
|
326
|
+
else:
|
|
327
|
+
self.validators = validators
|
|
246
328
|
self.linter_errors: List[LinterMessage] = []
|
|
247
329
|
self.component_function_name = component_function_name
|
|
248
330
|
self.visited_function = False
|
|
@@ -258,7 +340,11 @@ class TorchFunctionVisitor(ast.NodeVisitor):
|
|
|
258
340
|
self.linter_errors += validator.validate(node)
|
|
259
341
|
|
|
260
342
|
|
|
261
|
-
def validate(
|
|
343
|
+
def validate(
|
|
344
|
+
path: str,
|
|
345
|
+
component_function: str,
|
|
346
|
+
validators: Optional[List[ComponentFunctionValidator]] = None,
|
|
347
|
+
) -> List[LinterMessage]:
|
|
262
348
|
"""
|
|
263
349
|
Validates the function to make sure it complies the component standard.
|
|
264
350
|
|
|
@@ -287,7 +373,7 @@ def validate(path: str, component_function: str) -> List[LinterMessage]:
|
|
|
287
373
|
severity="error",
|
|
288
374
|
)
|
|
289
375
|
return [linter_message]
|
|
290
|
-
visitor =
|
|
376
|
+
visitor = ComponentFnVisitor(component_function, validators)
|
|
291
377
|
visitor.visit(module)
|
|
292
378
|
linter_errors = visitor.linter_errors
|
|
293
379
|
if not visitor.visited_function:
|