torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +4 -3
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +75 -6
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +418 -30
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +181 -4
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
  108. torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
  109. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/builders.py CHANGED
@@ -4,27 +4,46 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
8
+
7
9
  import argparse
8
10
  import inspect
9
- from typing import Any, Callable, Dict, List, Mapping, Optional, Union
11
+ import os
12
+ from argparse import Namespace
13
+ from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union
10
14
 
11
15
  from torchx.specs.api import BindMount, MountType, VolumeMount
12
16
  from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
13
- from torchx.util.types import (
14
- decode_from_string,
15
- decode_optional,
16
- get_argparse_param_type,
17
- is_bool,
18
- is_primitive,
19
- )
17
+ from torchx.util.types import decode, decode_optional, get_argparse_param_type, is_bool
20
18
 
21
19
  from .api import AppDef, DeviceMount
22
20
 
23
21
 
22
+ class ComponentArgs(NamedTuple):
23
+ """Parsed component function arguments"""
24
+
25
+ positional_args: dict[str, Any]
26
+ var_args: list[str]
27
+ kwargs: dict[str, Any]
28
+
29
+
24
30
  def _create_args_parser(
25
- cmpnt_fn: Callable[..., AppDef], cmpnt_defaults: Optional[Dict[str, str]] = None
31
+ cmpnt_fn: Callable[..., AppDef],
32
+ cmpnt_defaults: Optional[Dict[str, str]] = None,
33
+ config: Optional[Dict[str, Any]] = None,
26
34
  ) -> argparse.ArgumentParser:
27
35
  parameters = inspect.signature(cmpnt_fn).parameters
36
+ return _create_args_parser_from_parameters(
37
+ cmpnt_fn, parameters, cmpnt_defaults, config
38
+ )
39
+
40
+
41
+ def _create_args_parser_from_parameters(
42
+ cmpnt_fn: Callable[..., AppDef],
43
+ parameters: Mapping[str, inspect.Parameter],
44
+ cmpnt_defaults: Optional[Dict[str, str]] = None,
45
+ config: Optional[Dict[str, Any]] = None,
46
+ ) -> argparse.ArgumentParser:
28
47
  function_desc, args_desc = get_fn_docstring(cmpnt_fn)
29
48
  script_parser = argparse.ArgumentParser(
30
49
  prog=f"torchx run <run args...> {cmpnt_fn.__name__} ",
@@ -85,15 +104,144 @@ def _create_args_parser(
85
104
  if len(param_name) == 1:
86
105
  arg_names = [f"-{param_name}"] + arg_names
87
106
  if "default" not in args:
88
- args["required"] = True
107
+ if (config and param_name not in config) or not config:
108
+ args["required"] = True
109
+
89
110
  script_parser.add_argument(*arg_names, **args)
90
111
  return script_parser
91
112
 
92
113
 
114
+ def _merge_config_values_with_args(
115
+ parsed_args: argparse.Namespace, config: Dict[str, Any]
116
+ ) -> None:
117
+ for key, val in config.items():
118
+ if key in parsed_args:
119
+ setattr(parsed_args, key, val)
120
+
121
+
122
+ def parse_args(
123
+ cmpnt_fn: Callable[..., AppDef],
124
+ cmpnt_args: List[str],
125
+ cmpnt_defaults: Optional[Dict[str, Any]] = None,
126
+ config: Optional[Dict[str, Any]] = None,
127
+ ) -> Namespace:
128
+ """
129
+ Parse passed arguments, defaults, and config values into a namespace for
130
+ a component function.
131
+
132
+ Args:
133
+ cmpnt_fn: Component function
134
+ cmpnt_args: Function args
135
+ cmpnt_defaults: Additional default values for parameters of ``app_fn``
136
+ (overrides the defaults set on the fn declaration)
137
+ config: Optional dict containing additional configuration for the component from a passed config file
138
+
139
+ Returns:
140
+ A Namespace object with the args, defaults, and config values incorporated.
141
+ """
142
+
143
+ script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults, config)
144
+ parsed_args = script_parser.parse_args(cmpnt_args)
145
+ if config:
146
+ _merge_config_values_with_args(parsed_args, config)
147
+
148
+ return parsed_args
149
+
150
+
151
+ def component_args_from_str(
152
+ cmpnt_fn: Callable[..., AppDef],
153
+ cmpnt_args: list[str],
154
+ cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
155
+ config: Optional[Dict[str, Any]] = None,
156
+ ) -> ComponentArgs:
157
+ """
158
+ Parses and decodes command-line arguments for a component function.
159
+
160
+ This function takes a component function and its arguments, parses them using argparse,
161
+ and decodes the arguments into their expected types based on the function's signature.
162
+ It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
163
+
164
+ Args:
165
+ cmpnt_fn: The component function whose arguments are to be parsed and decoded.
166
+ cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
167
+ cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
168
+ config: Optional dictionary containing additional configuration values.
169
+
170
+ Returns:
171
+ ComponentArgs representing the input args to a component function containing:
172
+ - positional_args: Dictionary of positional and positional-or-keyword arguments.
173
+ - var_args: List of variable positional arguments (*args).
174
+ - kwargs: Dictionary of keyword-only arguments.
175
+
176
+ Usage:
177
+
178
+ .. doctest::
179
+ from torchx.specs.api import AppDef
180
+ from torchx.specs.builders import component_args_from_str
181
+
182
+ def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
183
+ return AppDef(name="example")
184
+
185
+ # Supports space separated arguments
186
+ args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
187
+ parsed_args = component_args_from_str(example_component_fn, args)
188
+
189
+ assert parsed_args.positional_args == {"foo": "fooval"}
190
+ assert parsed_args.var_args == ["arg1", "arg2"]
191
+ assert parsed_args.kwargs == {"bar": "barval"}
192
+
193
+ # Supports '=' separated arguments
194
+ args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
195
+ parsed_args = component_args_from_str(example_component_fn, args)
196
+
197
+ assert parsed_args.positional_args == {"foo": "fooval"}
198
+ assert parsed_args.var_args == ["arg1", "arg2"]
199
+ assert parsed_args.kwargs == {"bar": "barval"}
200
+
201
+
202
+ """
203
+ parsed_args: Namespace = parse_args(
204
+ cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
205
+ )
206
+
207
+ positional_args = {}
208
+ var_args = []
209
+ kwargs = {}
210
+
211
+ parameters = inspect.signature(cmpnt_fn).parameters
212
+ for param_name, parameter in parameters.items():
213
+ arg_value = getattr(parsed_args, param_name)
214
+ parameter_type = parameter.annotation
215
+ parameter_type = decode_optional(parameter_type)
216
+ if (
217
+ parameter_type != arg_value.__class__
218
+ and parameter.kind != inspect.Parameter.VAR_POSITIONAL
219
+ ):
220
+ arg_value = decode(arg_value, parameter_type)
221
+ if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
222
+ var_args = arg_value
223
+ elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
224
+ kwargs[param_name] = arg_value
225
+ elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
226
+ raise TypeError(
227
+ f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
228
+ f"type to a dict or explicitly declare the params"
229
+ )
230
+ else:
231
+ # POSITIONAL or POSITIONAL_OR_KEYWORD
232
+ positional_args[param_name] = arg_value
233
+
234
+ if len(var_args) > 0 and var_args[0] == "--":
235
+ var_args = var_args[1:]
236
+
237
+ return ComponentArgs(positional_args, var_args, kwargs)
238
+
239
+
93
240
  def materialize_appdef(
94
241
  cmpnt_fn: Callable[..., AppDef],
95
242
  cmpnt_args: List[str],
96
- cmpnt_defaults: Optional[Dict[str, str]] = None,
243
+ cmpnt_defaults: Optional[Dict[str, Any]] = None,
244
+ config: Optional[Dict[str, Any]] = None,
97
245
  ) -> AppDef:
98
246
  """
99
247
  Creates an application by running user defined ``app_fn``.
@@ -118,38 +266,25 @@ def materialize_appdef(
118
266
  cmpnt_args: Function args
119
267
  cmpnt_defaults: Additional default values for parameters of ``app_fn``
120
268
  (overrides the defaults set on the fn declaration)
269
+ config: Optional dict containing additional configuration for the component from a passed config file
121
270
  Returns:
122
271
  An application spec
123
272
  """
124
273
 
125
- script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults)
126
- parsed_args = script_parser.parse_args(cmpnt_args)
127
-
128
- function_args = []
129
- var_arg = []
130
- kwargs = {}
274
+ component_args: ComponentArgs = component_args_from_str(
275
+ cmpnt_fn, cmpnt_args, cmpnt_defaults, config
276
+ )
277
+ positional_arg_values = list(component_args.positional_args.values())
278
+ appdef = cmpnt_fn(
279
+ *positional_arg_values, *component_args.var_args, **component_args.kwargs
280
+ )
131
281
 
132
- parameters = inspect.signature(cmpnt_fn).parameters
133
- for param_name, parameter in parameters.items():
134
- arg_value = getattr(parsed_args, param_name)
135
- parameter_type = parameter.annotation
136
- parameter_type = decode_optional(parameter_type)
137
- if is_bool(parameter_type):
138
- arg_value = arg_value and arg_value.lower() == "true"
139
- elif not is_primitive(parameter_type):
140
- arg_value = decode_from_string(arg_value, parameter_type)
141
- if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
142
- var_arg = arg_value
143
- elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
144
- kwargs[param_name] = arg_value
145
- elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
146
- raise TypeError("**kwargs are not supported for component definitions")
147
- else:
148
- function_args.append(arg_value)
149
- if len(var_arg) > 0 and var_arg[0] == "--":
150
- var_arg = var_arg[1:]
282
+ if not isinstance(appdef, AppDef):
283
+ raise TypeError(
284
+ f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
285
+ )
151
286
 
152
- return cmpnt_fn(*function_args, *var_arg, **kwargs)
287
+ return appdef
153
288
 
154
289
 
155
290
  def make_app_handle(scheduler_backend: str, session_name: str, app_id: str) -> str:
@@ -205,9 +340,12 @@ def parse_mounts(opts: List[str]) -> List[Union[BindMount, VolumeMount, DeviceMo
205
340
  for opts in mount_opts:
206
341
  typ = opts.get("type")
207
342
  if typ == MountType.BIND:
343
+ src_path = opts["src"]
344
+ if src_path.startswith("~"):
345
+ src_path = os.path.expanduser(src_path)
208
346
  mounts.append(
209
347
  BindMount(
210
- src_path=opts["src"],
348
+ src_path=src_path,
211
349
  dst_path=opts["dst"],
212
350
  read_only="readonly" in opts,
213
351
  )
@@ -5,12 +5,15 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import abc
9
11
  import argparse
10
12
  import ast
11
13
  import inspect
14
+ import sys
12
15
  from dataclasses import dataclass
13
- from typing import Callable, cast, Dict, List, Optional, Tuple
16
+ from typing import Callable, Dict, List, Optional, Tuple
14
17
 
15
18
  from docstring_parser import parse
16
19
  from torchx.util.io import read_conf_file
@@ -29,7 +32,11 @@ def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str,
29
32
  return args_decs
30
33
 
31
34
 
32
- class TorchXArgumentHelpFormatter(argparse.HelpFormatter):
35
+ class TorchXArgumentHelpFormatter(
36
+ argparse.RawDescriptionHelpFormatter,
37
+ argparse.ArgumentDefaultsHelpFormatter,
38
+ argparse.MetavarTypeHelpFormatter,
39
+ ):
33
40
  """Help message formatter which adds default values and required to argument help.
34
41
 
35
42
  If the argument is required, the class appends `(required)` at the end of the help message.
@@ -68,7 +75,7 @@ def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
68
75
  if the description
69
76
  """
70
77
  default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
71
- to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
78
+ to your component (see: https://meta-pytorch.org/torchx/latest/component_best_practices.html)"""
72
79
  args_description = _get_default_arguments_descriptions(fn)
73
80
  func_description = inspect.getdoc(fn)
74
81
  if not func_description:
@@ -79,7 +86,7 @@ to your component (see: https://pytorch.org/torchx/latest/component_best_practic
79
86
  args_description[param.arg_name] = param.description
80
87
  short_func_description = docstring.short_description or default_fn_desc
81
88
  if docstring.long_description:
82
- short_func_description += " ..."
89
+ short_func_description += "\n" + docstring.long_description
83
90
  return (short_func_description or default_fn_desc, args_description)
84
91
 
85
92
 
@@ -92,7 +99,7 @@ class LinterMessage:
92
99
  severity: str = "error"
93
100
 
94
101
 
95
- class TorchxFunctionValidator(abc.ABC):
102
+ class ComponentFunctionValidator(abc.ABC):
96
103
  @abc.abstractmethod
97
104
  def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
98
105
  """
@@ -110,7 +117,55 @@ class TorchxFunctionValidator(abc.ABC):
110
117
  )
111
118
 
112
119
 
113
- class TorchxFunctionArgsValidator(TorchxFunctionValidator):
120
+ def OK() -> list[LinterMessage]:
121
+ return [] # empty linter error means validation passes
122
+
123
+
124
+ def is_primitive(arg: ast.expr) -> bool:
125
+ # whether the arg is a primitive type (e.g. int, float, str, bool)
126
+ return isinstance(arg, ast.Name)
127
+
128
+
129
+ def get_generic_type(arg: ast.expr) -> ast.expr:
130
+ # returns the slice expr of a subscripted type
131
+ # `arg` must be an instance of ast.Subscript (caller checks)
132
+ # in this validator's context, this is the generic type of a container type
133
+ # e.g. for Optional[str] returns the expr for str
134
+
135
+ assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
136
+
137
+ if isinstance(arg.slice, ast.Index): # python>=3.10
138
+ return arg.slice.value
139
+ else: # python-3.9
140
+ return arg.slice
141
+
142
+
143
+ def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
144
+ """
145
+ Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146
+ is not an ``Optional``. Handles both:
147
+ 1. ``typing.Optional[T]`` (python<3.10)
148
+ 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149
+ """
150
+ # case 1: 'a: Optional[T]'
151
+ if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
152
+ return get_generic_type(arg)
153
+
154
+ # case 2: 'a: T | None' or 'a: None | T'
155
+ if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
156
+ if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
157
+ if isinstance(arg.right, ast.Constant) and arg.right.value is None:
158
+ return arg.left
159
+ if isinstance(arg.left, ast.Constant) and arg.left.value is None:
160
+ return arg.right
161
+
162
+ # case 3: is not optional
163
+ return None
164
+
165
+
166
+ class ArgTypeValidator(ComponentFunctionValidator):
167
+ """Validates component function's argument types."""
168
+
114
169
  def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
115
170
  linter_errors = []
116
171
  for arg_def in app_specs_func_def.args.args:
@@ -127,53 +182,73 @@ class TorchxFunctionArgsValidator(TorchxFunctionValidator):
127
182
  return linter_errors
128
183
 
129
184
  def _validate_arg_def(
130
- self, function_name: str, arg_def: ast.arg
185
+ self, function_name: str, arg: ast.arg
131
186
  ) -> List[LinterMessage]:
132
- if not arg_def.annotation:
133
- return [
134
- self._gen_linter_message(
135
- f"Arg {arg_def.arg} missing type annotation", arg_def.lineno
136
- )
137
- ]
138
- if isinstance(arg_def.annotation, ast.Name):
187
+ arg_type = arg.annotation # type hint
188
+
189
+ def ok() -> list[LinterMessage]:
190
+ # return value when validation passes (e.g. no linter errors)
139
191
  return []
140
- complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
141
- if complex_type_def.value.id == "Optional":
142
- # ast module in python3.9 does not have ast.Index wrapper
143
- if isinstance(complex_type_def.slice, ast.Index):
144
- complex_type_def = complex_type_def.slice.value
145
- else:
146
- complex_type_def = complex_type_def.slice
147
- # Check if type is Optional[primitive_type]
148
- if isinstance(complex_type_def, ast.Name):
149
- return []
150
- # Check if type is Union[Dict,List]
151
- type_name = complex_type_def.value.id
152
- if type_name != "Dict" and type_name != "List":
153
- desc = (
154
- f"`{function_name}` allows only Dict, List as complex types."
155
- f"Argument `{arg_def.arg}` has: {type_name}"
156
- )
157
- return [self._gen_linter_message(desc, arg_def.lineno)]
158
- linter_errors = []
159
- # ast module in python3.9 does not have objects wrapped in ast.Index
160
- if isinstance(complex_type_def.slice, ast.Index):
161
- sub_type = complex_type_def.slice.value
192
+
193
+ def err(reason: str) -> list[LinterMessage]:
194
+ msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
195
+ return [self._gen_linter_message(msg, arg.lineno)]
196
+
197
+ if not arg_type:
198
+ return err("Missing type annotation")
199
+
200
+ # Case 1: optional
201
+ if T := get_optional_type(arg_type):
202
+ # NOTE: optional types can be primitives or any of the allowed container types
203
+ # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204
+ arg_type = T
205
+
206
+ # Case 2: int, float, str, bool
207
+ if is_primitive(arg_type):
208
+ return ok()
209
+ # Case 3: Containers (Dict, List, Tuple)
210
+ elif isinstance(arg_type, ast.Subscript):
211
+ container_type = arg_type.value.id
212
+
213
+ if container_type in ["Dict", "dict"]:
214
+ KV = get_generic_type(arg_type)
215
+
216
+ assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
217
+
218
+ K, V = KV.elts
219
+ if not is_primitive(K):
220
+ return err(f"Non-primitive key type {ast.unparse(K)!r}")
221
+ if not is_primitive(V):
222
+ return err(f"Non-primitive value type {ast.unparse(V)!r}")
223
+ return ok()
224
+ elif container_type in ["List", "list"]:
225
+ T = get_generic_type(arg_type)
226
+ if is_primitive(T):
227
+ return ok()
228
+ else:
229
+ return err(f"Non-primitive element type {ast.unparse(T)!r}")
230
+ elif container_type in ["Tuple", "tuple"]:
231
+ E_N = get_generic_type(arg_type)
232
+ assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
233
+
234
+ for e in E_N.elts:
235
+ if not is_primitive(e):
236
+ return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
237
+
238
+ return ok()
239
+
240
+ return err(f"Unsupported container type {container_type!r}")
162
241
  else:
163
- sub_type = complex_type_def.slice
164
- if type_name == "Dict":
165
- sub_type_tuple = cast(ast.Tuple, sub_type)
166
- for el in sub_type_tuple.elts:
167
- if not isinstance(el, ast.Name):
168
- desc = "Dict can only have primitive types"
169
- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
170
- elif not isinstance(sub_type, ast.Name):
171
- desc = "List can only have primitive types"
172
- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
173
- return linter_errors
242
+ return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
243
+
174
244
 
245
+ class ReturnTypeValidator(ComponentFunctionValidator):
246
+ """Validates that component functions always return AppDef type"""
247
+
248
+ def __init__(self, supported_return_type: str) -> None:
249
+ super().__init__()
250
+ self._supported_return_type = supported_return_type
175
251
 
176
- class TorchxReturnValidator(TorchxFunctionValidator):
177
252
  def _get_return_annotation(
178
253
  self, app_specs_func_def: ast.FunctionDef
179
254
  ) -> Optional[str]:
@@ -197,7 +272,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
197
272
  * AppDef
198
273
  * specs.AppDef
199
274
  """
200
- supported_return_annotation = "AppDef"
275
+ supported_return_annotation = self._supported_return_type
201
276
  return_annotation = self._get_return_annotation(app_specs_func_def)
202
277
  linter_errors = []
203
278
  if not return_annotation:
@@ -220,7 +295,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
220
295
  return linter_errors
221
296
 
222
297
 
223
- class TorchFunctionVisitor(ast.NodeVisitor):
298
+ class ComponentFnVisitor(ast.NodeVisitor):
224
299
  """
225
300
  Visitor that finds the component_function and runs registered validators on it.
226
301
  Current registered validators:
@@ -238,11 +313,18 @@ class TorchFunctionVisitor(ast.NodeVisitor):
238
313
 
239
314
  """
240
315
 
241
- def __init__(self, component_function_name: str) -> None:
242
- self.validators = [
243
- TorchxFunctionArgsValidator(),
244
- TorchxReturnValidator(),
245
- ]
316
+ def __init__(
317
+ self,
318
+ component_function_name: str,
319
+ validators: Optional[List[ComponentFunctionValidator]],
320
+ ) -> None:
321
+ if validators is None:
322
+ self.validators: List[ComponentFunctionValidator] = [
323
+ ArgTypeValidator(),
324
+ ReturnTypeValidator("AppDef"),
325
+ ]
326
+ else:
327
+ self.validators = validators
246
328
  self.linter_errors: List[LinterMessage] = []
247
329
  self.component_function_name = component_function_name
248
330
  self.visited_function = False
@@ -258,7 +340,11 @@ class TorchFunctionVisitor(ast.NodeVisitor):
258
340
  self.linter_errors += validator.validate(node)
259
341
 
260
342
 
261
- def validate(path: str, component_function: str) -> List[LinterMessage]:
343
+ def validate(
344
+ path: str,
345
+ component_function: str,
346
+ validators: Optional[List[ComponentFunctionValidator]] = None,
347
+ ) -> List[LinterMessage]:
262
348
  """
263
349
  Validates the function to make sure it complies the component standard.
264
350
 
@@ -287,7 +373,7 @@ def validate(path: str, component_function: str) -> List[LinterMessage]:
287
373
  severity="error",
288
374
  )
289
375
  return [linter_message]
290
- visitor = TorchFunctionVisitor(component_function)
376
+ visitor = ComponentFnVisitor(component_function, validators)
291
377
  visitor.visit(module)
292
378
  linter_errors = visitor.linter_errors
293
379
  if not visitor.visited_function: