torchx-nightly 2025.9.9__py3-none-any.whl → 2025.9.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

@@ -255,7 +255,7 @@ def _role_to_node_properties(
255
255
  container["jobRoleArn"] = job_role_arn
256
256
  if execution_role_arn:
257
257
  container["executionRoleArn"] = execution_role_arn
258
- if role.num_replicas > 1:
258
+ if role.num_replicas > 0:
259
259
  instance_type = instance_type_from_resource(role.resource)
260
260
  if instance_type is not None:
261
261
  container["instanceType"] = instance_type
@@ -11,8 +11,9 @@ import abc
11
11
  import argparse
12
12
  import ast
13
13
  import inspect
14
+ import sys
14
15
  from dataclasses import dataclass
15
- from typing import Callable, cast, Dict, List, Optional, Tuple
16
+ from typing import Callable, Dict, List, Optional, Tuple
16
17
 
17
18
  from docstring_parser import parse
18
19
  from torchx.util.io import read_conf_file
@@ -98,7 +99,7 @@ class LinterMessage:
98
99
  severity: str = "error"
99
100
 
100
101
 
101
- class TorchxFunctionValidator(abc.ABC):
102
+ class ComponentFunctionValidator(abc.ABC):
102
103
  @abc.abstractmethod
103
104
  def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
104
105
  """
@@ -116,7 +117,55 @@ class TorchxFunctionValidator(abc.ABC):
116
117
  )
117
118
 
118
119
 
119
- class TorchxFunctionArgsValidator(TorchxFunctionValidator):
120
+ def OK() -> list[LinterMessage]:
121
+ return [] # empty linter error means validation passes
122
+
123
+
124
+ def is_primitive(arg: ast.expr) -> bool:
125
+ # whether the arg is a primitive type (e.g. int, float, str, bool)
126
+ return isinstance(arg, ast.Name)
127
+
128
+
129
+ def get_generic_type(arg: ast.expr) -> ast.expr:
130
+ # returns the slice expr of a subscripted type
131
+ # `arg` must be an instance of ast.Subscript (caller checks)
132
+ # in this validator's context, this is the generic type of a container type
133
+ # e.g. for Optional[str] returns the expr for str
134
+
135
+ assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
136
+
137
+ if isinstance(arg.slice, ast.Index): # python>=3.10
138
+ return arg.slice.value
139
+ else: # python-3.9
140
+ return arg.slice
141
+
142
+
143
+ def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
144
+ """
145
+ Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146
+ is not an ``Optional``. Handles both:
147
+ 1. ``typing.Optional[T]`` (python<3.10)
148
+ 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149
+ """
150
+ # case 1: 'a: Optional[T]'
151
+ if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
152
+ return get_generic_type(arg)
153
+
154
+ # case 2: 'a: T | None' or 'a: None | T'
155
+ if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
156
+ if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
157
+ if isinstance(arg.right, ast.Constant) and arg.right.value is None:
158
+ return arg.left
159
+ if isinstance(arg.left, ast.Constant) and arg.left.value is None:
160
+ return arg.right
161
+
162
+ # case 3: is not optional
163
+ return None
164
+
165
+
166
+ class ArgTypeValidator(ComponentFunctionValidator):
167
+ """Validates component function's argument types."""
168
+
120
169
  def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
121
170
  linter_errors = []
122
171
  for arg_def in app_specs_func_def.args.args:
@@ -133,53 +182,68 @@ class TorchxFunctionArgsValidator(TorchxFunctionValidator):
133
182
  return linter_errors
134
183
 
135
184
  def _validate_arg_def(
136
- self, function_name: str, arg_def: ast.arg
185
+ self, function_name: str, arg: ast.arg
137
186
  ) -> List[LinterMessage]:
138
- if not arg_def.annotation:
139
- return [
140
- self._gen_linter_message(
141
- f"Arg {arg_def.arg} missing type annotation", arg_def.lineno
142
- )
143
- ]
144
- if isinstance(arg_def.annotation, ast.Name):
187
+ arg_type = arg.annotation # type hint
188
+
189
+ def ok() -> list[LinterMessage]:
190
+ # return value when validation passes (e.g. no linter errors)
145
191
  return []
146
- complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
147
- if complex_type_def.value.id == "Optional":
148
- # ast module in python3.9 does not have ast.Index wrapper
149
- if isinstance(complex_type_def.slice, ast.Index):
150
- complex_type_def = complex_type_def.slice.value
151
- else:
152
- complex_type_def = complex_type_def.slice
153
- # Check if type is Optional[primitive_type]
154
- if isinstance(complex_type_def, ast.Name):
155
- return []
156
- # Check if type is Union[Dict,List]
157
- type_name = complex_type_def.value.id
158
- if type_name != "Dict" and type_name != "List":
159
- desc = (
160
- f"`{function_name}` allows only Dict, List as complex types."
161
- f"Argument `{arg_def.arg}` has: {type_name}"
162
- )
163
- return [self._gen_linter_message(desc, arg_def.lineno)]
164
- linter_errors = []
165
- # ast module in python3.9 does not have objects wrapped in ast.Index
166
- if isinstance(complex_type_def.slice, ast.Index):
167
- sub_type = complex_type_def.slice.value
192
+
193
+ def err(reason: str) -> list[LinterMessage]:
194
+ msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
195
+ return [self._gen_linter_message(msg, arg.lineno)]
196
+
197
+ if not arg_type:
198
+ return err("Missing type annotation")
199
+
200
+ # Case 1: optional
201
+ if T := get_optional_type(arg_type):
202
+ # NOTE: optional types can be primitives or any of the allowed container types
203
+ # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204
+ arg_type = T
205
+
206
+ # Case 2: int, float, str, bool
207
+ if is_primitive(arg_type):
208
+ return ok()
209
+ # Case 3: Containers (Dict, List, Tuple)
210
+ elif isinstance(arg_type, ast.Subscript):
211
+ container_type = arg_type.value.id
212
+
213
+ if container_type in ["Dict", "dict"]:
214
+ KV = get_generic_type(arg_type)
215
+
216
+ assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
217
+
218
+ K, V = KV.elts
219
+ if not is_primitive(K):
220
+ return err(f"Non-primitive key type {ast.unparse(K)!r}")
221
+ if not is_primitive(V):
222
+ return err(f"Non-primitive value type {ast.unparse(V)!r}")
223
+ return ok()
224
+ elif container_type in ["List", "list"]:
225
+ T = get_generic_type(arg_type)
226
+ if is_primitive(T):
227
+ return ok()
228
+ else:
229
+ return err(f"Non-primitive element type {ast.unparse(T)!r}")
230
+ elif container_type in ["Tuple", "tuple"]:
231
+ E_N = get_generic_type(arg_type)
232
+ assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
233
+
234
+ for e in E_N.elts:
235
+ if not is_primitive(e):
236
+ return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
237
+
238
+ return ok()
239
+
240
+ return err(f"Unsupported container type {container_type!r}")
168
241
  else:
169
- sub_type = complex_type_def.slice
170
- if type_name == "Dict":
171
- sub_type_tuple = cast(ast.Tuple, sub_type)
172
- for el in sub_type_tuple.elts:
173
- if not isinstance(el, ast.Name):
174
- desc = "Dict can only have primitive types"
175
- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
176
- elif not isinstance(sub_type, ast.Name):
177
- desc = "List can only have primitive types"
178
- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
179
- return linter_errors
242
+ return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
180
243
 
181
244
 
182
- class TorchxReturnValidator(TorchxFunctionValidator):
245
+ class ReturnTypeValidator(ComponentFunctionValidator):
246
+ """Validates that component functions always return AppDef type"""
183
247
 
184
248
  def __init__(self, supported_return_type: str) -> None:
185
249
  super().__init__()
@@ -231,7 +295,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
231
295
  return linter_errors
232
296
 
233
297
 
234
- class TorchFunctionVisitor(ast.NodeVisitor):
298
+ class ComponentFnVisitor(ast.NodeVisitor):
235
299
  """
236
300
  Visitor that finds the component_function and runs registered validators on it.
237
301
  Current registered validators:
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
252
316
  def __init__(
253
317
  self,
254
318
  component_function_name: str,
255
- validators: Optional[List[TorchxFunctionValidator]],
319
+ validators: Optional[List[ComponentFunctionValidator]],
256
320
  ) -> None:
257
321
  if validators is None:
258
- self.validators: List[TorchxFunctionValidator] = [
259
- TorchxFunctionArgsValidator(),
260
- TorchxReturnValidator("AppDef"),
322
+ self.validators: List[ComponentFunctionValidator] = [
323
+ ArgTypeValidator(),
324
+ ReturnTypeValidator("AppDef"),
261
325
  ]
262
326
  else:
263
327
  self.validators = validators
@@ -279,7 +343,7 @@ class TorchFunctionVisitor(ast.NodeVisitor):
279
343
  def validate(
280
344
  path: str,
281
345
  component_function: str,
282
- validators: Optional[List[TorchxFunctionValidator]],
346
+ validators: Optional[List[ComponentFunctionValidator]] = None,
283
347
  ) -> List[LinterMessage]:
284
348
  """
285
349
  Validates the function to make sure it complies the component standard.
@@ -309,7 +373,7 @@ def validate(
309
373
  severity="error",
310
374
  )
311
375
  return [linter_message]
312
- visitor = TorchFunctionVisitor(component_function, validators)
376
+ visitor = ComponentFnVisitor(component_function, validators)
313
377
  visitor.visit(module)
314
378
  linter_errors = visitor.linter_errors
315
379
  if not visitor.visited_function:
torchx/specs/finder.py CHANGED
@@ -19,7 +19,11 @@ from pathlib import Path
19
19
  from types import ModuleType
20
20
  from typing import Any, Callable, Dict, Generator, List, Optional, Union
21
21
 
22
- from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
22
+ from torchx.specs.file_linter import (
23
+ ComponentFunctionValidator,
24
+ get_fn_docstring,
25
+ validate,
26
+ )
23
27
  from torchx.util import entrypoints
24
28
  from torchx.util.io import read_conf_file
25
29
  from torchx.util.types import none_throws
@@ -64,7 +68,7 @@ class _Component:
64
68
  class ComponentsFinder(abc.ABC):
65
69
  @abc.abstractmethod
66
70
  def find(
67
- self, validators: Optional[List[TorchxFunctionValidator]]
71
+ self, validators: Optional[List[ComponentFunctionValidator]]
68
72
  ) -> List[_Component]:
69
73
  """
70
74
  Retrieves a set of components. A component is defined as a python
@@ -210,7 +214,7 @@ class ModuleComponentsFinder(ComponentsFinder):
210
214
  yield self._try_import(module_info.name)
211
215
 
212
216
  def find(
213
- self, validators: Optional[List[TorchxFunctionValidator]]
217
+ self, validators: Optional[List[ComponentFunctionValidator]]
214
218
  ) -> List[_Component]:
215
219
  components = []
216
220
  for m in self._iter_modules_recursive(self.base_module):
@@ -230,7 +234,7 @@ class ModuleComponentsFinder(ComponentsFinder):
230
234
  return module
231
235
 
232
236
  def _get_components_from_module(
233
- self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
237
+ self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
234
238
  ) -> List[_Component]:
235
239
  functions = getmembers(module, isfunction)
236
240
  component_defs = []
@@ -269,7 +273,7 @@ class CustomComponentsFinder(ComponentsFinder):
269
273
  self,
270
274
  path: str,
271
275
  function_name: str,
272
- validators: Optional[List[TorchxFunctionValidator]],
276
+ validators: Optional[List[ComponentFunctionValidator]],
273
277
  ) -> List[str]:
274
278
  linter_errors = validate(path, function_name, validators)
275
279
  return [linter_error.description for linter_error in linter_errors]
@@ -283,13 +287,15 @@ class CustomComponentsFinder(ComponentsFinder):
283
287
  my_component defined in some_file.py, imported in other_file.py
284
288
  and the component is invoked as other_file.py:my_component
285
289
  """
286
- path_to_function_decl = inspect.getabsfile(function)
290
+ # Unwrap decorated functions to get the original function
291
+ unwrapped_function = inspect.unwrap(function)
292
+ path_to_function_decl = inspect.getabsfile(unwrapped_function)
287
293
  if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
288
294
  return self._filepath
289
295
  return path_to_function_decl
290
296
 
291
297
  def find(
292
- self, validators: Optional[List[TorchxFunctionValidator]]
298
+ self, validators: Optional[List[ComponentFunctionValidator]]
293
299
  ) -> List[_Component]:
294
300
 
295
301
  file_source = read_conf_file(self._filepath)
@@ -321,7 +327,7 @@ class CustomComponentsFinder(ComponentsFinder):
321
327
 
322
328
 
323
329
  def _load_custom_components(
324
- validators: Optional[List[TorchxFunctionValidator]],
330
+ validators: Optional[List[ComponentFunctionValidator]],
325
331
  ) -> List[_Component]:
326
332
  component_modules = {
327
333
  name: load_fn()
@@ -346,7 +352,7 @@ def _load_custom_components(
346
352
 
347
353
 
348
354
  def _load_components(
349
- validators: Optional[List[TorchxFunctionValidator]],
355
+ validators: Optional[List[ComponentFunctionValidator]],
350
356
  ) -> Dict[str, _Component]:
351
357
  """
352
358
  Loads either the custom component defs from the entrypoint ``[torchx.components]``
@@ -368,7 +374,7 @@ _components: Optional[Dict[str, _Component]] = None
368
374
 
369
375
 
370
376
  def _find_components(
371
- validators: Optional[List[TorchxFunctionValidator]],
377
+ validators: Optional[List[ComponentFunctionValidator]],
372
378
  ) -> Dict[str, _Component]:
373
379
  global _components
374
380
  if not _components:
@@ -381,7 +387,7 @@ def _is_custom_component(component_name: str) -> bool:
381
387
 
382
388
 
383
389
  def _find_custom_components(
384
- name: str, validators: Optional[List[TorchxFunctionValidator]]
390
+ name: str, validators: Optional[List[ComponentFunctionValidator]]
385
391
  ) -> Dict[str, _Component]:
386
392
  if ":" not in name:
387
393
  raise ValueError(
@@ -393,7 +399,7 @@ def _find_custom_components(
393
399
 
394
400
 
395
401
  def get_components(
396
- validators: Optional[List[TorchxFunctionValidator]] = None,
402
+ validators: Optional[List[ComponentFunctionValidator]] = None,
397
403
  ) -> Dict[str, _Component]:
398
404
  """
399
405
  Returns all custom components registered via ``[torchx.components]`` entrypoints
@@ -448,7 +454,7 @@ def get_components(
448
454
 
449
455
 
450
456
  def get_component(
451
- name: str, validators: Optional[List[TorchxFunctionValidator]] = None
457
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
452
458
  ) -> _Component:
453
459
  """
454
460
  Retrieves components by the provided name.
@@ -477,7 +483,7 @@ def get_component(
477
483
 
478
484
 
479
485
  def get_builtin_source(
480
- name: str, validators: Optional[List[TorchxFunctionValidator]] = None
486
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
481
487
  ) -> str:
482
488
  """
483
489
  Returns a string of the the builtin component's function source code
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchx-nightly
3
- Version: 2025.9.9
3
+ Version: 2025.9.11
4
4
  Summary: TorchX SDK and Components
5
5
  Home-page: https://github.com/pytorch/torchx
6
6
  Author: TorchX Devs
@@ -66,7 +66,7 @@ torchx/runtime/tracking/__init__.py,sha256=dYnAPnrXYREfPXkpHhdOFkcYIODWEbA13PdD-
66
66
  torchx/runtime/tracking/api.py,sha256=SmUQyUKZqG3KlAhT7CJOGqRz1O274E4m63wQeOVq3CU,5472
67
67
  torchx/schedulers/__init__.py,sha256=igIBdxGhkuzH7oYVFXIA9xwjkSn3QzWZ_9dhfdl_M0I,2299
68
68
  torchx/schedulers/api.py,sha256=FbyG7mNnbXXj2_W1dQ_MnNO6fQe2FAd1rEz1R-BNg-c,14598
69
- torchx/schedulers/aws_batch_scheduler.py,sha256=06CpJw6Y71-g9CH6A6lul5-WLO5qxINf3UKLyyNmzS0,28105
69
+ torchx/schedulers/aws_batch_scheduler.py,sha256=hFxYzSZEK2SVS5sEyQC5YvNI0JJUJUQsWORlYpj_h3M,28105
70
70
  torchx/schedulers/aws_sagemaker_scheduler.py,sha256=flN8GumKE2Dz4X_foAt6Jnvt-ZVojWs6pcyrHwB0hz0,20921
71
71
  torchx/schedulers/devices.py,sha256=RjVcu22ZRl_9OKtOtmA1A3vNXgu2qD6A9ST0L0Hsg4I,1734
72
72
  torchx/schedulers/docker_scheduler.py,sha256=xuK00-dB6o8TV1YaZox7O5P09LHB2KeQ6t4eiNtqMYQ,16781
@@ -85,8 +85,8 @@ torchx/schedulers/ray/ray_driver.py,sha256=RdaCLfth16ky-5PDVOWRe_RuheWJu9xufWux2
85
85
  torchx/specs/__init__.py,sha256=Gw_2actqR_oWFtxEkGXCxGk_yrWK5JDZzwysyyqmXao,6438
86
86
  torchx/specs/api.py,sha256=wkhHOxeWH_tFO3npKqPhNg4VX2NH5gPIFEylkPBo3AU,41315
87
87
  torchx/specs/builders.py,sha256=aozVl4q3h0mY5DDJCY1M1CyLC9SW66KJy8JIih8bZJo,13810
88
- torchx/specs/file_linter.py,sha256=QCwob5STTBuy8RsxaevTI-Dk6R8siDJn81LyaOwazes,12333
89
- torchx/specs/finder.py,sha256=lMCS1hUR-x7j4sAjfJZnKFeS1RlkFTeJtit4EL_ErIo,18182
88
+ torchx/specs/file_linter.py,sha256=6_aoeuS5d9UwXseKKfPgWNTwxj-f7G1i3uO9mQepti4,14402
89
+ torchx/specs/finder.py,sha256=8peKWCKVutCRP0WmffRqsZNqjIh6xfxMKbVA4aEPM_M,18368
90
90
  torchx/specs/named_resources_aws.py,sha256=ISjHtifRJqB8u7PeAMiyLyO_S0WCaZiK-CFF3qe6JDU,11415
91
91
  torchx/specs/named_resources_generic.py,sha256=Sg4tAdqiiWDrDz2Lj_pnfsjzGIXKTou73wPseh6j55w,2646
92
92
  torchx/specs/test/components/__init__.py,sha256=J8qjUOysmcMAek2KFN13mViOXZxTYc5vCrF02t3VuFU,223
@@ -115,9 +115,9 @@ torchx/workspace/__init__.py,sha256=FqN8AN4VhR1C_SBY10MggQvNZmyanbbuPuE-JCjkyUY,
115
115
  torchx/workspace/api.py,sha256=PtDkGTC5lX03pRoYpuMz2KCmM1ZOycRP1UknqvNb97Y,6341
116
116
  torchx/workspace/dir_workspace.py,sha256=npNW_IjUZm_yS5r-8hrRkH46ndDd9a_eApT64m1S1T4,2268
117
117
  torchx/workspace/docker_workspace.py,sha256=PFu2KQNVC-0p2aKJ-W_BKA9ZOmXdCY2ABEkCExp3udQ,10269
118
- torchx_nightly-2025.9.9.dist-info/LICENSE,sha256=WVHfXhFC0Ia8LTKt_nJVYobdqTJVg_4J3Crrfm2A8KQ,1721
119
- torchx_nightly-2025.9.9.dist-info/METADATA,sha256=iUePMYe566teMWbIPYyQDMSPzvh6l2rWqOfThyt9GWw,6103
120
- torchx_nightly-2025.9.9.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
121
- torchx_nightly-2025.9.9.dist-info/entry_points.txt,sha256=T328AMXeKI3JZnnxfkEew2ZcMN1oQDtkXjMz7lkV-P4,169
122
- torchx_nightly-2025.9.9.dist-info/top_level.txt,sha256=pxew3bc2gsiViS0zADs0jb6kC5v8o_Yy_85fhHj_J1A,7
123
- torchx_nightly-2025.9.9.dist-info/RECORD,,
118
+ torchx_nightly-2025.9.11.dist-info/LICENSE,sha256=WVHfXhFC0Ia8LTKt_nJVYobdqTJVg_4J3Crrfm2A8KQ,1721
119
+ torchx_nightly-2025.9.11.dist-info/METADATA,sha256=38npZBvhuRMn5mHdg_tuQNWP3Mp0xEWaDdpRbV2YE58,6104
120
+ torchx_nightly-2025.9.11.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
121
+ torchx_nightly-2025.9.11.dist-info/entry_points.txt,sha256=T328AMXeKI3JZnnxfkEew2ZcMN1oQDtkXjMz7lkV-P4,169
122
+ torchx_nightly-2025.9.11.dist-info/top_level.txt,sha256=pxew3bc2gsiViS0zADs0jb6kC5v8o_Yy_85fhHj_J1A,7
123
+ torchx_nightly-2025.9.11.dist-info/RECORD,,