torchx-nightly 2025.8.5__py3-none-any.whl → 2026.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  2. torchx/cli/cmd_delete.py +30 -0
  3. torchx/cli/cmd_list.py +1 -2
  4. torchx/cli/cmd_run.py +202 -28
  5. torchx/cli/cmd_tracker.py +1 -1
  6. torchx/cli/main.py +2 -0
  7. torchx/components/__init__.py +1 -8
  8. torchx/components/dist.py +9 -3
  9. torchx/components/integration_tests/component_provider.py +2 -2
  10. torchx/components/utils.py +1 -1
  11. torchx/distributed/__init__.py +1 -1
  12. torchx/runner/api.py +102 -81
  13. torchx/runner/config.py +3 -1
  14. torchx/runner/events/__init__.py +20 -10
  15. torchx/runner/events/api.py +1 -1
  16. torchx/schedulers/__init__.py +7 -10
  17. torchx/schedulers/api.py +66 -25
  18. torchx/schedulers/aws_batch_scheduler.py +47 -6
  19. torchx/schedulers/aws_sagemaker_scheduler.py +1 -1
  20. torchx/schedulers/docker_scheduler.py +4 -3
  21. torchx/schedulers/ids.py +27 -23
  22. torchx/schedulers/kubernetes_mcad_scheduler.py +1 -4
  23. torchx/schedulers/kubernetes_scheduler.py +355 -36
  24. torchx/schedulers/local_scheduler.py +2 -1
  25. torchx/schedulers/lsf_scheduler.py +1 -1
  26. torchx/schedulers/slurm_scheduler.py +102 -27
  27. torchx/specs/__init__.py +40 -9
  28. torchx/specs/api.py +222 -12
  29. torchx/specs/builders.py +109 -28
  30. torchx/specs/file_linter.py +117 -53
  31. torchx/specs/finder.py +25 -37
  32. torchx/specs/named_resources_aws.py +13 -2
  33. torchx/specs/overlays.py +106 -0
  34. torchx/tracker/__init__.py +2 -2
  35. torchx/tracker/api.py +1 -1
  36. torchx/util/entrypoints.py +1 -6
  37. torchx/util/strings.py +1 -1
  38. torchx/util/types.py +12 -1
  39. torchx/version.py +2 -2
  40. torchx/workspace/api.py +102 -5
  41. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/METADATA +35 -49
  42. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/RECORD +46 -56
  43. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/WHEEL +1 -1
  44. torchx/examples/pipelines/__init__.py +0 -0
  45. torchx/examples/pipelines/kfp/__init__.py +0 -0
  46. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
  47. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
  48. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
  49. torchx/pipelines/kfp/__init__.py +0 -30
  50. torchx/pipelines/kfp/adapter.py +0 -274
  51. torchx/pipelines/kfp/version.py +0 -19
  52. torchx/schedulers/gcp_batch_scheduler.py +0 -497
  53. torchx/schedulers/ray/ray_common.py +0 -22
  54. torchx/schedulers/ray/ray_driver.py +0 -307
  55. torchx/schedulers/ray_scheduler.py +0 -454
  56. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/entry_points.txt +0 -0
  57. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info/licenses}/LICENSE +0 -0
  58. {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/top_level.txt +0 -0
torchx/specs/api.py CHANGED
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
@@ -12,11 +11,15 @@ import copy
12
11
  import inspect
13
12
  import json
14
13
  import logging as logger
14
+ import os
15
+ import pathlib
15
16
  import re
17
+ import shutil
16
18
  import typing
19
+ import warnings
17
20
  from dataclasses import asdict, dataclass, field
18
21
  from datetime import datetime
19
- from enum import Enum
22
+ from enum import Enum, IntEnum
20
23
  from json import JSONDecodeError
21
24
  from string import Template
22
25
  from typing import (
@@ -67,6 +70,32 @@ YELLOW_BOLD = "\033[1;33m"
67
70
  RESET = "\033[0m"
68
71
 
69
72
 
73
+ def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
74
+ """
75
+ Path to the "dot-directory" for torchx.
76
+ Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.
77
+
78
+ Usage:
79
+
80
+ .. doc-test::
81
+
82
+ from pathlib import Path
83
+ from torchx.specs import TORCHX_HOME
84
+
85
+ assert TORCHX_HOME() == Path.home() / ".torchx"
86
+ assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
87
+ ```
88
+ """
89
+
90
+ default_dir = str(pathlib.Path.home() / ".torchx")
91
+ torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))
92
+
93
+ torchx_home = torchx_home / os.path.sep.join(subdir_paths)
94
+ torchx_home.mkdir(parents=True, exist_ok=True)
95
+
96
+ return torchx_home
97
+
98
+
70
99
  # ========================================
71
100
  # ==== Distributed AppDef API =======
72
101
  # ========================================
@@ -83,6 +112,8 @@ class Resource:
83
112
  memMB: MB of ram
84
113
  capabilities: additional hardware specs (interpreted by scheduler)
85
114
  devices: a list of named devices with their quantities
115
+ tags: metadata tags for the resource (not interpreted by schedulers)
116
+ used to add non-functional information about resources (e.g. whether it is an alias of another resource)
86
117
 
87
118
  Note: you should prefer to use named_resources instead of specifying the raw
88
119
  resource requirement directly.
@@ -93,6 +124,7 @@ class Resource:
93
124
  memMB: int
94
125
  capabilities: Dict[str, Any] = field(default_factory=dict)
95
126
  devices: Dict[str, int] = field(default_factory=dict)
127
+ tags: Dict[str, object] = field(default_factory=dict)
96
128
 
97
129
  @staticmethod
98
130
  def copy(original: "Resource", **capabilities: Any) -> "Resource":
@@ -101,6 +133,7 @@ class Resource:
101
133
  are present in the original resource and as parameter, the one from parameter
102
134
  will be used.
103
135
  """
136
+
104
137
  res_capabilities = dict(original.capabilities)
105
138
  res_capabilities.update(capabilities)
106
139
  return Resource(
@@ -220,7 +253,9 @@ class macros:
220
253
  current_dict[k] = self.substitute(v)
221
254
  elif isinstance(v, list):
222
255
  for i in range(len(v)):
223
- if isinstance(v[i], str):
256
+ if isinstance(v[i], dict):
257
+ stack.append(v[i])
258
+ elif isinstance(v[i], str):
224
259
  v[i] = self.substitute(v[i])
225
260
  return d
226
261
 
@@ -319,6 +354,121 @@ class DeviceMount:
319
354
  permissions: str = "rwm"
320
355
 
321
356
 
357
+ @dataclass
358
+ class Workspace:
359
+ """
360
+ Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built
361
+ into an (usually ephemeral) image. This effectively mirrors the local code changes
362
+ at job submission time.
363
+
364
+ For example:
365
+
366
+ 1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**``
367
+ 2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**``
368
+
369
+ The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between
370
+ different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`.
371
+ Check the scheduler documentation for details on which workspace it supports.
372
+
373
+ Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory.
374
+ Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``).
375
+
376
+
377
+ Attributes:
378
+ projects: mapping of local project to the sub-dir in the remote workspace dir.
379
+ """
380
+
381
+ projects: dict[str, str]
382
+
383
+ def __bool__(self) -> bool:
384
+ """False if no projects mapping. Lets us use workspace object in an if-statement"""
385
+ return bool(self.projects)
386
+
387
+ def __eq__(self, other: object) -> bool:
388
+ if not isinstance(other, Workspace):
389
+ return False
390
+ return self.projects == other.projects
391
+
392
+ def __hash__(self) -> int:
393
+ # makes it possible to use Workspace as the key in the workspace build cache
394
+ # see WorkspaceMixin.caching_build_workspace_and_update_role
395
+ return hash(frozenset(self.projects.items()))
396
+
397
+ def is_unmapped_single_project(self) -> bool:
398
+ """
399
+ Returns ``True`` if this workspace only has 1 project
400
+ and its target mapping is an empty string.
401
+ """
402
+ return len(self.projects) == 1 and not next(iter(self.projects.values()))
403
+
404
+ def merge_into(self, outdir: str | pathlib.Path) -> None:
405
+ """
406
+ Copies each project dir of this workspace into the specified ``outdir``.
407
+ Each project dir is copied into ``{outdir}/{target}`` where ``target`` is
408
+ the target mapping of the project dir.
409
+
410
+ For example:
411
+
412
+ .. code-block:: python
413
+ from os.path import expanduser
414
+
415
+ workspace = Workspace(
416
+ projects={
417
+ expanduser("~/workspace/torch"): "torch",
418
+ expanduser("~/workspace/my_project": "")
419
+ }
420
+ )
421
+ workspace.merge_into(expanduser("~/tmp"))
422
+
423
+ Copies:
424
+
425
+ * ``~/workspace/torch/**`` into ``~/tmp/torch/**``
426
+ * ``~/workspace/my_project/**`` into ``~/tmp/**``
427
+
428
+ """
429
+
430
+ for src, dst in self.projects.items():
431
+ dst_path = pathlib.Path(outdir) / dst
432
+ if pathlib.Path(src).is_file():
433
+ shutil.copy2(src, dst_path)
434
+ else: # src is dir
435
+ shutil.copytree(src, dst_path, dirs_exist_ok=True)
436
+
437
+ @staticmethod
438
+ def from_str(workspace: str | None) -> "Workspace":
439
+ import yaml
440
+
441
+ if not workspace:
442
+ return Workspace({})
443
+
444
+ projects = yaml.safe_load(workspace)
445
+ if isinstance(projects, str): # single project workspace
446
+ projects = {projects: ""}
447
+ else: # multi-project workspace
448
+ # Replace None mappings with "" (empty string)
449
+ projects = {k: ("" if v is None else v) for k, v in projects.items()}
450
+
451
+ return Workspace(projects)
452
+
453
+ def __str__(self) -> str:
454
+ """
455
+ Returns a string representation of the Workspace by concatenating
456
+ the project mappings using ';' as a delimiter and ':' between key and value.
457
+ If the single-project workspace with no target mapping, then simply
458
+ returns the src (local project dir)
459
+
460
+ NOTE: meant to be used for logging purposes not serde.
461
+ Therefore not symmetric with :py:func:`Workspace.from_str`.
462
+
463
+ """
464
+ if self.is_unmapped_single_project():
465
+ return next(iter(self.projects))
466
+ else:
467
+ return ";".join(
468
+ k if not v else f"{k}:{v}" for k, v in self.projects.items()
469
+ )
470
+
471
+
322
472
  @dataclass
323
473
  class Role:
324
474
  """
@@ -371,12 +521,15 @@ class Role:
371
521
  metadata: Free form information that is associated with the role, for example
372
522
  scheduler specific data. The key should follow the pattern: ``$scheduler.$key``
373
523
  mounts: a list of mounts on the machine
524
+ workspace: local project directories to be mirrored on the remote job.
525
+ NOTE: The workspace argument provided to the :py:class:`~torchx.runner.api.Runner` APIs
526
+ only takes effect on ``appdef.role[0]`` and overrides this attribute.
527
+
374
528
  """
375
529
 
376
530
  name: str
377
531
  image: str
378
532
  min_replicas: Optional[int] = None
379
- base_image: Optional[str] = None # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
380
533
  entrypoint: str = MISSING
381
534
  args: List[str] = field(default_factory=list)
382
535
  env: Dict[str, str] = field(default_factory=dict)
@@ -386,9 +539,10 @@ class Role:
386
539
  resource: Resource = field(default_factory=_null_resource)
387
540
  port_map: Dict[str, int] = field(default_factory=dict)
388
541
  metadata: Dict[str, Any] = field(default_factory=dict)
389
- mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
390
- default_factory=list
391
- )
542
+ mounts: List[BindMount | VolumeMount | DeviceMount] = field(default_factory=list)
543
+ workspace: Workspace | None = None
544
+
545
+ # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
392
546
  overrides: Dict[str, Any] = field(default_factory=dict)
393
547
 
394
548
  # pyre-ignore
@@ -788,6 +942,8 @@ class runopt:
788
942
  opt_type: Type[CfgVal]
789
943
  is_required: bool
790
944
  help: str
945
+ aliases: list[str] | None = None
946
+ deprecated_aliases: list[str] | None = None
791
947
 
792
948
  @property
793
949
  def is_type_list_of_str(self) -> bool:
@@ -823,7 +979,7 @@ class runopt:
823
979
 
824
980
  NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
825
981
  at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
826
- primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/pytorch/torchx/pull/855
982
+ primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/meta-pytorch/torchx/pull/855
827
983
  """
828
984
 
829
985
  if self.opt_type is None:
@@ -879,6 +1035,7 @@ class runopts:
879
1035
 
880
1036
  def __init__(self) -> None:
881
1037
  self._opts: Dict[str, runopt] = {}
1038
+ self._alias_to_key: dict[str, str] = {}
882
1039
 
883
1040
  def __iter__(self) -> Iterator[Tuple[str, runopt]]:
884
1041
  return self._opts.items().__iter__()
@@ -906,9 +1063,16 @@ class runopts:
906
1063
 
907
1064
  def get(self, name: str) -> Optional[runopt]:
908
1065
  """
909
- Returns option if any was registered, or None otherwise
1066
+ Returns option if any was registered, or None otherwise.
1067
+ First searches for the option by ``name``, then falls-back to matching ``name`` with any
1068
+ registered aliases.
1069
+
910
1070
  """
911
- return self._opts.get(name, None)
1071
+ if name in self._opts:
1072
+ return self._opts[name]
1073
+ if name in self._alias_to_key:
1074
+ return self._opts[self._alias_to_key[name]]
1075
+ return None
912
1076
 
913
1077
  def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
914
1078
  """
@@ -923,6 +1087,36 @@ class runopts:
923
1087
 
924
1088
  for cfg_key, runopt in self._opts.items():
925
1089
  val = resolved_cfg.get(cfg_key)
1090
+ resolved_name = None
1091
+ aliases = runopt.aliases or []
1092
+ deprecated_aliases = runopt.deprecated_aliases or []
1093
+ if val is None:
1094
+ for alias in aliases:
1095
+ val = resolved_cfg.get(alias)
1096
+ if alias in cfg or val is not None:
1097
+ resolved_name = alias
1098
+ break
1099
+ for alias in deprecated_aliases:
1100
+ val = resolved_cfg.get(alias)
1101
+ if val is not None:
1102
+ resolved_name = alias
1103
+ use_instead = self._alias_to_key.get(alias)
1104
+ warnings.warn(
1105
+ f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
1106
+ UserWarning,
1107
+ stacklevel=2,
1108
+ )
1109
+ break
1110
+ else:
1111
+ resolved_name = cfg_key
1112
+ for alias in aliases:
1113
+ duplicate_val = resolved_cfg.get(alias)
1114
+ if alias in cfg or duplicate_val is not None:
1115
+ raise InvalidRunConfigException(
1116
+ f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
1117
+ resolved_name,
1118
+ cfg,
1119
+ )
926
1120
 
927
1121
  # check required opt
928
1122
  if runopt.is_required and val is None:
@@ -942,7 +1136,7 @@ class runopts:
942
1136
  )
943
1137
 
944
1138
  # not required and not set, set to default
945
- if val is None:
1139
+ if val is None and resolved_name is None:
946
1140
  resolved_cfg[cfg_key] = runopt.default
947
1141
  return resolved_cfg
948
1142
 
@@ -1042,12 +1236,16 @@ class runopts:
1042
1236
  help: str,
1043
1237
  default: CfgVal = None,
1044
1238
  required: bool = False,
1239
+ aliases: Optional[list[str]] = None,
1240
+ deprecated_aliases: Optional[list[str]] = None,
1045
1241
  ) -> None:
1046
1242
  """
1047
1243
  Adds the ``config`` option with the given help string and ``default``
1048
1244
  value (if any). If the ``default`` is not specified then this option
1049
1245
  is a required option.
1050
1246
  """
1247
+ aliases = aliases or []
1248
+ deprecated_aliases = deprecated_aliases or []
1051
1249
  if required and default is not None:
1052
1250
  raise ValueError(
1053
1251
  f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -1059,7 +1257,19 @@ class runopts:
1059
1257
  f" Given: {default} ({type(default).__name__})"
1060
1258
  )
1061
1259
 
1062
- self._opts[cfg_key] = runopt(default, type_, required, help)
1260
+ opt = runopt(
1261
+ default,
1262
+ type_,
1263
+ required,
1264
+ help,
1265
+ list(set(aliases)),
1266
+ list(set(deprecated_aliases)),
1267
+ )
1268
+ for alias in aliases:
1269
+ self._alias_to_key[alias] = cfg_key
1270
+ for deprecated_alias in deprecated_aliases:
1271
+ self._alias_to_key[deprecated_alias] = cfg_key
1272
+ self._opts[cfg_key] = opt
1063
1273
 
1064
1274
  def update(self, other: "runopts") -> None:
1065
1275
  self._opts.update(other._opts)
torchx/specs/builders.py CHANGED
@@ -4,13 +4,13 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
- # pyre-strict
7
+ # pyre-unsafe
8
8
 
9
9
  import argparse
10
10
  import inspect
11
11
  import os
12
12
  from argparse import Namespace
13
- from typing import Any, Callable, Dict, List, Mapping, Optional, Union
13
+ from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union
14
14
 
15
15
  from torchx.specs.api import BindMount, MountType, VolumeMount
16
16
  from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
@@ -19,6 +19,14 @@ from torchx.util.types import decode, decode_optional, get_argparse_param_type,
19
19
  from .api import AppDef, DeviceMount
20
20
 
21
21
 
22
+ class ComponentArgs(NamedTuple):
23
+ """Parsed component function arguments"""
24
+
25
+ positional_args: dict[str, Any]
26
+ var_args: list[str]
27
+ kwargs: dict[str, Any]
28
+
29
+
22
30
  def _create_args_parser(
23
31
  cmpnt_fn: Callable[..., AppDef],
24
32
  cmpnt_defaults: Optional[Dict[str, str]] = None,
@@ -31,7 +39,7 @@ def _create_args_parser(
31
39
 
32
40
 
33
41
  def _create_args_parser_from_parameters(
34
- cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
42
+ cmpnt_fn: Callable[..., AppDef],
35
43
  parameters: Mapping[str, inspect.Parameter],
36
44
  cmpnt_defaults: Optional[Dict[str, str]] = None,
37
45
  config: Optional[Dict[str, Any]] = None,
@@ -112,7 +120,7 @@ def _merge_config_values_with_args(
112
120
 
113
121
 
114
122
  def parse_args(
115
- cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
123
+ cmpnt_fn: Callable[..., AppDef],
116
124
  cmpnt_args: List[str],
117
125
  cmpnt_defaults: Optional[Dict[str, Any]] = None,
118
126
  config: Optional[Dict[str, Any]] = None,
@@ -140,8 +148,97 @@ def parse_args(
140
148
  return parsed_args
141
149
 
142
150
 
151
+ def component_args_from_str(
152
+ cmpnt_fn: Callable[..., AppDef],
153
+ cmpnt_args: list[str],
154
+ cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
155
+ config: Optional[Dict[str, Any]] = None,
156
+ ) -> ComponentArgs:
157
+ """
158
+ Parses and decodes command-line arguments for a component function.
159
+
160
+ This function takes a component function and its arguments, parses them using argparse,
161
+ and decodes the arguments into their expected types based on the function's signature.
162
+ It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
163
+
164
+ Args:
165
+ cmpnt_fn: The component function whose arguments are to be parsed and decoded.
166
+ cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
167
+ cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
168
+ config: Optional dictionary containing additional configuration values.
169
+
170
+ Returns:
171
+ ComponentArgs representing the input args to a component function containing:
172
+ - positional_args: Dictionary of positional and positional-or-keyword arguments.
173
+ - var_args: List of variable positional arguments (*args).
174
+ - kwargs: Dictionary of keyword-only arguments.
175
+
176
+ Usage:
177
+
178
+ .. doctest::
179
+ from torchx.specs.api import AppDef
180
+ from torchx.specs.builders import component_args_from_str
181
+
182
+ def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
183
+ return AppDef(name="example")
184
+
185
+ # Supports space separated arguments
186
+ args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
187
+ parsed_args = component_args_from_str(example_component_fn, args)
188
+
189
+ assert parsed_args.positional_args == {"foo": "fooval"}
190
+ assert parsed_args.var_args == ["arg1", "arg2"]
191
+ assert parsed_args.kwargs == {"bar": "barval"}
192
+
193
+ # Supports '=' separated arguments
194
+ args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
195
+ parsed_args = component_args_from_str(example_component_fn, args)
196
+
197
+ assert parsed_args.positional_args == {"foo": "fooval"}
198
+ assert parsed_args.var_args == ["arg1", "arg2"]
199
+ assert parsed_args.kwargs == {"bar": "barval"}
200
+
201
+
202
+ """
203
+ parsed_args: Namespace = parse_args(
204
+ cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
205
+ )
206
+
207
+ positional_args = {}
208
+ var_args = []
209
+ kwargs = {}
210
+
211
+ parameters = inspect.signature(cmpnt_fn).parameters
212
+ for param_name, parameter in parameters.items():
213
+ arg_value = getattr(parsed_args, param_name)
214
+ parameter_type = parameter.annotation
215
+ parameter_type = decode_optional(parameter_type)
216
+ if (
217
+ parameter_type != arg_value.__class__
218
+ and parameter.kind != inspect.Parameter.VAR_POSITIONAL
219
+ ):
220
+ arg_value = decode(arg_value, parameter_type)
221
+ if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
222
+ var_args = arg_value
223
+ elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
224
+ kwargs[param_name] = arg_value
225
+ elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
226
+ raise TypeError(
227
+ f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
228
+ f"type to a dict or explicitly declare the params"
229
+ )
230
+ else:
231
+ # POSITIONAL or POSITIONAL_OR_KEYWORD
232
+ positional_args[param_name] = arg_value
233
+
234
+ if len(var_args) > 0 and var_args[0] == "--":
235
+ var_args = var_args[1:]
236
+
237
+ return ComponentArgs(positional_args, var_args, kwargs)
238
+
239
+
143
240
  def materialize_appdef(
144
- cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
241
+ cmpnt_fn: Callable[..., AppDef],
145
242
  cmpnt_args: List[str],
146
243
  cmpnt_defaults: Optional[Dict[str, Any]] = None,
147
244
  config: Optional[Dict[str, Any]] = None,
@@ -174,30 +271,14 @@ def materialize_appdef(
174
271
  An application spec
175
272
  """
176
273
 
177
- function_args = []
178
- var_arg = []
179
- kwargs = {}
180
-
181
- parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)
182
-
183
- parameters = inspect.signature(cmpnt_fn).parameters
184
- for param_name, parameter in parameters.items():
185
- arg_value = getattr(parsed_args, param_name)
186
- parameter_type = parameter.annotation
187
- parameter_type = decode_optional(parameter_type)
188
- arg_value = decode(arg_value, parameter_type)
189
- if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
190
- var_arg = arg_value
191
- elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
192
- kwargs[param_name] = arg_value
193
- elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
194
- raise TypeError("**kwargs are not supported for component definitions")
195
- else:
196
- function_args.append(arg_value)
197
- if len(var_arg) > 0 and var_arg[0] == "--":
198
- var_arg = var_arg[1:]
274
+ component_args: ComponentArgs = component_args_from_str(
275
+ cmpnt_fn, cmpnt_args, cmpnt_defaults, config
276
+ )
277
+ positional_arg_values = list(component_args.positional_args.values())
278
+ appdef = cmpnt_fn(
279
+ *positional_arg_values, *component_args.var_args, **component_args.kwargs
280
+ )
199
281
 
200
- appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
201
282
  if not isinstance(appdef, AppDef):
202
283
  raise TypeError(
203
284
  f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"