torchx-nightly 2025.3.19__py3-none-any.whl → 2025.3.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

torchx/specs/builders.py CHANGED
@@ -101,7 +101,7 @@ def _merge_config_values_with_args(
101
101
 
102
102
 
103
103
  def parse_args(
104
- cmpnt_fn: Callable[..., AppDef],
104
+ cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
105
105
  cmpnt_args: List[str],
106
106
  cmpnt_defaults: Optional[Dict[str, Any]] = None,
107
107
  config: Optional[Dict[str, Any]] = None,
@@ -130,7 +130,7 @@ def parse_args(
130
130
 
131
131
 
132
132
  def materialize_appdef(
133
- cmpnt_fn: Callable[..., AppDef],
133
+ cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
134
134
  cmpnt_args: List[str],
135
135
  cmpnt_defaults: Optional[Dict[str, Any]] = None,
136
136
  config: Optional[Dict[str, Any]] = None,
@@ -187,6 +187,10 @@ def materialize_appdef(
187
187
  var_arg = var_arg[1:]
188
188
 
189
189
  appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
190
+ if not isinstance(appdef, AppDef):
191
+ raise TypeError(
192
+ f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
193
+ )
190
194
 
191
195
  return appdef
192
196
 
@@ -244,11 +244,18 @@ class TorchFunctionVisitor(ast.NodeVisitor):
244
244
 
245
245
  """
246
246
 
247
- def __init__(self, component_function_name: str) -> None:
248
- self.validators = [
249
- TorchxFunctionArgsValidator(),
250
- TorchxReturnValidator(),
251
- ]
247
+ def __init__(
248
+ self,
249
+ component_function_name: str,
250
+ validators: Optional[List[TorchxFunctionValidator]],
251
+ ) -> None:
252
+ if validators is None:
253
+ self.validators: List[TorchxFunctionValidator] = [
254
+ TorchxFunctionArgsValidator(),
255
+ TorchxReturnValidator(),
256
+ ]
257
+ else:
258
+ self.validators = validators
252
259
  self.linter_errors: List[LinterMessage] = []
253
260
  self.component_function_name = component_function_name
254
261
  self.visited_function = False
@@ -264,7 +271,11 @@ class TorchFunctionVisitor(ast.NodeVisitor):
264
271
  self.linter_errors += validator.validate(node)
265
272
 
266
273
 
267
- def validate(path: str, component_function: str) -> List[LinterMessage]:
274
+ def validate(
275
+ path: str,
276
+ component_function: str,
277
+ validators: Optional[List[TorchxFunctionValidator]],
278
+ ) -> List[LinterMessage]:
268
279
  """
269
280
  Validates the function to make sure it complies the component standard.
270
281
 
@@ -293,7 +304,7 @@ def validate(path: str, component_function: str) -> List[LinterMessage]:
293
304
  severity="error",
294
305
  )
295
306
  return [linter_message]
296
- visitor = TorchFunctionVisitor(component_function)
307
+ visitor = TorchFunctionVisitor(component_function, validators)
297
308
  visitor.visit(module)
298
309
  linter_errors = visitor.linter_errors
299
310
  if not visitor.visited_function:
torchx/specs/finder.py CHANGED
@@ -16,14 +16,15 @@ from dataclasses import dataclass
16
16
  from inspect import getmembers, isfunction
17
17
  from pathlib import Path
18
18
  from types import ModuleType
19
- from typing import Callable, Dict, Generator, List, Optional, Union
19
+ from typing import Any, Callable, Dict, Generator, List, Optional, Union
20
20
 
21
21
  from torchx.specs import AppDef
22
- from torchx.specs.file_linter import get_fn_docstring, validate
22
+ from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
23
23
  from torchx.util import entrypoints
24
24
  from torchx.util.io import read_conf_file
25
25
  from torchx.util.types import none_throws
26
26
 
27
+
27
28
  logger: logging.Logger = logging.getLogger(__name__)
28
29
 
29
30
 
@@ -53,13 +54,18 @@ class _Component:
53
54
  name: str
54
55
  description: str
55
56
  fn_name: str
56
- fn: Callable[..., AppDef]
57
+
58
+ # pyre-ignore[4] TODO temporary until PipelineDef is decoupled and can be exposed as type to OSS
59
+ fn: Callable[..., Any]
60
+
57
61
  validation_errors: List[str]
58
62
 
59
63
 
60
64
  class ComponentsFinder(abc.ABC):
61
65
  @abc.abstractmethod
62
- def find(self) -> List[_Component]:
66
+ def find(
67
+ self, validators: Optional[List[TorchxFunctionValidator]]
68
+ ) -> List[_Component]:
63
69
  """
64
70
  Retrieves a set of components. A component is defined as a python
65
71
  function that conforms to ``torchx.specs.file_linter`` linter.
@@ -203,10 +209,12 @@ class ModuleComponentsFinder(ComponentsFinder):
203
209
  else:
204
210
  yield self._try_import(module_info.name)
205
211
 
206
- def find(self) -> List[_Component]:
212
+ def find(
213
+ self, validators: Optional[List[TorchxFunctionValidator]]
214
+ ) -> List[_Component]:
207
215
  components = []
208
216
  for m in self._iter_modules_recursive(self.base_module):
209
- components += self._get_components_from_module(m)
217
+ components += self._get_components_from_module(m, validators)
210
218
  return components
211
219
 
212
220
  def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
@@ -221,7 +229,9 @@ class ModuleComponentsFinder(ComponentsFinder):
221
229
  else:
222
230
  return module
223
231
 
224
- def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
232
+ def _get_components_from_module(
233
+ self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
234
+ ) -> List[_Component]:
225
235
  functions = getmembers(module, isfunction)
226
236
  component_defs = []
227
237
 
@@ -230,7 +240,7 @@ class ModuleComponentsFinder(ComponentsFinder):
230
240
  module_path = os.path.abspath(module_path)
231
241
  rel_module_name = module_relname(module, relative_to=self.base_module)
232
242
  for function_name, function in functions:
233
- linter_errors = validate(module_path, function_name)
243
+ linter_errors = validate(module_path, function_name, validators)
234
244
  component_desc, _ = get_fn_docstring(function)
235
245
 
236
246
  # remove empty string to deal with group=""
@@ -255,13 +265,20 @@ class CustomComponentsFinder(ComponentsFinder):
255
265
  self._filepath = filepath
256
266
  self._function_name = function_name
257
267
 
258
- def _get_validation_errors(self, path: str, function_name: str) -> List[str]:
259
- linter_errors = validate(path, function_name)
268
+ def _get_validation_errors(
269
+ self,
270
+ path: str,
271
+ function_name: str,
272
+ validators: Optional[List[TorchxFunctionValidator]],
273
+ ) -> List[str]:
274
+ linter_errors = validate(path, function_name, validators)
260
275
  return [linter_error.description for linter_error in linter_errors]
261
276
 
262
- def find(self) -> List[_Component]:
277
+ def find(
278
+ self, validators: Optional[List[TorchxFunctionValidator]]
279
+ ) -> List[_Component]:
263
280
  validation_errors = self._get_validation_errors(
264
- self._filepath, self._function_name
281
+ self._filepath, self._function_name, validators
265
282
  )
266
283
 
267
284
  file_source = read_conf_file(self._filepath)
@@ -284,7 +301,9 @@ class CustomComponentsFinder(ComponentsFinder):
284
301
  ]
285
302
 
286
303
 
287
- def _load_custom_components() -> List[_Component]:
304
+ def _load_custom_components(
305
+ validators: Optional[List[TorchxFunctionValidator]],
306
+ ) -> List[_Component]:
288
307
  component_modules = {
289
308
  name: load_fn()
290
309
  for name, load_fn in
@@ -303,11 +322,13 @@ def _load_custom_components() -> List[_Component]:
303
322
  # _0 = torchx.components.dist
304
323
  # _1 = torchx.components.utils
305
324
  group = "" if group.startswith("_") else group
306
- components += ModuleComponentsFinder(module, group).find()
325
+ components += ModuleComponentsFinder(module, group).find(validators)
307
326
  return components
308
327
 
309
328
 
310
- def _load_components() -> Dict[str, _Component]:
329
+ def _load_components(
330
+ validators: Optional[List[TorchxFunctionValidator]],
331
+ ) -> Dict[str, _Component]:
311
332
  """
312
333
  Loads either the custom component defs from the entrypoint ``[torchx.components]``
313
334
  or the default builtins from ``torchx.components`` module.
@@ -318,19 +339,21 @@ def _load_components() -> Dict[str, _Component]:
318
339
 
319
340
  """
320
341
 
321
- components = _load_custom_components()
342
+ components = _load_custom_components(validators)
322
343
  if not components:
323
- components = ModuleComponentsFinder("torchx.components", "").find()
344
+ components = ModuleComponentsFinder("torchx.components", "").find(validators)
324
345
  return {c.name: c for c in components}
325
346
 
326
347
 
327
348
  _components: Optional[Dict[str, _Component]] = None
328
349
 
329
350
 
330
- def _find_components() -> Dict[str, _Component]:
351
+ def _find_components(
352
+ validators: Optional[List[TorchxFunctionValidator]],
353
+ ) -> Dict[str, _Component]:
331
354
  global _components
332
355
  if not _components:
333
- _components = _load_components()
356
+ _components = _load_components(validators)
334
357
  return none_throws(_components)
335
358
 
336
359
 
@@ -338,17 +361,21 @@ def _is_custom_component(component_name: str) -> bool:
338
361
  return ":" in component_name
339
362
 
340
363
 
341
- def _find_custom_components(name: str) -> Dict[str, _Component]:
364
+ def _find_custom_components(
365
+ name: str, validators: Optional[List[TorchxFunctionValidator]]
366
+ ) -> Dict[str, _Component]:
342
367
  if ":" not in name:
343
368
  raise ValueError(
344
369
  f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`"
345
370
  )
346
371
  filepath, component_name = name.split(":")
347
- components = CustomComponentsFinder(filepath, component_name).find()
372
+ components = CustomComponentsFinder(filepath, component_name).find(validators)
348
373
  return {component.name: component for component in components}
349
374
 
350
375
 
351
- def get_components() -> Dict[str, _Component]:
376
+ def get_components(
377
+ validators: Optional[List[TorchxFunctionValidator]] = None,
378
+ ) -> Dict[str, _Component]:
352
379
  """
353
380
  Returns all custom components registered via ``[torchx.components]`` entrypoints
354
381
  OR builtin components that ship with TorchX (but not both).
@@ -395,13 +422,15 @@ def get_components() -> Dict[str, _Component]:
395
422
  """
396
423
 
397
424
  valid_components: Dict[str, _Component] = {}
398
- for component_name, component in _find_components().items():
425
+ for component_name, component in _find_components(validators).items():
399
426
  if len(component.validation_errors) == 0:
400
427
  valid_components[component_name] = component
401
428
  return valid_components
402
429
 
403
430
 
404
- def get_component(name: str) -> _Component:
431
+ def get_component(
432
+ name: str, validators: Optional[List[TorchxFunctionValidator]] = None
433
+ ) -> _Component:
405
434
  """
406
435
  Retrieves components by the provided name.
407
436
 
@@ -409,9 +438,9 @@ def get_component(name: str) -> _Component:
409
438
  Component or None if no component with ``name`` exists
410
439
  """
411
440
  if _is_custom_component(name):
412
- components = _find_custom_components(name)
441
+ components = _find_custom_components(name, validators)
413
442
  else:
414
- components = _find_components()
443
+ components = _find_components(validators)
415
444
  if name not in components:
416
445
  raise ComponentNotFoundException(
417
446
  f"Component `{name}` not found. Please make sure it is one of the "
@@ -428,7 +457,9 @@ def get_component(name: str) -> _Component:
428
457
  return component
429
458
 
430
459
 
431
- def get_builtin_source(name: str) -> str:
460
+ def get_builtin_source(
461
+ name: str, validators: Optional[List[TorchxFunctionValidator]] = None
462
+ ) -> str:
432
463
  """
433
464
  Returns a string of the the builtin component's function source code
434
465
  with all the import statements. Intended to be used to make a copy
@@ -446,7 +477,7 @@ def get_builtin_source(name: str) -> str:
446
477
  are optimized and formatting adheres to your organization's standards.
447
478
  """
448
479
 
449
- component = get_component(name)
480
+ component = get_component(name, validators)
450
481
  fn = component.fn
451
482
  fn_name = component.name.split(".")[-1]
452
483
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchx-nightly
3
- Version: 2025.3.19
3
+ Version: 2025.3.21
4
4
  Summary: TorchX SDK and Components
5
5
  Home-page: https://github.com/pytorch/torchx
6
6
  Author: TorchX Devs
@@ -84,9 +84,9 @@ torchx/schedulers/ray/ray_common.py,sha256=pyNYFvTKVwdjDAeCBNbPwAWwVNmlLOJWExfn9
84
84
  torchx/schedulers/ray/ray_driver.py,sha256=RdaCLfth16ky-5PDVOWRe_RuheWJu9xufWux2F9T7iw,12302
85
85
  torchx/specs/__init__.py,sha256=T8xUCz7iVE6OsUL5P4Pzy2B8ZY_YinCVDwUer5Q-XPc,6179
86
86
  torchx/specs/api.py,sha256=ISdnfGH1R0-1Ah23UsPOZQT3WH63dIMj9CxA3QRoP0g,38046
87
- torchx/specs/builders.py,sha256=QDcQrnCO4bdSaiP0216XbCgTsnLutO_1_FW5jDiEIWI,9939
88
- torchx/specs/file_linter.py,sha256=IeiomB1BgHUlT-ZsvGxar3llY63NOupfLBrOrD_---A,11860
89
- torchx/specs/finder.py,sha256=MnwxG_UC4a-3X2wQ37ANEQR6D1TvriCLyuVYBh_-wuI,16249
87
+ torchx/specs/builders.py,sha256=rejt-X68zxTE_8_MWeoQ30cOj6pJNe4DS9YRhjrSDuA,10127
88
+ torchx/specs/file_linter.py,sha256=OOyDnfg9m_9BiyhXc3T0fhE_-ykKCJ32gxCQGbfT-yE,12157
89
+ torchx/specs/finder.py,sha256=N8EJcxYmG8ZBreB5FgvufGq-CnNOmvVYUpAdflkv7K8,17312
90
90
  torchx/specs/named_resources_aws.py,sha256=ISjHtifRJqB8u7PeAMiyLyO_S0WCaZiK-CFF3qe6JDU,11415
91
91
  torchx/specs/named_resources_generic.py,sha256=Sg4tAdqiiWDrDz2Lj_pnfsjzGIXKTou73wPseh6j55w,2646
92
92
  torchx/specs/test/components/__init__.py,sha256=J8qjUOysmcMAek2KFN13mViOXZxTYc5vCrF02t3VuFU,223
@@ -115,9 +115,9 @@ torchx/workspace/__init__.py,sha256=FqN8AN4VhR1C_SBY10MggQvNZmyanbbuPuE-JCjkyUY,
115
115
  torchx/workspace/api.py,sha256=PtDkGTC5lX03pRoYpuMz2KCmM1ZOycRP1UknqvNb97Y,6341
116
116
  torchx/workspace/dir_workspace.py,sha256=npNW_IjUZm_yS5r-8hrRkH46ndDd9a_eApT64m1S1T4,2268
117
117
  torchx/workspace/docker_workspace.py,sha256=PFu2KQNVC-0p2aKJ-W_BKA9ZOmXdCY2ABEkCExp3udQ,10269
118
- torchx_nightly-2025.3.19.dist-info/LICENSE,sha256=WVHfXhFC0Ia8LTKt_nJVYobdqTJVg_4J3Crrfm2A8KQ,1721
119
- torchx_nightly-2025.3.19.dist-info/METADATA,sha256=qZoFDHWo4OWbDAk9_xXnMY8uX7wYT_kkV0FLO64aWuc,6169
120
- torchx_nightly-2025.3.19.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
121
- torchx_nightly-2025.3.19.dist-info/entry_points.txt,sha256=T328AMXeKI3JZnnxfkEew2ZcMN1oQDtkXjMz7lkV-P4,169
122
- torchx_nightly-2025.3.19.dist-info/top_level.txt,sha256=pxew3bc2gsiViS0zADs0jb6kC5v8o_Yy_85fhHj_J1A,7
123
- torchx_nightly-2025.3.19.dist-info/RECORD,,
118
+ torchx_nightly-2025.3.21.dist-info/LICENSE,sha256=WVHfXhFC0Ia8LTKt_nJVYobdqTJVg_4J3Crrfm2A8KQ,1721
119
+ torchx_nightly-2025.3.21.dist-info/METADATA,sha256=H5AYA_AcfLs_WzWHkwaIRhWWz6Js7YJ2dMWjfB5S03o,6169
120
+ torchx_nightly-2025.3.21.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
121
+ torchx_nightly-2025.3.21.dist-info/entry_points.txt,sha256=T328AMXeKI3JZnnxfkEew2ZcMN1oQDtkXjMz7lkV-P4,169
122
+ torchx_nightly-2025.3.21.dist-info/top_level.txt,sha256=pxew3bc2gsiViS0zADs0jb6kC5v8o_Yy_85fhHj_J1A,7
123
+ torchx_nightly-2025.3.21.dist-info/RECORD,,