torchx-nightly 2025.8.5__py3-none-any.whl → 2025.11.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  2. torchx/cli/cmd_list.py +1 -2
  3. torchx/cli/cmd_run.py +202 -28
  4. torchx/cli/cmd_tracker.py +1 -1
  5. torchx/components/__init__.py +1 -8
  6. torchx/components/dist.py +9 -3
  7. torchx/components/integration_tests/component_provider.py +2 -2
  8. torchx/components/utils.py +1 -1
  9. torchx/distributed/__init__.py +1 -1
  10. torchx/runner/api.py +92 -81
  11. torchx/runner/config.py +3 -1
  12. torchx/runner/events/__init__.py +20 -10
  13. torchx/runner/events/api.py +1 -1
  14. torchx/schedulers/__init__.py +7 -10
  15. torchx/schedulers/api.py +20 -15
  16. torchx/schedulers/aws_batch_scheduler.py +45 -2
  17. torchx/schedulers/docker_scheduler.py +3 -0
  18. torchx/schedulers/kubernetes_scheduler.py +200 -17
  19. torchx/schedulers/local_scheduler.py +1 -0
  20. torchx/schedulers/slurm_scheduler.py +93 -24
  21. torchx/specs/__init__.py +23 -6
  22. torchx/specs/api.py +219 -11
  23. torchx/specs/builders.py +109 -28
  24. torchx/specs/file_linter.py +117 -53
  25. torchx/specs/finder.py +25 -37
  26. torchx/specs/named_resources_aws.py +13 -2
  27. torchx/tracker/__init__.py +2 -2
  28. torchx/tracker/api.py +1 -1
  29. torchx/util/entrypoints.py +1 -6
  30. torchx/util/strings.py +1 -1
  31. torchx/util/types.py +12 -1
  32. torchx/version.py +2 -2
  33. torchx/workspace/api.py +102 -5
  34. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
  35. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
  36. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
  37. torchx/examples/pipelines/__init__.py +0 -0
  38. torchx/examples/pipelines/kfp/__init__.py +0 -0
  39. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
  40. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
  41. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
  42. torchx/pipelines/kfp/__init__.py +0 -30
  43. torchx/pipelines/kfp/adapter.py +0 -274
  44. torchx/pipelines/kfp/version.py +0 -19
  45. torchx/schedulers/gcp_batch_scheduler.py +0 -497
  46. torchx/schedulers/ray/ray_common.py +0 -22
  47. torchx/schedulers/ray/ray_driver.py +0 -307
  48. torchx/schedulers/ray_scheduler.py +0 -454
  49. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
  50. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
  51. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,9 @@ import abc
11
11
  import argparse
12
12
  import ast
13
13
  import inspect
14
+ import sys
14
15
  from dataclasses import dataclass
15
- from typing import Callable, cast, Dict, List, Optional, Tuple
16
+ from typing import Callable, Dict, List, Optional, Tuple
16
17
 
17
18
  from docstring_parser import parse
18
19
  from torchx.util.io import read_conf_file
@@ -74,7 +75,7 @@ def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
74
75
  if the description
75
76
  """
76
77
  default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
77
- to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
78
+ to your component (see: https://meta-pytorch.org/torchx/latest/component_best_practices.html)"""
78
79
  args_description = _get_default_arguments_descriptions(fn)
79
80
  func_description = inspect.getdoc(fn)
80
81
  if not func_description:
@@ -98,7 +99,7 @@ class LinterMessage:
98
99
  severity: str = "error"
99
100
 
100
101
 
101
- class TorchxFunctionValidator(abc.ABC):
102
+ class ComponentFunctionValidator(abc.ABC):
102
103
  @abc.abstractmethod
103
104
  def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
104
105
  """
@@ -116,7 +117,55 @@ class TorchxFunctionValidator(abc.ABC):
116
117
  )
117
118
 
118
119
 
119
- class TorchxFunctionArgsValidator(TorchxFunctionValidator):
120
+ def OK() -> list[LinterMessage]:
121
+ return [] # empty linter error means validation passes
122
+
123
+
124
+ def is_primitive(arg: ast.expr) -> bool:
125
+ # whether the arg is a primitive type (e.g. int, float, str, bool)
126
+ return isinstance(arg, ast.Name)
127
+
128
+
129
+ def get_generic_type(arg: ast.expr) -> ast.expr:
130
+ # returns the slice expr of a subscripted type
131
+ # `arg` must be an instance of ast.Subscript (caller checks)
132
+ # in this validator's context, this is the generic type of a container type
133
+ # e.g. for Optional[str] returns the expr for str
134
+
135
+ assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
136
+
137
+ if isinstance(arg.slice, ast.Index): # python>=3.10
138
+ return arg.slice.value
139
+ else: # python-3.9
140
+ return arg.slice
141
+
142
+
143
+ def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
144
+ """
145
+ Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146
+ is not an ``Optional``. Handles both:
147
+ 1. ``typing.Optional[T]`` (python<3.10)
148
+ 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149
+ """
150
+ # case 1: 'a: Optional[T]'
151
+ if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
152
+ return get_generic_type(arg)
153
+
154
+ # case 2: 'a: T | None' or 'a: None | T'
155
+ if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
156
+ if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
157
+ if isinstance(arg.right, ast.Constant) and arg.right.value is None:
158
+ return arg.left
159
+ if isinstance(arg.left, ast.Constant) and arg.left.value is None:
160
+ return arg.right
161
+
162
+ # case 3: is not optional
163
+ return None
164
+
165
+
166
+ class ArgTypeValidator(ComponentFunctionValidator):
167
+ """Validates component function's argument types."""
168
+
120
169
  def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
121
170
  linter_errors = []
122
171
  for arg_def in app_specs_func_def.args.args:
@@ -133,53 +182,68 @@ class TorchxFunctionArgsValidator(TorchxFunctionValidator):
133
182
  return linter_errors
134
183
 
135
184
  def _validate_arg_def(
136
- self, function_name: str, arg_def: ast.arg
185
+ self, function_name: str, arg: ast.arg
137
186
  ) -> List[LinterMessage]:
138
- if not arg_def.annotation:
139
- return [
140
- self._gen_linter_message(
141
- f"Arg {arg_def.arg} missing type annotation", arg_def.lineno
142
- )
143
- ]
144
- if isinstance(arg_def.annotation, ast.Name):
187
+ arg_type = arg.annotation # type hint
188
+
189
+ def ok() -> list[LinterMessage]:
190
+ # return value when validation passes (e.g. no linter errors)
145
191
  return []
146
- complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
147
- if complex_type_def.value.id == "Optional":
148
- # ast module in python3.9 does not have ast.Index wrapper
149
- if isinstance(complex_type_def.slice, ast.Index):
150
- complex_type_def = complex_type_def.slice.value
151
- else:
152
- complex_type_def = complex_type_def.slice
153
- # Check if type is Optional[primitive_type]
154
- if isinstance(complex_type_def, ast.Name):
155
- return []
156
- # Check if type is Union[Dict,List]
157
- type_name = complex_type_def.value.id
158
- if type_name != "Dict" and type_name != "List":
159
- desc = (
160
- f"`{function_name}` allows only Dict, List as complex types."
161
- f"Argument `{arg_def.arg}` has: {type_name}"
162
- )
163
- return [self._gen_linter_message(desc, arg_def.lineno)]
164
- linter_errors = []
165
- # ast module in python3.9 does not have objects wrapped in ast.Index
166
- if isinstance(complex_type_def.slice, ast.Index):
167
- sub_type = complex_type_def.slice.value
192
+
193
+ def err(reason: str) -> list[LinterMessage]:
194
+ msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
195
+ return [self._gen_linter_message(msg, arg.lineno)]
196
+
197
+ if not arg_type:
198
+ return err("Missing type annotation")
199
+
200
+ # Case 1: optional
201
+ if T := get_optional_type(arg_type):
202
+ # NOTE: optional types can be primitives or any of the allowed container types
203
+ # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204
+ arg_type = T
205
+
206
+ # Case 2: int, float, str, bool
207
+ if is_primitive(arg_type):
208
+ return ok()
209
+ # Case 3: Containers (Dict, List, Tuple)
210
+ elif isinstance(arg_type, ast.Subscript):
211
+ container_type = arg_type.value.id
212
+
213
+ if container_type in ["Dict", "dict"]:
214
+ KV = get_generic_type(arg_type)
215
+
216
+ assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
217
+
218
+ K, V = KV.elts
219
+ if not is_primitive(K):
220
+ return err(f"Non-primitive key type {ast.unparse(K)!r}")
221
+ if not is_primitive(V):
222
+ return err(f"Non-primitive value type {ast.unparse(V)!r}")
223
+ return ok()
224
+ elif container_type in ["List", "list"]:
225
+ T = get_generic_type(arg_type)
226
+ if is_primitive(T):
227
+ return ok()
228
+ else:
229
+ return err(f"Non-primitive element type {ast.unparse(T)!r}")
230
+ elif container_type in ["Tuple", "tuple"]:
231
+ E_N = get_generic_type(arg_type)
232
+ assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
233
+
234
+ for e in E_N.elts:
235
+ if not is_primitive(e):
236
+ return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
237
+
238
+ return ok()
239
+
240
+ return err(f"Unsupported container type {container_type!r}")
168
241
  else:
169
- sub_type = complex_type_def.slice
170
- if type_name == "Dict":
171
- sub_type_tuple = cast(ast.Tuple, sub_type)
172
- for el in sub_type_tuple.elts:
173
- if not isinstance(el, ast.Name):
174
- desc = "Dict can only have primitive types"
175
- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
176
- elif not isinstance(sub_type, ast.Name):
177
- desc = "List can only have primitive types"
178
- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
179
- return linter_errors
242
+ return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
180
243
 
181
244
 
182
- class TorchxReturnValidator(TorchxFunctionValidator):
245
+ class ReturnTypeValidator(ComponentFunctionValidator):
246
+ """Validates that component functions always return AppDef type"""
183
247
 
184
248
  def __init__(self, supported_return_type: str) -> None:
185
249
  super().__init__()
@@ -231,7 +295,7 @@ class TorchxReturnValidator(TorchxFunctionValidator):
231
295
  return linter_errors
232
296
 
233
297
 
234
- class TorchFunctionVisitor(ast.NodeVisitor):
298
+ class ComponentFnVisitor(ast.NodeVisitor):
235
299
  """
236
300
  Visitor that finds the component_function and runs registered validators on it.
237
301
  Current registered validators:
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
252
316
  def __init__(
253
317
  self,
254
318
  component_function_name: str,
255
- validators: Optional[List[TorchxFunctionValidator]],
319
+ validators: Optional[List[ComponentFunctionValidator]],
256
320
  ) -> None:
257
321
  if validators is None:
258
- self.validators: List[TorchxFunctionValidator] = [
259
- TorchxFunctionArgsValidator(),
260
- TorchxReturnValidator("AppDef"),
322
+ self.validators: List[ComponentFunctionValidator] = [
323
+ ArgTypeValidator(),
324
+ ReturnTypeValidator("AppDef"),
261
325
  ]
262
326
  else:
263
327
  self.validators = validators
@@ -279,7 +343,7 @@ class TorchFunctionVisitor(ast.NodeVisitor):
279
343
  def validate(
280
344
  path: str,
281
345
  component_function: str,
282
- validators: Optional[List[TorchxFunctionValidator]],
346
+ validators: Optional[List[ComponentFunctionValidator]] = None,
283
347
  ) -> List[LinterMessage]:
284
348
  """
285
349
  Validates the function to make sure it complies the component standard.
@@ -309,7 +373,7 @@ def validate(
309
373
  severity="error",
310
374
  )
311
375
  return [linter_message]
312
- visitor = TorchFunctionVisitor(component_function, validators)
376
+ visitor = ComponentFnVisitor(component_function, validators)
313
377
  visitor.visit(module)
314
378
  linter_errors = visitor.linter_errors
315
379
  if not visitor.visited_function:
torchx/specs/finder.py CHANGED
@@ -17,9 +17,15 @@ from dataclasses import dataclass
17
17
  from inspect import getmembers, isfunction
18
18
  from pathlib import Path
19
19
  from types import ModuleType
20
- from typing import Any, Callable, Dict, Generator, List, Optional, Union
20
+ from typing import Callable, Dict, Generator, List, Optional, Union
21
21
 
22
- from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
22
+ from torchx.specs import AppDef
23
+
24
+ from torchx.specs.file_linter import (
25
+ ComponentFunctionValidator,
26
+ get_fn_docstring,
27
+ validate,
28
+ )
23
29
  from torchx.util import entrypoints
24
30
  from torchx.util.io import read_conf_file
25
31
  from torchx.util.types import none_throws
@@ -55,8 +61,7 @@ class _Component:
55
61
  description: str
56
62
  fn_name: str
57
63
 
58
- # pyre-ignore[4] TODO temporary until PipelineDef is decoupled and can be exposed as type to OSS
59
- fn: Callable[..., Any]
64
+ fn: Callable[..., AppDef]
60
65
 
61
66
  validation_errors: List[str]
62
67
 
@@ -64,7 +69,7 @@ class _Component:
64
69
  class ComponentsFinder(abc.ABC):
65
70
  @abc.abstractmethod
66
71
  def find(
67
- self, validators: Optional[List[TorchxFunctionValidator]]
72
+ self, validators: Optional[List[ComponentFunctionValidator]]
68
73
  ) -> List[_Component]:
69
74
  """
70
75
  Retrieves a set of components. A component is defined as a python
@@ -210,7 +215,7 @@ class ModuleComponentsFinder(ComponentsFinder):
210
215
  yield self._try_import(module_info.name)
211
216
 
212
217
  def find(
213
- self, validators: Optional[List[TorchxFunctionValidator]]
218
+ self, validators: Optional[List[ComponentFunctionValidator]]
214
219
  ) -> List[_Component]:
215
220
  components = []
216
221
  for m in self._iter_modules_recursive(self.base_module):
@@ -230,7 +235,7 @@ class ModuleComponentsFinder(ComponentsFinder):
230
235
  return module
231
236
 
232
237
  def _get_components_from_module(
233
- self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
238
+ self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
234
239
  ) -> List[_Component]:
235
240
  functions = getmembers(module, isfunction)
236
241
  component_defs = []
@@ -269,28 +274,17 @@ class CustomComponentsFinder(ComponentsFinder):
269
274
  self,
270
275
  path: str,
271
276
  function_name: str,
272
- validators: Optional[List[TorchxFunctionValidator]],
277
+ validators: Optional[List[ComponentFunctionValidator]],
273
278
  ) -> List[str]:
274
279
  linter_errors = validate(path, function_name, validators)
275
280
  return [linter_error.description for linter_error in linter_errors]
276
281
 
277
- def _get_path_to_function_decl(
278
- self, function: Callable[..., Any] # pyre-ignore[2]
279
- ) -> str:
280
- """
281
- Attempts to return the path to the file where the function is implemented.
282
- This can be different from the path where the function is looked up, for example if we have:
283
- my_component defined in some_file.py, imported in other_file.py
284
- and the component is invoked as other_file.py:my_component
285
- """
286
- path_to_function_decl = inspect.getabsfile(function)
287
- if path_to_function_decl is None or not os.path.isfile(path_to_function_decl):
288
- return self._filepath
289
- return path_to_function_decl
290
-
291
282
  def find(
292
- self, validators: Optional[List[TorchxFunctionValidator]]
283
+ self, validators: Optional[List[ComponentFunctionValidator]]
293
284
  ) -> List[_Component]:
285
+ validation_errors = self._get_validation_errors(
286
+ self._filepath, self._function_name, validators
287
+ )
294
288
 
295
289
  file_source = read_conf_file(self._filepath)
296
290
  namespace = copy.copy(globals())
@@ -303,12 +297,6 @@ class CustomComponentsFinder(ComponentsFinder):
303
297
  )
304
298
  app_fn = namespace[self._function_name]
305
299
  fn_desc, _ = get_fn_docstring(app_fn)
306
-
307
- func_path = self._get_path_to_function_decl(app_fn)
308
- validation_errors = self._get_validation_errors(
309
- func_path, self._function_name, validators
310
- )
311
-
312
300
  return [
313
301
  _Component(
314
302
  name=f"{self._filepath}:{self._function_name}",
@@ -321,7 +309,7 @@ class CustomComponentsFinder(ComponentsFinder):
321
309
 
322
310
 
323
311
  def _load_custom_components(
324
- validators: Optional[List[TorchxFunctionValidator]],
312
+ validators: Optional[List[ComponentFunctionValidator]],
325
313
  ) -> List[_Component]:
326
314
  component_modules = {
327
315
  name: load_fn()
@@ -346,7 +334,7 @@ def _load_custom_components(
346
334
 
347
335
 
348
336
  def _load_components(
349
- validators: Optional[List[TorchxFunctionValidator]],
337
+ validators: Optional[List[ComponentFunctionValidator]],
350
338
  ) -> Dict[str, _Component]:
351
339
  """
352
340
  Loads either the custom component defs from the entrypoint ``[torchx.components]``
@@ -368,7 +356,7 @@ _components: Optional[Dict[str, _Component]] = None
368
356
 
369
357
 
370
358
  def _find_components(
371
- validators: Optional[List[TorchxFunctionValidator]],
359
+ validators: Optional[List[ComponentFunctionValidator]],
372
360
  ) -> Dict[str, _Component]:
373
361
  global _components
374
362
  if not _components:
@@ -381,7 +369,7 @@ def _is_custom_component(component_name: str) -> bool:
381
369
 
382
370
 
383
371
  def _find_custom_components(
384
- name: str, validators: Optional[List[TorchxFunctionValidator]]
372
+ name: str, validators: Optional[List[ComponentFunctionValidator]]
385
373
  ) -> Dict[str, _Component]:
386
374
  if ":" not in name:
387
375
  raise ValueError(
@@ -393,7 +381,7 @@ def _find_custom_components(
393
381
 
394
382
 
395
383
  def get_components(
396
- validators: Optional[List[TorchxFunctionValidator]] = None,
384
+ validators: Optional[List[ComponentFunctionValidator]] = None,
397
385
  ) -> Dict[str, _Component]:
398
386
  """
399
387
  Returns all custom components registered via ``[torchx.components]`` entrypoints
@@ -448,7 +436,7 @@ def get_components(
448
436
 
449
437
 
450
438
  def get_component(
451
- name: str, validators: Optional[List[TorchxFunctionValidator]] = None
439
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
452
440
  ) -> _Component:
453
441
  """
454
442
  Retrieves components by the provided name.
@@ -464,7 +452,7 @@ def get_component(
464
452
  raise ComponentNotFoundException(
465
453
  f"Component `{name}` not found. Please make sure it is one of the "
466
454
  "builtins: `torchx builtins`. Or registered via `[torchx.components]` "
467
- "entry point (see: https://pytorch.org/torchx/latest/configure.html)"
455
+ "entry point (see: https://meta-pytorch.org/torchx/latest/configure.html)"
468
456
  )
469
457
 
470
458
  component = components[name]
@@ -477,7 +465,7 @@ def get_component(
477
465
 
478
466
 
479
467
  def get_builtin_source(
480
- name: str, validators: Optional[List[TorchxFunctionValidator]] = None
468
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
481
469
  ) -> str:
482
470
  """
483
471
  Returns a string of the the builtin component's function source code
@@ -16,7 +16,7 @@ the equvalent resource in mem, cpu and gpu numbers.
16
16
 
17
17
  .. note::
18
18
  These resource definitions may change in future. It is expected for each user to
19
- manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
19
+ manage their own resources. Follow https://meta-pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
20
20
  to set up named resources.
21
21
 
22
22
  Usage:
@@ -47,7 +47,7 @@ NEURON_DEVICE = "aws.amazon.com/neurondevice"
47
47
  MEM_TAX = 0.96
48
48
 
49
49
  # determines instance type for non-honogeneous CEs
50
- # see https://github.com/pytorch/torchx/issues/780
50
+ # see https://github.com/meta-pytorch/torchx/issues/780
51
51
  K8S_ITYPE = "node.kubernetes.io/instance-type"
52
52
  GiB: int = int(1024 * MEM_TAX)
53
53
 
@@ -120,6 +120,16 @@ def aws_p5_48xlarge() -> Resource:
120
120
  )
121
121
 
122
122
 
123
+ def aws_p5e_48xlarge() -> Resource:
124
+ return Resource(
125
+ cpu=192,
126
+ gpu=8,
127
+ memMB=2048 * GiB,
128
+ capabilities={K8S_ITYPE: "p5e.48xlarge"},
129
+ devices={EFA_DEVICE: 32},
130
+ )
131
+
132
+
123
133
  def aws_p5en_48xlarge() -> Resource:
124
134
  return Resource(
125
135
  cpu=192,
@@ -419,6 +429,7 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
419
429
  "aws_p4d.24xlarge": aws_p4d_24xlarge,
420
430
  "aws_p4de.24xlarge": aws_p4de_24xlarge,
421
431
  "aws_p5.48xlarge": aws_p5_48xlarge,
432
+ "aws_p5e.48xlarge": aws_p5e_48xlarge,
422
433
  "aws_p5en.48xlarge": aws_p5en_48xlarge,
423
434
  "aws_g4dn.xlarge": aws_g4dn_xlarge,
424
435
  "aws_g4dn.2xlarge": aws_g4dn_2xlarge,
@@ -32,7 +32,7 @@ implementation.
32
32
 
33
33
  Example usage
34
34
  -------------
35
- Sample `code <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
35
+ Sample `code <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
36
36
 
37
37
 
38
38
  Tracker Setup
@@ -111,7 +111,7 @@ Use :py:meth:`~torchx.tracker.app_run_from_env`:
111
111
  Reference :py:class:`~torchx.tracker.api.TrackerBase` implementation
112
112
  --------------------------------------------------------------------
113
113
  :py:class:`~torchx.tracker.backend.fsspec.FsspecTracker` provides reference implementation of a tracker backend.
114
- GitHub example `directory <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
114
+ GitHub example `directory <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
115
115
  configure and use it in user application.
116
116
 
117
117
 
torchx/tracker/api.py CHANGED
@@ -191,7 +191,7 @@ def build_trackers(
191
191
  factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
192
192
  if not factory:
193
193
  logger.warning(
194
- f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
194
+ f"No tracker factory `{factory_name}` found in entry_points or modules. See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
195
195
  )
196
196
  continue
197
197
  if config:
@@ -69,9 +69,7 @@ def _defer_load_ep(ep: EntryPoint) -> object:
69
69
  return run
70
70
 
71
71
 
72
- def load_group(
73
- group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
74
- ):
72
+ def load_group(group: str, default: Optional[Dict[str, Any]] = None):
75
73
  """
76
74
  Loads all the entry points specified by ``group`` and returns
77
75
  the entry points as a map of ``name (str) -> deferred_load_fn``.
@@ -90,7 +88,6 @@ def load_group(
90
88
  1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
91
89
  1. ``load_group("food")`` -> ``None``
92
90
  1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
93
- 1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
94
91
 
95
92
 
96
93
  If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
@@ -115,8 +112,6 @@ def load_group(
115
112
  entrypoints = metadata.entry_points().get(group, ())
116
113
 
117
114
  if len(entrypoints) == 0:
118
- if skip_defaults:
119
- return None
120
115
  return default
121
116
 
122
117
  eps = {}
torchx/util/strings.py CHANGED
@@ -13,7 +13,7 @@ def normalize_str(data: str) -> str:
13
13
  """
14
14
  Invokes ``lower`` on thes string and removes all
15
15
  characters that do not satisfy ``[a-z0-9\\-]`` pattern.
16
- This method is mostly used to make sure kubernetes and gcp_batch scheduler gets
16
+ This method is mostly used to make sure kubernetes scheduler gets
17
17
  the job name that does not violate its restrictions.
18
18
  """
19
19
  if data.startswith("-"):
torchx/util/types.py CHANGED
@@ -8,6 +8,7 @@
8
8
 
9
9
  import inspect
10
10
  import re
11
+ from types import UnionType
11
12
  from typing import Any, Callable, Optional, Tuple, TypeVar, Union
12
13
 
13
14
 
@@ -234,10 +235,20 @@ def decode_optional(param_type: Any) -> Any:
234
235
  If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE
235
236
  Otherwise returns ``param_type``
236
237
  """
238
+
237
239
  if not hasattr(param_type, "__origin__"):
238
- return param_type
240
+ if isinstance(param_type, UnionType):
241
+ # handle BinOp style Optional (e.g. `T | None`)
242
+ if len(param_type.__args__) == 2 and param_type.__args__[1] is type(None):
243
+ return param_type.__args__[0]
244
+ else:
245
+ return param_type
246
+ else:
247
+ return param_type
248
+
239
249
  if param_type.__origin__ is not Union:
240
250
  return param_type
251
+
241
252
  args = param_type.__args__
242
253
  if len(args) == 2 and args[1] is type(None):
243
254
  return args[0]
torchx/version.py CHANGED
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
@@ -7,6 +6,7 @@
7
6
 
8
7
  # pyre-strict
9
8
 
9
+ from torchx._version import BASE_VERSION
10
10
  from torchx.util.entrypoints import load
11
11
 
12
12
  # Follows PEP-0440 version scheme guidelines
@@ -18,7 +18,7 @@ from torchx.util.entrypoints import load
18
18
  # 0.1.0bN # Beta release
19
19
  # 0.1.0rcN # Release Candidate
20
20
  # 0.1.0 # Final release
21
- __version__ = "0.8.0dev0"
21
+ __version__: str = BASE_VERSION
22
22
 
23
23
 
24
24
  # Use the github container registry images corresponding to the current package