torchx-nightly 2025.8.5__py3-none-any.whl → 2026.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/cli/cmd_delete.py +30 -0
- torchx/cli/cmd_list.py +1 -2
- torchx/cli/cmd_run.py +202 -28
- torchx/cli/cmd_tracker.py +1 -1
- torchx/cli/main.py +2 -0
- torchx/components/__init__.py +1 -8
- torchx/components/dist.py +9 -3
- torchx/components/integration_tests/component_provider.py +2 -2
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +102 -81
- torchx/runner/config.py +3 -1
- torchx/runner/events/__init__.py +20 -10
- torchx/runner/events/api.py +1 -1
- torchx/schedulers/__init__.py +7 -10
- torchx/schedulers/api.py +66 -25
- torchx/schedulers/aws_batch_scheduler.py +47 -6
- torchx/schedulers/aws_sagemaker_scheduler.py +1 -1
- torchx/schedulers/docker_scheduler.py +4 -3
- torchx/schedulers/ids.py +27 -23
- torchx/schedulers/kubernetes_mcad_scheduler.py +1 -4
- torchx/schedulers/kubernetes_scheduler.py +355 -36
- torchx/schedulers/local_scheduler.py +2 -1
- torchx/schedulers/lsf_scheduler.py +1 -1
- torchx/schedulers/slurm_scheduler.py +102 -27
- torchx/specs/__init__.py +40 -9
- torchx/specs/api.py +222 -12
- torchx/specs/builders.py +109 -28
- torchx/specs/file_linter.py +117 -53
- torchx/specs/finder.py +25 -37
- torchx/specs/named_resources_aws.py +13 -2
- torchx/specs/overlays.py +106 -0
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/util/strings.py +1 -1
- torchx/util/types.py +12 -1
- torchx/version.py +2 -2
- torchx/workspace/api.py +102 -5
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/METADATA +35 -49
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/RECORD +46 -56
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/WHEEL +1 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
- torchx/pipelines/kfp/__init__.py +0 -30
- torchx/pipelines/kfp/adapter.py +0 -274
- torchx/pipelines/kfp/version.py +0 -19
- torchx/schedulers/gcp_batch_scheduler.py +0 -497
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -454
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2026.1.11.dist-info}/top_level.txt +0 -0
torchx/specs/api.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
@@ -12,11 +11,15 @@ import copy
|
|
|
12
11
|
import inspect
|
|
13
12
|
import json
|
|
14
13
|
import logging as logger
|
|
14
|
+
import os
|
|
15
|
+
import pathlib
|
|
15
16
|
import re
|
|
17
|
+
import shutil
|
|
16
18
|
import typing
|
|
19
|
+
import warnings
|
|
17
20
|
from dataclasses import asdict, dataclass, field
|
|
18
21
|
from datetime import datetime
|
|
19
|
-
from enum import Enum
|
|
22
|
+
from enum import Enum, IntEnum
|
|
20
23
|
from json import JSONDecodeError
|
|
21
24
|
from string import Template
|
|
22
25
|
from typing import (
|
|
@@ -67,6 +70,32 @@ YELLOW_BOLD = "\033[1;33m"
|
|
|
67
70
|
RESET = "\033[0m"
|
|
68
71
|
|
|
69
72
|
|
|
73
|
+
def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
|
|
74
|
+
"""
|
|
75
|
+
Path to the "dot-directory" for torchx.
|
|
76
|
+
Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.
|
|
77
|
+
|
|
78
|
+
Usage:
|
|
79
|
+
|
|
80
|
+
.. doc-test::
|
|
81
|
+
|
|
82
|
+
from pathlib import Path
|
|
83
|
+
from torchx.specs import TORCHX_HOME
|
|
84
|
+
|
|
85
|
+
assert TORCHX_HOME() == Path.home() / ".torchx"
|
|
86
|
+
assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
|
|
87
|
+
```
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
default_dir = str(pathlib.Path.home() / ".torchx")
|
|
91
|
+
torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))
|
|
92
|
+
|
|
93
|
+
torchx_home = torchx_home / os.path.sep.join(subdir_paths)
|
|
94
|
+
torchx_home.mkdir(parents=True, exist_ok=True)
|
|
95
|
+
|
|
96
|
+
return torchx_home
|
|
97
|
+
|
|
98
|
+
|
|
70
99
|
# ========================================
|
|
71
100
|
# ==== Distributed AppDef API =======
|
|
72
101
|
# ========================================
|
|
@@ -83,6 +112,8 @@ class Resource:
|
|
|
83
112
|
memMB: MB of ram
|
|
84
113
|
capabilities: additional hardware specs (interpreted by scheduler)
|
|
85
114
|
devices: a list of named devices with their quantities
|
|
115
|
+
tags: metadata tags for the resource (not interpreted by schedulers)
|
|
116
|
+
used to add non-functional information about resources (e.g. whether it is an alias of another resource)
|
|
86
117
|
|
|
87
118
|
Note: you should prefer to use named_resources instead of specifying the raw
|
|
88
119
|
resource requirement directly.
|
|
@@ -93,6 +124,7 @@ class Resource:
|
|
|
93
124
|
memMB: int
|
|
94
125
|
capabilities: Dict[str, Any] = field(default_factory=dict)
|
|
95
126
|
devices: Dict[str, int] = field(default_factory=dict)
|
|
127
|
+
tags: Dict[str, object] = field(default_factory=dict)
|
|
96
128
|
|
|
97
129
|
@staticmethod
|
|
98
130
|
def copy(original: "Resource", **capabilities: Any) -> "Resource":
|
|
@@ -101,6 +133,7 @@ class Resource:
|
|
|
101
133
|
are present in the original resource and as parameter, the one from parameter
|
|
102
134
|
will be used.
|
|
103
135
|
"""
|
|
136
|
+
|
|
104
137
|
res_capabilities = dict(original.capabilities)
|
|
105
138
|
res_capabilities.update(capabilities)
|
|
106
139
|
return Resource(
|
|
@@ -220,7 +253,9 @@ class macros:
|
|
|
220
253
|
current_dict[k] = self.substitute(v)
|
|
221
254
|
elif isinstance(v, list):
|
|
222
255
|
for i in range(len(v)):
|
|
223
|
-
if isinstance(v[i],
|
|
256
|
+
if isinstance(v[i], dict):
|
|
257
|
+
stack.append(v[i])
|
|
258
|
+
elif isinstance(v[i], str):
|
|
224
259
|
v[i] = self.substitute(v[i])
|
|
225
260
|
return d
|
|
226
261
|
|
|
@@ -319,6 +354,121 @@ class DeviceMount:
|
|
|
319
354
|
permissions: str = "rwm"
|
|
320
355
|
|
|
321
356
|
|
|
357
|
+
@dataclass
|
|
358
|
+
class Workspace:
|
|
359
|
+
"""
|
|
360
|
+
Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built
|
|
361
|
+
into an (usually ephemeral) image. This effectively mirrors the local code changes
|
|
362
|
+
at job submission time.
|
|
363
|
+
|
|
364
|
+
For example:
|
|
365
|
+
|
|
366
|
+
1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**``
|
|
367
|
+
2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**``
|
|
368
|
+
|
|
369
|
+
The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between
|
|
370
|
+
different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`.
|
|
371
|
+
Check the scheduler documentation for details on which workspace it supports.
|
|
372
|
+
|
|
373
|
+
Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory.
|
|
374
|
+
Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``).
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
Attributes:
|
|
378
|
+
projects: mapping of local project to the sub-dir in the remote workspace dir.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
projects: dict[str, str]
|
|
382
|
+
|
|
383
|
+
def __bool__(self) -> bool:
|
|
384
|
+
"""False if no projects mapping. Lets us use workspace object in an if-statement"""
|
|
385
|
+
return bool(self.projects)
|
|
386
|
+
|
|
387
|
+
def __eq__(self, other: object) -> bool:
|
|
388
|
+
if not isinstance(other, Workspace):
|
|
389
|
+
return False
|
|
390
|
+
return self.projects == other.projects
|
|
391
|
+
|
|
392
|
+
def __hash__(self) -> int:
|
|
393
|
+
# makes it possible to use Workspace as the key in the workspace build cache
|
|
394
|
+
# see WorkspaceMixin.caching_build_workspace_and_update_role
|
|
395
|
+
return hash(frozenset(self.projects.items()))
|
|
396
|
+
|
|
397
|
+
def is_unmapped_single_project(self) -> bool:
|
|
398
|
+
"""
|
|
399
|
+
Returns ``True`` if this workspace only has 1 project
|
|
400
|
+
and its target mapping is an empty string.
|
|
401
|
+
"""
|
|
402
|
+
return len(self.projects) == 1 and not next(iter(self.projects.values()))
|
|
403
|
+
|
|
404
|
+
def merge_into(self, outdir: str | pathlib.Path) -> None:
|
|
405
|
+
"""
|
|
406
|
+
Copies each project dir of this workspace into the specified ``outdir``.
|
|
407
|
+
Each project dir is copied into ``{outdir}/{target}`` where ``target`` is
|
|
408
|
+
the target mapping of the project dir.
|
|
409
|
+
|
|
410
|
+
For example:
|
|
411
|
+
|
|
412
|
+
.. code-block:: python
|
|
413
|
+
from os.path import expanduser
|
|
414
|
+
|
|
415
|
+
workspace = Workspace(
|
|
416
|
+
projects={
|
|
417
|
+
expanduser("~/workspace/torch"): "torch",
|
|
418
|
+
expanduser("~/workspace/my_project": "")
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
workspace.merge_into(expanduser("~/tmp"))
|
|
422
|
+
|
|
423
|
+
Copies:
|
|
424
|
+
|
|
425
|
+
* ``~/workspace/torch/**`` into ``~/tmp/torch/**``
|
|
426
|
+
* ``~/workspace/my_project/**`` into ``~/tmp/**``
|
|
427
|
+
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
for src, dst in self.projects.items():
|
|
431
|
+
dst_path = pathlib.Path(outdir) / dst
|
|
432
|
+
if pathlib.Path(src).is_file():
|
|
433
|
+
shutil.copy2(src, dst_path)
|
|
434
|
+
else: # src is dir
|
|
435
|
+
shutil.copytree(src, dst_path, dirs_exist_ok=True)
|
|
436
|
+
|
|
437
|
+
@staticmethod
|
|
438
|
+
def from_str(workspace: str | None) -> "Workspace":
|
|
439
|
+
import yaml
|
|
440
|
+
|
|
441
|
+
if not workspace:
|
|
442
|
+
return Workspace({})
|
|
443
|
+
|
|
444
|
+
projects = yaml.safe_load(workspace)
|
|
445
|
+
if isinstance(projects, str): # single project workspace
|
|
446
|
+
projects = {projects: ""}
|
|
447
|
+
else: # multi-project workspace
|
|
448
|
+
# Replace None mappings with "" (empty string)
|
|
449
|
+
projects = {k: ("" if v is None else v) for k, v in projects.items()}
|
|
450
|
+
|
|
451
|
+
return Workspace(projects)
|
|
452
|
+
|
|
453
|
+
def __str__(self) -> str:
|
|
454
|
+
"""
|
|
455
|
+
Returns a string representation of the Workspace by concatenating
|
|
456
|
+
the project mappings using ';' as a delimiter and ':' between key and value.
|
|
457
|
+
If the single-project workspace with no target mapping, then simply
|
|
458
|
+
returns the src (local project dir)
|
|
459
|
+
|
|
460
|
+
NOTE: meant to be used for logging purposes not serde.
|
|
461
|
+
Therefore not symmetric with :py:func:`Workspace.from_str`.
|
|
462
|
+
|
|
463
|
+
"""
|
|
464
|
+
if self.is_unmapped_single_project():
|
|
465
|
+
return next(iter(self.projects))
|
|
466
|
+
else:
|
|
467
|
+
return ";".join(
|
|
468
|
+
k if not v else f"{k}:{v}" for k, v in self.projects.items()
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
|
|
322
472
|
@dataclass
|
|
323
473
|
class Role:
|
|
324
474
|
"""
|
|
@@ -371,12 +521,15 @@ class Role:
|
|
|
371
521
|
metadata: Free form information that is associated with the role, for example
|
|
372
522
|
scheduler specific data. The key should follow the pattern: ``$scheduler.$key``
|
|
373
523
|
mounts: a list of mounts on the machine
|
|
524
|
+
workspace: local project directories to be mirrored on the remote job.
|
|
525
|
+
NOTE: The workspace argument provided to the :py:class:`~torchx.runner.api.Runner` APIs
|
|
526
|
+
only takes effect on ``appdef.role[0]`` and overrides this attribute.
|
|
527
|
+
|
|
374
528
|
"""
|
|
375
529
|
|
|
376
530
|
name: str
|
|
377
531
|
image: str
|
|
378
532
|
min_replicas: Optional[int] = None
|
|
379
|
-
base_image: Optional[str] = None # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
|
|
380
533
|
entrypoint: str = MISSING
|
|
381
534
|
args: List[str] = field(default_factory=list)
|
|
382
535
|
env: Dict[str, str] = field(default_factory=dict)
|
|
@@ -386,9 +539,10 @@ class Role:
|
|
|
386
539
|
resource: Resource = field(default_factory=_null_resource)
|
|
387
540
|
port_map: Dict[str, int] = field(default_factory=dict)
|
|
388
541
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
389
|
-
mounts: List[
|
|
390
|
-
|
|
391
|
-
|
|
542
|
+
mounts: List[BindMount | VolumeMount | DeviceMount] = field(default_factory=list)
|
|
543
|
+
workspace: Workspace | None = None
|
|
544
|
+
|
|
545
|
+
# DEPRECATED DO NOT SET, WILL BE REMOVED SOON
|
|
392
546
|
overrides: Dict[str, Any] = field(default_factory=dict)
|
|
393
547
|
|
|
394
548
|
# pyre-ignore
|
|
@@ -788,6 +942,8 @@ class runopt:
|
|
|
788
942
|
opt_type: Type[CfgVal]
|
|
789
943
|
is_required: bool
|
|
790
944
|
help: str
|
|
945
|
+
aliases: list[str] | None = None
|
|
946
|
+
deprecated_aliases: list[str] | None = None
|
|
791
947
|
|
|
792
948
|
@property
|
|
793
949
|
def is_type_list_of_str(self) -> bool:
|
|
@@ -823,7 +979,7 @@ class runopt:
|
|
|
823
979
|
|
|
824
980
|
NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
|
|
825
981
|
at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
|
|
826
|
-
primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/pytorch/torchx/pull/855
|
|
982
|
+
primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/meta-pytorch/torchx/pull/855
|
|
827
983
|
"""
|
|
828
984
|
|
|
829
985
|
if self.opt_type is None:
|
|
@@ -879,6 +1035,7 @@ class runopts:
|
|
|
879
1035
|
|
|
880
1036
|
def __init__(self) -> None:
|
|
881
1037
|
self._opts: Dict[str, runopt] = {}
|
|
1038
|
+
self._alias_to_key: dict[str, str] = {}
|
|
882
1039
|
|
|
883
1040
|
def __iter__(self) -> Iterator[Tuple[str, runopt]]:
|
|
884
1041
|
return self._opts.items().__iter__()
|
|
@@ -906,9 +1063,16 @@ class runopts:
|
|
|
906
1063
|
|
|
907
1064
|
def get(self, name: str) -> Optional[runopt]:
|
|
908
1065
|
"""
|
|
909
|
-
Returns option if any was registered, or None otherwise
|
|
1066
|
+
Returns option if any was registered, or None otherwise.
|
|
1067
|
+
First searches for the option by ``name``, then falls-back to matching ``name`` with any
|
|
1068
|
+
registered aliases.
|
|
1069
|
+
|
|
910
1070
|
"""
|
|
911
|
-
|
|
1071
|
+
if name in self._opts:
|
|
1072
|
+
return self._opts[name]
|
|
1073
|
+
if name in self._alias_to_key:
|
|
1074
|
+
return self._opts[self._alias_to_key[name]]
|
|
1075
|
+
return None
|
|
912
1076
|
|
|
913
1077
|
def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
|
|
914
1078
|
"""
|
|
@@ -923,6 +1087,36 @@ class runopts:
|
|
|
923
1087
|
|
|
924
1088
|
for cfg_key, runopt in self._opts.items():
|
|
925
1089
|
val = resolved_cfg.get(cfg_key)
|
|
1090
|
+
resolved_name = None
|
|
1091
|
+
aliases = runopt.aliases or []
|
|
1092
|
+
deprecated_aliases = runopt.deprecated_aliases or []
|
|
1093
|
+
if val is None:
|
|
1094
|
+
for alias in aliases:
|
|
1095
|
+
val = resolved_cfg.get(alias)
|
|
1096
|
+
if alias in cfg or val is not None:
|
|
1097
|
+
resolved_name = alias
|
|
1098
|
+
break
|
|
1099
|
+
for alias in deprecated_aliases:
|
|
1100
|
+
val = resolved_cfg.get(alias)
|
|
1101
|
+
if val is not None:
|
|
1102
|
+
resolved_name = alias
|
|
1103
|
+
use_instead = self._alias_to_key.get(alias)
|
|
1104
|
+
warnings.warn(
|
|
1105
|
+
f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
|
|
1106
|
+
UserWarning,
|
|
1107
|
+
stacklevel=2,
|
|
1108
|
+
)
|
|
1109
|
+
break
|
|
1110
|
+
else:
|
|
1111
|
+
resolved_name = cfg_key
|
|
1112
|
+
for alias in aliases:
|
|
1113
|
+
duplicate_val = resolved_cfg.get(alias)
|
|
1114
|
+
if alias in cfg or duplicate_val is not None:
|
|
1115
|
+
raise InvalidRunConfigException(
|
|
1116
|
+
f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
|
|
1117
|
+
resolved_name,
|
|
1118
|
+
cfg,
|
|
1119
|
+
)
|
|
926
1120
|
|
|
927
1121
|
# check required opt
|
|
928
1122
|
if runopt.is_required and val is None:
|
|
@@ -942,7 +1136,7 @@ class runopts:
|
|
|
942
1136
|
)
|
|
943
1137
|
|
|
944
1138
|
# not required and not set, set to default
|
|
945
|
-
if val is None:
|
|
1139
|
+
if val is None and resolved_name is None:
|
|
946
1140
|
resolved_cfg[cfg_key] = runopt.default
|
|
947
1141
|
return resolved_cfg
|
|
948
1142
|
|
|
@@ -1042,12 +1236,16 @@ class runopts:
|
|
|
1042
1236
|
help: str,
|
|
1043
1237
|
default: CfgVal = None,
|
|
1044
1238
|
required: bool = False,
|
|
1239
|
+
aliases: Optional[list[str]] = None,
|
|
1240
|
+
deprecated_aliases: Optional[list[str]] = None,
|
|
1045
1241
|
) -> None:
|
|
1046
1242
|
"""
|
|
1047
1243
|
Adds the ``config`` option with the given help string and ``default``
|
|
1048
1244
|
value (if any). If the ``default`` is not specified then this option
|
|
1049
1245
|
is a required option.
|
|
1050
1246
|
"""
|
|
1247
|
+
aliases = aliases or []
|
|
1248
|
+
deprecated_aliases = deprecated_aliases or []
|
|
1051
1249
|
if required and default is not None:
|
|
1052
1250
|
raise ValueError(
|
|
1053
1251
|
f"Required option: {cfg_key} must not specify default value. Given: {default}"
|
|
@@ -1059,7 +1257,19 @@ class runopts:
|
|
|
1059
1257
|
f" Given: {default} ({type(default).__name__})"
|
|
1060
1258
|
)
|
|
1061
1259
|
|
|
1062
|
-
|
|
1260
|
+
opt = runopt(
|
|
1261
|
+
default,
|
|
1262
|
+
type_,
|
|
1263
|
+
required,
|
|
1264
|
+
help,
|
|
1265
|
+
list(set(aliases)),
|
|
1266
|
+
list(set(deprecated_aliases)),
|
|
1267
|
+
)
|
|
1268
|
+
for alias in aliases:
|
|
1269
|
+
self._alias_to_key[alias] = cfg_key
|
|
1270
|
+
for deprecated_alias in deprecated_aliases:
|
|
1271
|
+
self._alias_to_key[deprecated_alias] = cfg_key
|
|
1272
|
+
self._opts[cfg_key] = opt
|
|
1063
1273
|
|
|
1064
1274
|
def update(self, other: "runopts") -> None:
|
|
1065
1275
|
self._opts.update(other._opts)
|
torchx/specs/builders.py
CHANGED
|
@@ -4,13 +4,13 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
# pyre-
|
|
7
|
+
# pyre-unsafe
|
|
8
8
|
|
|
9
9
|
import argparse
|
|
10
10
|
import inspect
|
|
11
11
|
import os
|
|
12
12
|
from argparse import Namespace
|
|
13
|
-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|
13
|
+
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union
|
|
14
14
|
|
|
15
15
|
from torchx.specs.api import BindMount, MountType, VolumeMount
|
|
16
16
|
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
|
|
@@ -19,6 +19,14 @@ from torchx.util.types import decode, decode_optional, get_argparse_param_type,
|
|
|
19
19
|
from .api import AppDef, DeviceMount
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
class ComponentArgs(NamedTuple):
|
|
23
|
+
"""Parsed component function arguments"""
|
|
24
|
+
|
|
25
|
+
positional_args: dict[str, Any]
|
|
26
|
+
var_args: list[str]
|
|
27
|
+
kwargs: dict[str, Any]
|
|
28
|
+
|
|
29
|
+
|
|
22
30
|
def _create_args_parser(
|
|
23
31
|
cmpnt_fn: Callable[..., AppDef],
|
|
24
32
|
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
@@ -31,7 +39,7 @@ def _create_args_parser(
|
|
|
31
39
|
|
|
32
40
|
|
|
33
41
|
def _create_args_parser_from_parameters(
|
|
34
|
-
cmpnt_fn: Callable[...,
|
|
42
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
35
43
|
parameters: Mapping[str, inspect.Parameter],
|
|
36
44
|
cmpnt_defaults: Optional[Dict[str, str]] = None,
|
|
37
45
|
config: Optional[Dict[str, Any]] = None,
|
|
@@ -112,7 +120,7 @@ def _merge_config_values_with_args(
|
|
|
112
120
|
|
|
113
121
|
|
|
114
122
|
def parse_args(
|
|
115
|
-
cmpnt_fn: Callable[...,
|
|
123
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
116
124
|
cmpnt_args: List[str],
|
|
117
125
|
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
118
126
|
config: Optional[Dict[str, Any]] = None,
|
|
@@ -140,8 +148,97 @@ def parse_args(
|
|
|
140
148
|
return parsed_args
|
|
141
149
|
|
|
142
150
|
|
|
151
|
+
def component_args_from_str(
|
|
152
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
153
|
+
cmpnt_args: list[str],
|
|
154
|
+
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
|
|
155
|
+
config: Optional[Dict[str, Any]] = None,
|
|
156
|
+
) -> ComponentArgs:
|
|
157
|
+
"""
|
|
158
|
+
Parses and decodes command-line arguments for a component function.
|
|
159
|
+
|
|
160
|
+
This function takes a component function and its arguments, parses them using argparse,
|
|
161
|
+
and decodes the arguments into their expected types based on the function's signature.
|
|
162
|
+
It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
cmpnt_fn: The component function whose arguments are to be parsed and decoded.
|
|
166
|
+
cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
|
|
167
|
+
cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
|
|
168
|
+
config: Optional dictionary containing additional configuration values.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
ComponentArgs representing the input args to a component function containing:
|
|
172
|
+
- positional_args: Dictionary of positional and positional-or-keyword arguments.
|
|
173
|
+
- var_args: List of variable positional arguments (*args).
|
|
174
|
+
- kwargs: Dictionary of keyword-only arguments.
|
|
175
|
+
|
|
176
|
+
Usage:
|
|
177
|
+
|
|
178
|
+
.. doctest::
|
|
179
|
+
from torchx.specs.api import AppDef
|
|
180
|
+
from torchx.specs.builders import component_args_from_str
|
|
181
|
+
|
|
182
|
+
def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
|
|
183
|
+
return AppDef(name="example")
|
|
184
|
+
|
|
185
|
+
# Supports space separated arguments
|
|
186
|
+
args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
|
|
187
|
+
parsed_args = component_args_from_str(example_component_fn, args)
|
|
188
|
+
|
|
189
|
+
assert parsed_args.positional_args == {"foo": "fooval"}
|
|
190
|
+
assert parsed_args.var_args == ["arg1", "arg2"]
|
|
191
|
+
assert parsed_args.kwargs == {"bar": "barval"}
|
|
192
|
+
|
|
193
|
+
# Supports '=' separated arguments
|
|
194
|
+
args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
|
|
195
|
+
parsed_args = component_args_from_str(example_component_fn, args)
|
|
196
|
+
|
|
197
|
+
assert parsed_args.positional_args == {"foo": "fooval"}
|
|
198
|
+
assert parsed_args.var_args == ["arg1", "arg2"]
|
|
199
|
+
assert parsed_args.kwargs == {"bar": "barval"}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
parsed_args: Namespace = parse_args(
|
|
204
|
+
cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
positional_args = {}
|
|
208
|
+
var_args = []
|
|
209
|
+
kwargs = {}
|
|
210
|
+
|
|
211
|
+
parameters = inspect.signature(cmpnt_fn).parameters
|
|
212
|
+
for param_name, parameter in parameters.items():
|
|
213
|
+
arg_value = getattr(parsed_args, param_name)
|
|
214
|
+
parameter_type = parameter.annotation
|
|
215
|
+
parameter_type = decode_optional(parameter_type)
|
|
216
|
+
if (
|
|
217
|
+
parameter_type != arg_value.__class__
|
|
218
|
+
and parameter.kind != inspect.Parameter.VAR_POSITIONAL
|
|
219
|
+
):
|
|
220
|
+
arg_value = decode(arg_value, parameter_type)
|
|
221
|
+
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
222
|
+
var_args = arg_value
|
|
223
|
+
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
224
|
+
kwargs[param_name] = arg_value
|
|
225
|
+
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
226
|
+
raise TypeError(
|
|
227
|
+
f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
|
|
228
|
+
f"type to a dict or explicitly declare the params"
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
# POSITIONAL or POSITIONAL_OR_KEYWORD
|
|
232
|
+
positional_args[param_name] = arg_value
|
|
233
|
+
|
|
234
|
+
if len(var_args) > 0 and var_args[0] == "--":
|
|
235
|
+
var_args = var_args[1:]
|
|
236
|
+
|
|
237
|
+
return ComponentArgs(positional_args, var_args, kwargs)
|
|
238
|
+
|
|
239
|
+
|
|
143
240
|
def materialize_appdef(
|
|
144
|
-
cmpnt_fn: Callable[...,
|
|
241
|
+
cmpnt_fn: Callable[..., AppDef],
|
|
145
242
|
cmpnt_args: List[str],
|
|
146
243
|
cmpnt_defaults: Optional[Dict[str, Any]] = None,
|
|
147
244
|
config: Optional[Dict[str, Any]] = None,
|
|
@@ -174,30 +271,14 @@ def materialize_appdef(
|
|
|
174
271
|
An application spec
|
|
175
272
|
"""
|
|
176
273
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
for param_name, parameter in parameters.items():
|
|
185
|
-
arg_value = getattr(parsed_args, param_name)
|
|
186
|
-
parameter_type = parameter.annotation
|
|
187
|
-
parameter_type = decode_optional(parameter_type)
|
|
188
|
-
arg_value = decode(arg_value, parameter_type)
|
|
189
|
-
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
190
|
-
var_arg = arg_value
|
|
191
|
-
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
192
|
-
kwargs[param_name] = arg_value
|
|
193
|
-
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
194
|
-
raise TypeError("**kwargs are not supported for component definitions")
|
|
195
|
-
else:
|
|
196
|
-
function_args.append(arg_value)
|
|
197
|
-
if len(var_arg) > 0 and var_arg[0] == "--":
|
|
198
|
-
var_arg = var_arg[1:]
|
|
274
|
+
component_args: ComponentArgs = component_args_from_str(
|
|
275
|
+
cmpnt_fn, cmpnt_args, cmpnt_defaults, config
|
|
276
|
+
)
|
|
277
|
+
positional_arg_values = list(component_args.positional_args.values())
|
|
278
|
+
appdef = cmpnt_fn(
|
|
279
|
+
*positional_arg_values, *component_args.var_args, **component_args.kwargs
|
|
280
|
+
)
|
|
199
281
|
|
|
200
|
-
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
|
|
201
282
|
if not isinstance(appdef, AppDef):
|
|
202
283
|
raise TypeError(
|
|
203
284
|
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
|