torchx-nightly 2025.9.28__py3-none-any.whl → 2025.11.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/_version.py +8 -0
- torchx/cli/cmd_run.py +10 -5
- torchx/cli/cmd_tracker.py +1 -1
- torchx/components/__init__.py +1 -1
- torchx/components/dist.py +9 -3
- torchx/components/utils.py +1 -1
- torchx/distributed/__init__.py +1 -1
- torchx/runner/api.py +30 -22
- torchx/runner/config.py +2 -0
- torchx/schedulers/__init__.py +8 -9
- torchx/schedulers/api.py +9 -4
- torchx/schedulers/aws_batch_scheduler.py +44 -1
- torchx/schedulers/docker_scheduler.py +3 -0
- torchx/schedulers/kubernetes_scheduler.py +200 -17
- torchx/schedulers/slurm_scheduler.py +11 -2
- torchx/specs/__init__.py +30 -7
- torchx/specs/api.py +215 -10
- torchx/specs/file_linter.py +1 -1
- torchx/specs/finder.py +1 -1
- torchx/specs/named_resources_aws.py +13 -2
- torchx/tracker/__init__.py +2 -2
- torchx/tracker/api.py +1 -1
- torchx/util/entrypoints.py +1 -6
- torchx/version.py +2 -2
- torchx/workspace/__init__.py +1 -1
- torchx/workspace/api.py +65 -110
- {torchx_nightly-2025.9.28.dist-info → torchx_nightly-2025.11.17.dist-info}/METADATA +34 -21
- {torchx_nightly-2025.9.28.dist-info → torchx_nightly-2025.11.17.dist-info}/RECORD +32 -31
- {torchx_nightly-2025.9.28.dist-info → torchx_nightly-2025.11.17.dist-info}/WHEEL +1 -1
- {torchx_nightly-2025.9.28.dist-info → torchx_nightly-2025.11.17.dist-info}/entry_points.txt +0 -0
- {torchx_nightly-2025.9.28.dist-info → torchx_nightly-2025.11.17.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2025.9.28.dist-info → torchx_nightly-2025.11.17.dist-info}/top_level.txt +0 -0
torchx/_version.py
ADDED
torchx/cli/cmd_run.py
CHANGED
|
@@ -26,7 +26,7 @@ from torchx.cli.cmd_log import get_logs
|
|
|
26
26
|
from torchx.runner import config, get_runner, Runner
|
|
27
27
|
from torchx.runner.config import load_sections
|
|
28
28
|
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
|
|
29
|
-
from torchx.specs import CfgVal
|
|
29
|
+
from torchx.specs import CfgVal, Workspace
|
|
30
30
|
from torchx.specs.finder import (
|
|
31
31
|
_Component,
|
|
32
32
|
ComponentNotFoundException,
|
|
@@ -36,7 +36,6 @@ from torchx.specs.finder import (
|
|
|
36
36
|
)
|
|
37
37
|
from torchx.util.log_tee_helpers import tee_logs
|
|
38
38
|
from torchx.util.types import none_throws
|
|
39
|
-
from torchx.workspace import Workspace
|
|
40
39
|
|
|
41
40
|
|
|
42
41
|
MISSING_COMPONENT_ERROR_MSG = (
|
|
@@ -344,7 +343,7 @@ class CmdRun(SubCommand):
|
|
|
344
343
|
"Invalid scheduler configuration: %s\n"
|
|
345
344
|
"To configure scheduler options, either:\n"
|
|
346
345
|
" 1. Use the `-cfg` command-line argument, e.g., `-cfg key1=value1,key2=value2`\n"
|
|
347
|
-
" 2. Set up a `.torchxconfig` file. For more details, visit: https://pytorch.org/torchx/main/runner.config.html\n"
|
|
346
|
+
" 2. Set up a `.torchxconfig` file. For more details, visit: https://meta-pytorch.org/torchx/main/runner.config.html\n"
|
|
348
347
|
"Run `torchx runopts %s` to check all available configuration options for the "
|
|
349
348
|
"`%s` scheduler."
|
|
350
349
|
)
|
|
@@ -379,12 +378,16 @@ class CmdRun(SubCommand):
|
|
|
379
378
|
if not args.stdin:
|
|
380
379
|
return None
|
|
381
380
|
if self._stdin_data_json is None:
|
|
382
|
-
self._stdin_data_json = self.torchx_json_from_stdin()
|
|
381
|
+
self._stdin_data_json = self.torchx_json_from_stdin(args)
|
|
383
382
|
return self._stdin_data_json
|
|
384
383
|
|
|
385
|
-
def torchx_json_from_stdin(
|
|
384
|
+
def torchx_json_from_stdin(
|
|
385
|
+
self, args: Optional[argparse.Namespace] = None
|
|
386
|
+
) -> Dict[str, Any]:
|
|
386
387
|
try:
|
|
387
388
|
stdin_data_json = json.load(sys.stdin)
|
|
389
|
+
if args and args.dryrun:
|
|
390
|
+
stdin_data_json["dryrun"] = True
|
|
388
391
|
if not isinstance(stdin_data_json, dict):
|
|
389
392
|
logger.error(
|
|
390
393
|
"Invalid JSON input for `torchx run` command. Expected a dictionary."
|
|
@@ -413,6 +416,8 @@ class CmdRun(SubCommand):
|
|
|
413
416
|
continue
|
|
414
417
|
if action.dest == "help": # Skip help
|
|
415
418
|
continue
|
|
419
|
+
if action.dest == "dryrun": # Skip dryrun
|
|
420
|
+
continue
|
|
416
421
|
|
|
417
422
|
current_value = getattr(args, action.dest, None)
|
|
418
423
|
default_value = action.default
|
torchx/cli/cmd_tracker.py
CHANGED
|
@@ -45,7 +45,7 @@ class CmdTracker(SubCommand):
|
|
|
45
45
|
else:
|
|
46
46
|
raise RuntimeError(
|
|
47
47
|
"No trackers configured."
|
|
48
|
-
" See: https://pytorch.org/torchx/latest/runtime/tracking.html"
|
|
48
|
+
" See: https://meta-pytorch.org/torchx/latest/runtime/tracking.html"
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
def add_list_job_arguments(self, subparser: argparse.ArgumentParser) -> None:
|
torchx/components/__init__.py
CHANGED
|
@@ -181,7 +181,7 @@ To validate that you've defined your component correctly you can either:
|
|
|
181
181
|
|
|
182
182
|
1. (easiest) Dryrun your component's ``--help`` with the cli: ``torchx run --dryrun ~/component.py:train --help``
|
|
183
183
|
2. Use the component :ref:`linter<specs:Component Linter>`
|
|
184
|
-
(see `dist_test.py <https://github.com/pytorch/torchx/blob/main/torchx/components/test/dist_test.py>`_ as an example)
|
|
184
|
+
(see `dist_test.py <https://github.com/meta-pytorch/torchx/blob/main/torchx/components/test/dist_test.py>`_ as an example)
|
|
185
185
|
|
|
186
186
|
|
|
187
187
|
Running as a Job
|
torchx/components/dist.py
CHANGED
|
@@ -92,6 +92,7 @@ def spmd(
|
|
|
92
92
|
h: str = "gpu.small",
|
|
93
93
|
j: str = "1x1",
|
|
94
94
|
env: Optional[Dict[str, str]] = None,
|
|
95
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
95
96
|
max_retries: int = 0,
|
|
96
97
|
mounts: Optional[List[str]] = None,
|
|
97
98
|
debug: bool = False,
|
|
@@ -131,6 +132,7 @@ def spmd(
|
|
|
131
132
|
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
|
|
132
133
|
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
|
|
133
134
|
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
|
|
135
|
+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
|
|
134
136
|
max_retries: the number of scheduler retries allowed
|
|
135
137
|
mounts: (for docker based runs only) mounts to mount into the worker environment/container
|
|
136
138
|
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
|
|
@@ -150,6 +152,7 @@ def spmd(
|
|
|
150
152
|
h=h,
|
|
151
153
|
j=str(StructuredJArgument.parse_from(h, j)),
|
|
152
154
|
env=env,
|
|
155
|
+
metadata=metadata,
|
|
153
156
|
max_retries=max_retries,
|
|
154
157
|
mounts=mounts,
|
|
155
158
|
debug=debug,
|
|
@@ -168,6 +171,7 @@ def ddp(
|
|
|
168
171
|
memMB: int = 1024,
|
|
169
172
|
j: str = "1x2",
|
|
170
173
|
env: Optional[Dict[str, str]] = None,
|
|
174
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
171
175
|
max_retries: int = 0,
|
|
172
176
|
rdzv_port: int = 29500,
|
|
173
177
|
rdzv_backend: str = "c10d",
|
|
@@ -186,7 +190,7 @@ def ddp(
|
|
|
186
190
|
|
|
187
191
|
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
|
|
188
192
|
``h`` takes precedence if specified for setting resource requirements.
|
|
189
|
-
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
193
|
+
See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
190
194
|
|
|
191
195
|
Args:
|
|
192
196
|
script_args: arguments to the main module
|
|
@@ -201,6 +205,7 @@ def ddp(
|
|
|
201
205
|
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
|
|
202
206
|
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
|
|
203
207
|
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
|
|
208
|
+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
|
|
204
209
|
max_retries: the number of scheduler retries allowed
|
|
205
210
|
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
|
|
206
211
|
Only takes effect when running multi-node. When running single node, this parameter
|
|
@@ -237,8 +242,8 @@ def ddp(
|
|
|
237
242
|
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
|
|
238
243
|
rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}")
|
|
239
244
|
|
|
240
|
-
|
|
241
|
-
|
|
245
|
+
env = env or {}
|
|
246
|
+
metadata = metadata or {}
|
|
242
247
|
|
|
243
248
|
argname = StructuredNameArgument.parse_from(
|
|
244
249
|
name=name,
|
|
@@ -299,6 +304,7 @@ def ddp(
|
|
|
299
304
|
mounts=specs.parse_mounts(mounts) if mounts else [],
|
|
300
305
|
)
|
|
301
306
|
],
|
|
307
|
+
metadata=metadata,
|
|
302
308
|
)
|
|
303
309
|
|
|
304
310
|
|
torchx/components/utils.py
CHANGED
|
@@ -154,7 +154,7 @@ def python(
|
|
|
154
154
|
|
|
155
155
|
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
|
|
156
156
|
``h`` takes precedence if specified for setting resource requirements.
|
|
157
|
-
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
157
|
+
See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
158
158
|
|
|
159
159
|
Args:
|
|
160
160
|
args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
|
torchx/distributed/__init__.py
CHANGED
|
@@ -48,7 +48,7 @@ def local_rank() -> int:
|
|
|
48
48
|
" but the `LOCAL_RANK` environment variable is not set. Will trivially return 0 for local_rank.\n"
|
|
49
49
|
" It is recommended to use torchrun/torchx to run your script or set the `LOCAL_RANK` manually.\n"
|
|
50
50
|
" For additional details see:\n"
|
|
51
|
-
" 1) https://pytorch.org/torchx/latest/components/distributed.html\n"
|
|
51
|
+
" 1) https://meta-pytorch.org/torchx/latest/components/distributed.html\n"
|
|
52
52
|
" 2) https://pytorch.org/docs/stable/elastic/run.html\n"
|
|
53
53
|
"=============================================================================================="
|
|
54
54
|
)
|
torchx/runner/api.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
@@ -43,6 +42,7 @@ from torchx.specs import (
|
|
|
43
42
|
parse_app_handle,
|
|
44
43
|
runopts,
|
|
45
44
|
UnknownAppException,
|
|
45
|
+
Workspace,
|
|
46
46
|
)
|
|
47
47
|
from torchx.specs.finder import get_component
|
|
48
48
|
from torchx.tracker.api import (
|
|
@@ -54,7 +54,7 @@ from torchx.tracker.api import (
|
|
|
54
54
|
from torchx.util.session import get_session_id_or_create_new, TORCHX_INTERNAL_SESSION_ID
|
|
55
55
|
|
|
56
56
|
from torchx.util.types import none_throws
|
|
57
|
-
from torchx.workspace
|
|
57
|
+
from torchx.workspace import WorkspaceMixin
|
|
58
58
|
|
|
59
59
|
if TYPE_CHECKING:
|
|
60
60
|
from typing_extensions import Self
|
|
@@ -420,36 +420,44 @@ class Runner:
|
|
|
420
420
|
scheduler,
|
|
421
421
|
runcfg=json.dumps(cfg) if cfg else None,
|
|
422
422
|
workspace=str(workspace),
|
|
423
|
-
):
|
|
423
|
+
) as ctx:
|
|
424
424
|
sched = self._scheduler(scheduler)
|
|
425
425
|
resolved_cfg = sched.run_opts().resolve(cfg)
|
|
426
426
|
|
|
427
427
|
sched._pre_build_validate(app, scheduler, resolved_cfg)
|
|
428
428
|
|
|
429
|
-
if
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
f"Reusing original image `{old_img}` for role[0]={role.name}."
|
|
447
|
-
" Either a patch was built or no changes to workspace was detected."
|
|
429
|
+
if isinstance(sched, WorkspaceMixin):
|
|
430
|
+
if workspace:
|
|
431
|
+
# NOTE: torchx originally took workspace as a runner arg and only applied the workspace to role[0]
|
|
432
|
+
# later, torchx added support for the workspace attr in Role
|
|
433
|
+
# for BC, give precedence to the workspace argument over the workspace attr for role[0]
|
|
434
|
+
if app.roles[0].workspace:
|
|
435
|
+
logger.info(
|
|
436
|
+
"Overriding role[%d] (%s) workspace to `%s`"
|
|
437
|
+
"To use the role's workspace attr pass: --workspace='' from CLI or workspace=None programmatically.",
|
|
438
|
+
0,
|
|
439
|
+
role.name,
|
|
440
|
+
str(app.roles[0].workspace),
|
|
441
|
+
)
|
|
442
|
+
app.roles[0].workspace = (
|
|
443
|
+
Workspace.from_str(workspace)
|
|
444
|
+
if isinstance(workspace, str)
|
|
445
|
+
else workspace
|
|
448
446
|
)
|
|
449
447
|
|
|
448
|
+
sched.build_workspaces(app.roles, resolved_cfg)
|
|
449
|
+
|
|
450
450
|
sched._validate(app, scheduler, resolved_cfg)
|
|
451
451
|
dryrun_info = sched.submit_dryrun(app, resolved_cfg)
|
|
452
452
|
dryrun_info._scheduler = scheduler
|
|
453
|
+
|
|
454
|
+
event = ctx._torchx_event
|
|
455
|
+
event.scheduler = scheduler
|
|
456
|
+
event.runcfg = json.dumps(cfg) if cfg else None
|
|
457
|
+
event.app_id = app.name
|
|
458
|
+
event.app_image = none_throws(dryrun_info._app).roles[0].image
|
|
459
|
+
event.app_metadata = app.metadata
|
|
460
|
+
|
|
453
461
|
return dryrun_info
|
|
454
462
|
|
|
455
463
|
def scheduler_run_opts(self, scheduler: str) -> runopts:
|
torchx/runner/config.py
CHANGED
|
@@ -494,6 +494,8 @@ def find_configs(dirs: Optional[Iterable[str]] = None) -> List[str]:
|
|
|
494
494
|
|
|
495
495
|
config = os.getenv(ENV_TORCHXCONFIG)
|
|
496
496
|
if config is not None:
|
|
497
|
+
if not config:
|
|
498
|
+
return []
|
|
497
499
|
configfile = Path(config)
|
|
498
500
|
if not configfile.is_file():
|
|
499
501
|
raise FileNotFoundError(
|
torchx/schedulers/__init__.py
CHANGED
|
@@ -49,15 +49,14 @@ def get_scheduler_factories(
|
|
|
49
49
|
The first scheduler in the dictionary is used as the default scheduler.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
)
|
|
52
|
+
if skip_defaults:
|
|
53
|
+
default_schedulers = {}
|
|
54
|
+
else:
|
|
55
|
+
default_schedulers: dict[str, SchedulerFactory] = {}
|
|
56
|
+
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
|
|
57
|
+
default_schedulers[scheduler] = _defer_load_scheduler(path)
|
|
58
|
+
|
|
59
|
+
return load_group(group, default=default_schedulers)
|
|
61
60
|
|
|
62
61
|
|
|
63
62
|
def get_default_scheduler_name() -> str:
|
torchx/schedulers/api.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
1
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
2
|
# All rights reserved.
|
|
4
3
|
#
|
|
@@ -22,8 +21,9 @@ from torchx.specs import (
|
|
|
22
21
|
Role,
|
|
23
22
|
RoleStatus,
|
|
24
23
|
runopts,
|
|
24
|
+
Workspace,
|
|
25
25
|
)
|
|
26
|
-
from torchx.workspace
|
|
26
|
+
from torchx.workspace import WorkspaceMixin
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
DAYS_IN_2_WEEKS = 14
|
|
@@ -131,7 +131,7 @@ class Scheduler(abc.ABC, Generic[T, A, D]):
|
|
|
131
131
|
self,
|
|
132
132
|
app: A,
|
|
133
133
|
cfg: T,
|
|
134
|
-
workspace:
|
|
134
|
+
workspace: str | Workspace | None = None,
|
|
135
135
|
) -> str:
|
|
136
136
|
"""
|
|
137
137
|
Submits the application to be run by the scheduler.
|
|
@@ -145,7 +145,12 @@ class Scheduler(abc.ABC, Generic[T, A, D]):
|
|
|
145
145
|
resolved_cfg = self.run_opts().resolve(cfg)
|
|
146
146
|
if workspace:
|
|
147
147
|
assert isinstance(self, WorkspaceMixin)
|
|
148
|
-
|
|
148
|
+
|
|
149
|
+
if isinstance(workspace, str):
|
|
150
|
+
workspace = Workspace.from_str(workspace)
|
|
151
|
+
|
|
152
|
+
app.roles[0].workspace = workspace
|
|
153
|
+
self.build_workspaces(app.roles, resolved_cfg)
|
|
149
154
|
|
|
150
155
|
# pyre-fixme: submit_dryrun takes Generic type for resolved_cfg
|
|
151
156
|
dryrun_info = self.submit_dryrun(app, resolved_cfg)
|
|
@@ -92,6 +92,8 @@ ENV_TORCHX_ROLE_IDX = "TORCHX_ROLE_IDX"
|
|
|
92
92
|
|
|
93
93
|
ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME"
|
|
94
94
|
|
|
95
|
+
ENV_TORCHX_IMAGE = "TORCHX_IMAGE"
|
|
96
|
+
|
|
95
97
|
DEFAULT_ROLE_NAME = "node"
|
|
96
98
|
|
|
97
99
|
TAG_TORCHX_VER = "torchx.pytorch.org/version"
|
|
@@ -99,6 +101,37 @@ TAG_TORCHX_APPNAME = "torchx.pytorch.org/app-name"
|
|
|
99
101
|
TAG_TORCHX_USER = "torchx.pytorch.org/user"
|
|
100
102
|
|
|
101
103
|
|
|
104
|
+
def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]:
|
|
105
|
+
"""
|
|
106
|
+
Parse ulimit string in format: name:softLimit:hardLimit
|
|
107
|
+
Multiple ulimits separated by commas.
|
|
108
|
+
"""
|
|
109
|
+
if not ulimits_list:
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
ulimits = []
|
|
113
|
+
for ulimit_str in ulimits_list:
|
|
114
|
+
if not ulimit_str.strip():
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
parts = ulimit_str.strip().split(":")
|
|
118
|
+
if len(parts) != 3:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
name, soft_limit, hard_limit = parts
|
|
124
|
+
ulimits.append(
|
|
125
|
+
{
|
|
126
|
+
"name": name,
|
|
127
|
+
"softLimit": int(soft_limit) if soft_limit != "-1" else -1,
|
|
128
|
+
"hardLimit": int(hard_limit) if hard_limit != "-1" else -1,
|
|
129
|
+
}
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return ulimits
|
|
133
|
+
|
|
134
|
+
|
|
102
135
|
if TYPE_CHECKING:
|
|
103
136
|
from docker import DockerClient
|
|
104
137
|
|
|
@@ -177,7 +210,8 @@ def _role_to_node_properties(
|
|
|
177
210
|
privileged: bool = False,
|
|
178
211
|
job_role_arn: Optional[str] = None,
|
|
179
212
|
execution_role_arn: Optional[str] = None,
|
|
180
|
-
|
|
213
|
+
ulimits: Optional[List[Dict[str, Any]]] = None,
|
|
214
|
+
) -> Dict[str, Any]:
|
|
181
215
|
role.mounts += get_device_mounts(role.resource.devices)
|
|
182
216
|
|
|
183
217
|
mount_points = []
|
|
@@ -239,6 +273,7 @@ def _role_to_node_properties(
|
|
|
239
273
|
"environment": [{"name": k, "value": v} for k, v in role.env.items()],
|
|
240
274
|
"privileged": privileged,
|
|
241
275
|
"resourceRequirements": resource_requirements_from_resource(role.resource),
|
|
276
|
+
**({"ulimits": ulimits} if ulimits else {}),
|
|
242
277
|
"linuxParameters": {
|
|
243
278
|
# To support PyTorch dataloaders we need to set /dev/shm to larger
|
|
244
279
|
# than the 64M default.
|
|
@@ -361,6 +396,7 @@ class AWSBatchOpts(TypedDict, total=False):
|
|
|
361
396
|
priority: int
|
|
362
397
|
job_role_arn: Optional[str]
|
|
363
398
|
execution_role_arn: Optional[str]
|
|
399
|
+
ulimits: Optional[list[str]]
|
|
364
400
|
|
|
365
401
|
|
|
366
402
|
class AWSBatchScheduler(
|
|
@@ -506,6 +542,7 @@ class AWSBatchScheduler(
|
|
|
506
542
|
role = values.apply(role)
|
|
507
543
|
role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx)
|
|
508
544
|
role.env[ENV_TORCHX_ROLE_NAME] = str(role.name)
|
|
545
|
+
role.env[ENV_TORCHX_IMAGE] = role.image
|
|
509
546
|
|
|
510
547
|
nodes.append(
|
|
511
548
|
_role_to_node_properties(
|
|
@@ -514,6 +551,7 @@ class AWSBatchScheduler(
|
|
|
514
551
|
privileged=cfg["privileged"],
|
|
515
552
|
job_role_arn=cfg.get("job_role_arn"),
|
|
516
553
|
execution_role_arn=cfg.get("execution_role_arn"),
|
|
554
|
+
ulimits=parse_ulimits(cfg.get("ulimits") or []),
|
|
517
555
|
)
|
|
518
556
|
)
|
|
519
557
|
node_idx += role.num_replicas
|
|
@@ -599,6 +637,11 @@ class AWSBatchScheduler(
|
|
|
599
637
|
type_=str,
|
|
600
638
|
help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
|
|
601
639
|
)
|
|
640
|
+
opts.add(
|
|
641
|
+
"ulimits",
|
|
642
|
+
type_=List[str],
|
|
643
|
+
help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)",
|
|
644
|
+
)
|
|
602
645
|
return opts
|
|
603
646
|
|
|
604
647
|
def _get_job_id(self, app_id: str) -> Optional[str]:
|
|
@@ -84,6 +84,8 @@ LABEL_APP_ID: str = "torchx.pytorch.org/app-id"
|
|
|
84
84
|
LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name"
|
|
85
85
|
LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id"
|
|
86
86
|
|
|
87
|
+
ENV_TORCHX_IMAGE: str = "TORCHX_IMAGE"
|
|
88
|
+
|
|
87
89
|
NETWORK = "torchx"
|
|
88
90
|
|
|
89
91
|
|
|
@@ -279,6 +281,7 @@ class DockerScheduler(
|
|
|
279
281
|
|
|
280
282
|
# configure distributed host envs
|
|
281
283
|
env["TORCHX_RANK0_HOST"] = rank0_name
|
|
284
|
+
env[ENV_TORCHX_IMAGE] = replica_role.image
|
|
282
285
|
|
|
283
286
|
c = DockerContainer(
|
|
284
287
|
image=replica_role.image,
|