torchx-nightly 2025.8.5__py3-none-any.whl → 2025.11.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/cli/cmd_list.py +1 -2
- torchx/cli/cmd_run.py +202 -28
- torchx/cli/cmd_tracker.py +1 -1
- torchx/components/__init__.py +1 -8
- torchx/components/dist.py +9 -3
- torchx/components/integration_tests/component_provider.py +2 -2
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +92 -81
- torchx/runner/config.py +3 -1
- torchx/runner/events/__init__.py +20 -10
- torchx/runner/events/api.py +1 -1
- torchx/schedulers/__init__.py +7 -10
- torchx/schedulers/api.py +20 -15
- torchx/schedulers/aws_batch_scheduler.py +45 -2
- torchx/schedulers/docker_scheduler.py +3 -0
- torchx/schedulers/kubernetes_scheduler.py +200 -17
- torchx/schedulers/local_scheduler.py +1 -0
- torchx/schedulers/slurm_scheduler.py +93 -24
- torchx/specs/__init__.py +23 -6
- torchx/specs/api.py +219 -11
- torchx/specs/builders.py +109 -28
- torchx/specs/file_linter.py +117 -53
- torchx/specs/finder.py +25 -37
- torchx/specs/named_resources_aws.py +13 -2
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/util/strings.py +1 -1
- torchx/util/types.py +12 -1
- torchx/version.py +2 -2
- torchx/workspace/api.py +102 -5
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
- torchx/pipelines/kfp/__init__.py +0 -30
- torchx/pipelines/kfp/adapter.py +0 -274
- torchx/pipelines/kfp/version.py +0 -19
- torchx/schedulers/gcp_batch_scheduler.py +0 -497
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -454
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.8.5.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
5
4
|
# This source code is licensed under the BSD-style license found in the
|
|
6
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
BASE_VERSION = "0.8.0dev0"
|
torchx/cli/cmd_list.py
CHANGED
|
@@ -33,8 +33,7 @@ class CmdList(SubCommand):
|
|
|
33
33
|
type=str,
|
|
34
34
|
default=get_default_scheduler_name(),
|
|
35
35
|
choices=list(scheduler_names),
|
|
36
|
-
help=f"Name of the scheduler to use. One of: [{','.join(scheduler_names)}]."
|
|
37
|
-
" For listing app handles for ray scheduler, RAY_ADDRESS env variable should be set.",
|
|
36
|
+
help=f"Name of the scheduler to use. One of: [{','.join(scheduler_names)}].",
|
|
38
37
|
)
|
|
39
38
|
|
|
40
39
|
def run(self, args: argparse.Namespace) -> None:
|
torchx/cli/cmd_run.py
CHANGED
|
@@ -7,16 +7,17 @@
|
|
|
7
7
|
# pyre-strict
|
|
8
8
|
|
|
9
9
|
import argparse
|
|
10
|
+
import json
|
|
10
11
|
import logging
|
|
11
12
|
import os
|
|
12
13
|
import sys
|
|
13
14
|
import threading
|
|
14
15
|
from collections import Counter
|
|
15
|
-
from dataclasses import asdict
|
|
16
|
+
from dataclasses import asdict, dataclass, field, fields, MISSING as DATACLASS_MISSING
|
|
16
17
|
from itertools import groupby
|
|
17
18
|
from pathlib import Path
|
|
18
19
|
from pprint import pformat
|
|
19
|
-
from typing import Dict, List, Optional, Tuple
|
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
20
21
|
|
|
21
22
|
import torchx.specs as specs
|
|
22
23
|
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
|
|
@@ -25,6 +26,7 @@ from torchx.cli.cmd_log import get_logs
|
|
|
25
26
|
from torchx.runner import config, get_runner, Runner
|
|
26
27
|
from torchx.runner.config import load_sections
|
|
27
28
|
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
|
|
29
|
+
from torchx.specs import CfgVal, Workspace
|
|
28
30
|
from torchx.specs.finder import (
|
|
29
31
|
_Component,
|
|
30
32
|
ComponentNotFoundException,
|
|
@@ -40,10 +42,81 @@ MISSING_COMPONENT_ERROR_MSG = (
|
|
|
40
42
|
"missing component name, either provide it from the CLI or in .torchxconfig"
|
|
41
43
|
)
|
|
42
44
|
|
|
45
|
+
LOCAL_SCHEDULER_WARNING_MSG = (
|
|
46
|
+
"`local` scheduler is deprecated and will be"
|
|
47
|
+
" removed in the near future,"
|
|
48
|
+
" please use other variants of the local scheduler"
|
|
49
|
+
" (e.g. `local_cwd`)"
|
|
50
|
+
)
|
|
43
51
|
|
|
44
52
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
45
53
|
|
|
46
54
|
|
|
55
|
+
@dataclass
|
|
56
|
+
class TorchXRunArgs:
|
|
57
|
+
component_name: str
|
|
58
|
+
scheduler: str
|
|
59
|
+
scheduler_args: Dict[str, Any]
|
|
60
|
+
scheduler_cfg: Dict[str, CfgVal] = field(default_factory=dict)
|
|
61
|
+
dryrun: bool = False
|
|
62
|
+
wait: bool = False
|
|
63
|
+
log: bool = False
|
|
64
|
+
workspace: str = ""
|
|
65
|
+
parent_run_id: Optional[str] = None
|
|
66
|
+
tee_logs: bool = False
|
|
67
|
+
component_args: Dict[str, Any] = field(default_factory=dict)
|
|
68
|
+
component_args_str: List[str] = field(default_factory=list)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
|
|
72
|
+
all_fields = [f.name for f in fields(TorchXRunArgs)]
|
|
73
|
+
required_fields = {
|
|
74
|
+
f.name
|
|
75
|
+
for f in fields(TorchXRunArgs)
|
|
76
|
+
if f.default is DATACLASS_MISSING and f.default_factory is DATACLASS_MISSING
|
|
77
|
+
}
|
|
78
|
+
missing_fields = required_fields - json_data.keys()
|
|
79
|
+
if missing_fields:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"The following required fields are missing: {', '.join(missing_fields)}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Fail if there are fields that aren't part of the run command
|
|
85
|
+
filtered_json_data = {k: v for k, v in json_data.items() if k in all_fields}
|
|
86
|
+
extra_fields = set(json_data.keys()) - set(all_fields)
|
|
87
|
+
if extra_fields:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"The following fields are not part of the run command: {', '.join(extra_fields)}.",
|
|
90
|
+
"Please check your JSON and try launching again.",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
torchx_args = TorchXRunArgs(**filtered_json_data)
|
|
94
|
+
if torchx_args.workspace == "":
|
|
95
|
+
torchx_args.workspace = f"{Path.cwd()}"
|
|
96
|
+
return torchx_args
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def torchx_run_args_from_argparse(
|
|
100
|
+
args: argparse.Namespace,
|
|
101
|
+
component_name: str,
|
|
102
|
+
component_args: List[str],
|
|
103
|
+
scheduler_cfg: Dict[str, CfgVal],
|
|
104
|
+
) -> TorchXRunArgs:
|
|
105
|
+
return TorchXRunArgs(
|
|
106
|
+
component_name=component_name,
|
|
107
|
+
scheduler=args.scheduler,
|
|
108
|
+
scheduler_args={},
|
|
109
|
+
scheduler_cfg=scheduler_cfg,
|
|
110
|
+
dryrun=args.dryrun,
|
|
111
|
+
wait=args.wait,
|
|
112
|
+
log=args.log,
|
|
113
|
+
workspace=args.workspace,
|
|
114
|
+
parent_run_id=args.parent_run_id,
|
|
115
|
+
tee_logs=args.tee_logs,
|
|
116
|
+
component_args_str=component_args,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
47
120
|
def _parse_component_name_and_args(
|
|
48
121
|
component_name_and_args: List[str],
|
|
49
122
|
subparser: argparse.ArgumentParser,
|
|
@@ -133,6 +206,7 @@ class CmdBuiltins(SubCommand):
|
|
|
133
206
|
class CmdRun(SubCommand):
|
|
134
207
|
def __init__(self) -> None:
|
|
135
208
|
self._subparser: Optional[argparse.ArgumentParser] = None
|
|
209
|
+
self._stdin_data_json: Optional[Dict[str, Any]] = None
|
|
136
210
|
|
|
137
211
|
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
|
|
138
212
|
scheduler_names = get_scheduler_factories().keys()
|
|
@@ -176,7 +250,7 @@ class CmdRun(SubCommand):
|
|
|
176
250
|
subparser.add_argument(
|
|
177
251
|
"--workspace",
|
|
178
252
|
"--buck-target",
|
|
179
|
-
default=f"
|
|
253
|
+
default=f"{Path.cwd()}",
|
|
180
254
|
action=torchxconfig_run,
|
|
181
255
|
help="local workspace to build/patch (buck-target of main binary if using buck)",
|
|
182
256
|
)
|
|
@@ -193,36 +267,37 @@ class CmdRun(SubCommand):
|
|
|
193
267
|
default=False,
|
|
194
268
|
help="Add additional prefix to log lines to indicate which replica is printing the log",
|
|
195
269
|
)
|
|
270
|
+
subparser.add_argument(
|
|
271
|
+
"--stdin",
|
|
272
|
+
action="store_true",
|
|
273
|
+
default=False,
|
|
274
|
+
help="Read JSON input from stdin to parse into torchx run args and run the component.",
|
|
275
|
+
)
|
|
196
276
|
subparser.add_argument(
|
|
197
277
|
"component_name_and_args",
|
|
198
278
|
nargs=argparse.REMAINDER,
|
|
199
279
|
)
|
|
200
280
|
|
|
201
|
-
def
|
|
281
|
+
def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
|
|
202
282
|
if args.scheduler == "local":
|
|
203
|
-
logger.warning(
|
|
204
|
-
"`local` scheduler is deprecated and will be"
|
|
205
|
-
" removed in the near future,"
|
|
206
|
-
" please use other variants of the local scheduler"
|
|
207
|
-
" (e.g. `local_cwd`)"
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
|
|
211
|
-
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
|
|
212
|
-
config.apply(scheduler=args.scheduler, cfg=cfg)
|
|
283
|
+
logger.warning(LOCAL_SCHEDULER_WARNING_MSG)
|
|
213
284
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
285
|
+
config.apply(scheduler=args.scheduler, cfg=args.scheduler_cfg)
|
|
286
|
+
component_args = (
|
|
287
|
+
args.component_args_str
|
|
288
|
+
if args.component_args_str != []
|
|
289
|
+
else args.component_args
|
|
217
290
|
)
|
|
218
291
|
try:
|
|
292
|
+
workspace = Workspace.from_str(args.workspace) if args.workspace else None
|
|
293
|
+
|
|
219
294
|
if args.dryrun:
|
|
220
295
|
dryrun_info = runner.dryrun_component(
|
|
221
|
-
|
|
296
|
+
args.component_name,
|
|
222
297
|
component_args,
|
|
223
298
|
args.scheduler,
|
|
224
|
-
workspace=
|
|
225
|
-
cfg=
|
|
299
|
+
workspace=workspace,
|
|
300
|
+
cfg=args.scheduler_cfg,
|
|
226
301
|
parent_run_id=args.parent_run_id,
|
|
227
302
|
)
|
|
228
303
|
print(
|
|
@@ -233,11 +308,11 @@ class CmdRun(SubCommand):
|
|
|
233
308
|
print("\n=== SCHEDULER REQUEST ===\n" f"{dryrun_info}")
|
|
234
309
|
else:
|
|
235
310
|
app_handle = runner.run_component(
|
|
236
|
-
|
|
311
|
+
args.component_name,
|
|
237
312
|
component_args,
|
|
238
313
|
args.scheduler,
|
|
239
314
|
workspace=args.workspace,
|
|
240
|
-
cfg=
|
|
315
|
+
cfg=args.scheduler_cfg,
|
|
241
316
|
parent_run_id=args.parent_run_id,
|
|
242
317
|
)
|
|
243
318
|
# DO NOT delete this line. It is used by slurm tests to retrieve the app id
|
|
@@ -258,19 +333,118 @@ class CmdRun(SubCommand):
|
|
|
258
333
|
)
|
|
259
334
|
|
|
260
335
|
except (ComponentValidationException, ComponentNotFoundException) as e:
|
|
261
|
-
error_msg =
|
|
336
|
+
error_msg = (
|
|
337
|
+
f"\nFailed to run component `{args.component_name}` got errors: \n {e}"
|
|
338
|
+
)
|
|
262
339
|
logger.error(error_msg)
|
|
263
340
|
sys.exit(1)
|
|
264
341
|
except specs.InvalidRunConfigException as e:
|
|
265
342
|
error_msg = (
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
"
|
|
343
|
+
"Invalid scheduler configuration: %s\n"
|
|
344
|
+
"To configure scheduler options, either:\n"
|
|
345
|
+
" 1. Use the `-cfg` command-line argument, e.g., `-cfg key1=value1,key2=value2`\n"
|
|
346
|
+
" 2. Set up a `.torchxconfig` file. For more details, visit: https://meta-pytorch.org/torchx/main/runner.config.html\n"
|
|
347
|
+
"Run `torchx runopts %s` to check all available configuration options for the "
|
|
348
|
+
"`%s` scheduler."
|
|
349
|
+
)
|
|
350
|
+
print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr)
|
|
351
|
+
sys.exit(1)
|
|
352
|
+
|
|
353
|
+
def _run_from_cli_args(self, runner: Runner, args: argparse.Namespace) -> None:
|
|
354
|
+
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
|
|
355
|
+
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)
|
|
356
|
+
|
|
357
|
+
component, component_args = _parse_component_name_and_args(
|
|
358
|
+
args.component_name_and_args,
|
|
359
|
+
none_throws(self._subparser),
|
|
360
|
+
)
|
|
361
|
+
torchx_run_args = torchx_run_args_from_argparse(
|
|
362
|
+
args, component, component_args, cfg
|
|
363
|
+
)
|
|
364
|
+
self._run_inner(runner, torchx_run_args)
|
|
365
|
+
|
|
366
|
+
def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> None:
|
|
367
|
+
torchx_run_args = torchx_run_args_from_json(stdin_data)
|
|
368
|
+
scheduler_opts = runner.scheduler_run_opts(torchx_run_args.scheduler)
|
|
369
|
+
cfg = scheduler_opts.cfg_from_json_repr(
|
|
370
|
+
json.dumps(torchx_run_args.scheduler_args)
|
|
371
|
+
)
|
|
372
|
+
torchx_run_args.scheduler_cfg = cfg
|
|
373
|
+
self._run_inner(runner, torchx_run_args)
|
|
374
|
+
|
|
375
|
+
def _get_torchx_stdin_args(
|
|
376
|
+
self, args: argparse.Namespace
|
|
377
|
+
) -> Optional[Dict[str, Any]]:
|
|
378
|
+
if not args.stdin:
|
|
379
|
+
return None
|
|
380
|
+
if self._stdin_data_json is None:
|
|
381
|
+
self._stdin_data_json = self.torchx_json_from_stdin(args)
|
|
382
|
+
return self._stdin_data_json
|
|
383
|
+
|
|
384
|
+
def torchx_json_from_stdin(
|
|
385
|
+
self, args: Optional[argparse.Namespace] = None
|
|
386
|
+
) -> Dict[str, Any]:
|
|
387
|
+
try:
|
|
388
|
+
stdin_data_json = json.load(sys.stdin)
|
|
389
|
+
if args and args.dryrun:
|
|
390
|
+
stdin_data_json["dryrun"] = True
|
|
391
|
+
if not isinstance(stdin_data_json, dict):
|
|
392
|
+
logger.error(
|
|
393
|
+
"Invalid JSON input for `torchx run` command. Expected a dictionary."
|
|
394
|
+
)
|
|
395
|
+
sys.exit(1)
|
|
396
|
+
return stdin_data_json
|
|
397
|
+
except (json.JSONDecodeError, EOFError):
|
|
398
|
+
logger.error(
|
|
399
|
+
"Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
|
|
270
400
|
)
|
|
271
|
-
logger.error(error_msg)
|
|
272
401
|
sys.exit(1)
|
|
273
402
|
|
|
403
|
+
def verify_no_extra_args(self, args: argparse.Namespace) -> None:
|
|
404
|
+
"""
|
|
405
|
+
Verifies that only --stdin was provided when using stdin mode.
|
|
406
|
+
"""
|
|
407
|
+
if not args.stdin:
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
subparser = none_throws(self._subparser)
|
|
411
|
+
conflicting_args = []
|
|
412
|
+
|
|
413
|
+
# Check each argument against its default value
|
|
414
|
+
for action in subparser._actions:
|
|
415
|
+
if action.dest == "stdin": # Skip stdin itself
|
|
416
|
+
continue
|
|
417
|
+
if action.dest == "help": # Skip help
|
|
418
|
+
continue
|
|
419
|
+
if action.dest == "dryrun": # Skip dryrun
|
|
420
|
+
continue
|
|
421
|
+
|
|
422
|
+
current_value = getattr(args, action.dest, None)
|
|
423
|
+
default_value = action.default
|
|
424
|
+
|
|
425
|
+
# For arguments that differ from default
|
|
426
|
+
if current_value != default_value:
|
|
427
|
+
# Handle special cases where non-default doesn't mean explicitly set
|
|
428
|
+
if action.dest == "component_name_and_args" and current_value == []:
|
|
429
|
+
continue # Empty list is still default
|
|
430
|
+
print(f"*********\n {default_value} = {current_value}")
|
|
431
|
+
conflicting_args.append(f"--{action.dest.replace('_', '-')}")
|
|
432
|
+
|
|
433
|
+
if conflicting_args:
|
|
434
|
+
subparser.error(
|
|
435
|
+
f"Cannot specify {', '.join(conflicting_args)} when using --stdin. "
|
|
436
|
+
"All configuration should be provided in JSON input."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
def _run(self, runner: Runner, args: argparse.Namespace) -> None:
|
|
440
|
+
self.verify_no_extra_args(args)
|
|
441
|
+
if args.stdin:
|
|
442
|
+
stdin_data_json = self._get_torchx_stdin_args(args)
|
|
443
|
+
if stdin_data_json is not None:
|
|
444
|
+
self._run_from_stdin_args(runner, stdin_data_json)
|
|
445
|
+
else:
|
|
446
|
+
self._run_from_cli_args(runner, args)
|
|
447
|
+
|
|
274
448
|
def run(self, args: argparse.Namespace) -> None:
|
|
275
449
|
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
|
|
276
450
|
component_defaults = load_sections(prefix="component")
|
torchx/cli/cmd_tracker.py
CHANGED
|
@@ -45,7 +45,7 @@ class CmdTracker(SubCommand):
|
|
|
45
45
|
else:
|
|
46
46
|
raise RuntimeError(
|
|
47
47
|
"No trackers configured."
|
|
48
|
-
" See: https://pytorch.org/torchx/latest/runtime/tracking.html"
|
|
48
|
+
" See: https://meta-pytorch.org/torchx/latest/runtime/tracking.html"
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
def add_list_job_arguments(self, subparser: argparse.ArgumentParser) -> None:
|
torchx/components/__init__.py
CHANGED
|
@@ -181,7 +181,7 @@ To validate that you've defined your component correctly you can either:
|
|
|
181
181
|
|
|
182
182
|
1. (easiest) Dryrun your component's ``--help`` with the cli: ``torchx run --dryrun ~/component.py:train --help``
|
|
183
183
|
2. Use the component :ref:`linter<specs:Component Linter>`
|
|
184
|
-
(see `dist_test.py <https://github.com/pytorch/torchx/blob/main/torchx/components/test/dist_test.py>`_ as an example)
|
|
184
|
+
(see `dist_test.py <https://github.com/meta-pytorch/torchx/blob/main/torchx/components/test/dist_test.py>`_ as an example)
|
|
185
185
|
|
|
186
186
|
|
|
187
187
|
Running as a Job
|
|
@@ -298,13 +298,6 @@ imagine the component is defined as:
|
|
|
298
298
|
* ``*args=["--help"]``: ``torchx run comp.py:f -- --help``
|
|
299
299
|
* ``*args=["--i", "2"]``: ``torchx run comp.py:f --i 1 -- --i 2``
|
|
300
300
|
|
|
301
|
-
Run in a Pipeline
|
|
302
|
-
--------------------------------
|
|
303
|
-
|
|
304
|
-
The :ref:`torchx.pipelines<pipelines:torchx.pipelines>` define adapters that
|
|
305
|
-
convert a torchx component into the object that represents a pipeline "stage" in the
|
|
306
|
-
target pipeline platform (see :ref:`Pipelines` for a list of supported pipeline orchestrators).
|
|
307
|
-
|
|
308
301
|
Additional Resources
|
|
309
302
|
-----------------------
|
|
310
303
|
|
torchx/components/dist.py
CHANGED
|
@@ -92,6 +92,7 @@ def spmd(
|
|
|
92
92
|
h: str = "gpu.small",
|
|
93
93
|
j: str = "1x1",
|
|
94
94
|
env: Optional[Dict[str, str]] = None,
|
|
95
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
95
96
|
max_retries: int = 0,
|
|
96
97
|
mounts: Optional[List[str]] = None,
|
|
97
98
|
debug: bool = False,
|
|
@@ -131,6 +132,7 @@ def spmd(
|
|
|
131
132
|
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
|
|
132
133
|
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
|
|
133
134
|
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
|
|
135
|
+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
|
|
134
136
|
max_retries: the number of scheduler retries allowed
|
|
135
137
|
mounts: (for docker based runs only) mounts to mount into the worker environment/container
|
|
136
138
|
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
|
|
@@ -150,6 +152,7 @@ def spmd(
|
|
|
150
152
|
h=h,
|
|
151
153
|
j=str(StructuredJArgument.parse_from(h, j)),
|
|
152
154
|
env=env,
|
|
155
|
+
metadata=metadata,
|
|
153
156
|
max_retries=max_retries,
|
|
154
157
|
mounts=mounts,
|
|
155
158
|
debug=debug,
|
|
@@ -168,6 +171,7 @@ def ddp(
|
|
|
168
171
|
memMB: int = 1024,
|
|
169
172
|
j: str = "1x2",
|
|
170
173
|
env: Optional[Dict[str, str]] = None,
|
|
174
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
171
175
|
max_retries: int = 0,
|
|
172
176
|
rdzv_port: int = 29500,
|
|
173
177
|
rdzv_backend: str = "c10d",
|
|
@@ -186,7 +190,7 @@ def ddp(
|
|
|
186
190
|
|
|
187
191
|
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
|
|
188
192
|
``h`` takes precedence if specified for setting resource requirements.
|
|
189
|
-
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
193
|
+
See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
190
194
|
|
|
191
195
|
Args:
|
|
192
196
|
script_args: arguments to the main module
|
|
@@ -201,6 +205,7 @@ def ddp(
|
|
|
201
205
|
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
|
|
202
206
|
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
|
|
203
207
|
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
|
|
208
|
+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
|
|
204
209
|
max_retries: the number of scheduler retries allowed
|
|
205
210
|
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
|
|
206
211
|
Only takes effect when running multi-node. When running single node, this parameter
|
|
@@ -237,8 +242,8 @@ def ddp(
|
|
|
237
242
|
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
|
|
238
243
|
rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}")
|
|
239
244
|
|
|
240
|
-
|
|
241
|
-
|
|
245
|
+
env = env or {}
|
|
246
|
+
metadata = metadata or {}
|
|
242
247
|
|
|
243
248
|
argname = StructuredNameArgument.parse_from(
|
|
244
249
|
name=name,
|
|
@@ -299,6 +304,7 @@ def ddp(
|
|
|
299
304
|
mounts=specs.parse_mounts(mounts) if mounts else [],
|
|
300
305
|
)
|
|
301
306
|
],
|
|
307
|
+
metadata=metadata,
|
|
302
308
|
)
|
|
303
309
|
|
|
304
310
|
|
|
@@ -109,7 +109,7 @@ class CopyComponentProvider(ComponentProvider):
|
|
|
109
109
|
self._dst_path = "<None>"
|
|
110
110
|
|
|
111
111
|
def setUp(self) -> None:
|
|
112
|
-
if self._scheduler in ["local_cwd"
|
|
112
|
+
if self._scheduler in ["local_cwd"]:
|
|
113
113
|
fname = "torchx_copy_test.txt"
|
|
114
114
|
self._src_path: str = os.path.join(tempfile.gettempdir(), fname)
|
|
115
115
|
self._dst_path: str = os.path.join(tempfile.gettempdir(), f"{fname}.copy")
|
|
@@ -126,7 +126,7 @@ class CopyComponentProvider(ComponentProvider):
|
|
|
126
126
|
def tearDown(self) -> None:
|
|
127
127
|
if os.path.exists(self._dst_path):
|
|
128
128
|
os.remove(self._dst_path)
|
|
129
|
-
if self._scheduler in ["local_cwd"
|
|
129
|
+
if self._scheduler in ["local_cwd"] and os.path.exists(self._dst_path):
|
|
130
130
|
os.remove(self._dst_path)
|
|
131
131
|
|
|
132
132
|
def get_app_def(self) -> AppDef:
|
torchx/components/utils.py
CHANGED
|
@@ -154,7 +154,7 @@ def python(
|
|
|
154
154
|
|
|
155
155
|
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
|
|
156
156
|
``h`` takes precedence if specified for setting resource requirements.
|
|
157
|
-
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
157
|
+
See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
158
158
|
|
|
159
159
|
Args:
|
|
160
160
|
args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
|
torchx/distributed/__init__.py
CHANGED
|
@@ -48,7 +48,7 @@ def local_rank() -> int:
|
|
|
48
48
|
" but the `LOCAL_RANK` environment variable is not set. Will trivially return 0 for local_rank.\n"
|
|
49
49
|
" It is recommended to use torchrun/torchx to run your script or set the `LOCAL_RANK` manually.\n"
|
|
50
50
|
" For additional details see:\n"
|
|
51
|
-
" 1) https://pytorch.org/torchx/latest/components/distributed.html\n"
|
|
51
|
+
" 1) https://meta-pytorch.org/torchx/latest/components/distributed.html\n"
|
|
52
52
|
" 2) https://pytorch.org/docs/stable/elastic/run.html\n"
|
|
53
53
|
"=============================================================================================="
|
|
54
54
|
)
|