torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +7 -6
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +79 -5
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +431 -32
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +254 -22
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
  108. torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
  109. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,12 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import json
9
11
  from dataclasses import asdict, dataclass
10
12
  from enum import Enum
11
- from typing import Optional, Union
13
+ from typing import Dict, Optional, Union
12
14
 
13
15
 
14
16
  class SourceType(str, Enum):
@@ -23,15 +25,18 @@ class TorchxEvent:
23
25
  The class represents the event produced by ``torchx.runner`` api calls.
24
26
 
25
27
  Arguments:
26
- session: Session id that was used to execute request.
28
+ session: Session id of the current run
27
29
  scheduler: Scheduler that is used to execute request
28
30
  api: Api name
29
31
  app_id: Unique id that is set by the underlying scheduler
30
- image: Image/container bundle that is used to execute request.
32
+ app_image: Image/container bundle that is used to execute request.
33
+ app_metadata: metadata to the app (treatment of metadata is scheduler dependent)
31
34
  runcfg: Run config that was used to schedule app.
32
35
  source: Type of source the event is generated.
33
36
  cpu_time_usec: CPU time spent in usec
34
37
  wall_time_usec: Wall time spent in usec
38
+ start_epoch_time_usec: Epoch time in usec when runner event starts
39
+ Workspace: Track how different workspaces/no workspace affects build and scheduler
35
40
  """
36
41
 
37
42
  session: str
@@ -39,11 +44,17 @@ class TorchxEvent:
39
44
  api: str
40
45
  app_id: Optional[str] = None
41
46
  app_image: Optional[str] = None
47
+ app_metadata: Optional[Dict[str, str]] = None
42
48
  runcfg: Optional[str] = None
43
49
  raw_exception: Optional[str] = None
44
50
  source: SourceType = SourceType.UNKNOWN
45
51
  cpu_time_usec: Optional[int] = None
46
52
  wall_time_usec: Optional[int] = None
53
+ start_epoch_time_usec: Optional[int] = None
54
+ workspace: Optional[str] = None
55
+ exception_type: Optional[str] = None
56
+ exception_message: Optional[str] = None
57
+ exception_source_location: Optional[str] = None
47
58
 
48
59
  def __str__(self) -> str:
49
60
  return self.serialize()
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import logging
9
11
  from typing import Dict
10
12
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  .. note:: EXPERIMENTAL, USE AT YOUR OWN RISK, APIs SUBJECT TO CHANGE
10
12
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import abc
9
11
  import json
10
12
  from typing import Dict, Union
@@ -5,12 +5,13 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import importlib
9
- from typing import Dict, Mapping
11
+ from typing import Mapping, Protocol
10
12
 
11
13
  from torchx.schedulers.api import Scheduler
12
14
  from torchx.util.entrypoints import load_group
13
- from typing_extensions import Protocol
14
15
 
15
16
  DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
16
17
  "local_docker": "torchx.schedulers.docker_scheduler",
@@ -19,16 +20,14 @@ DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
19
20
  "kubernetes": "torchx.schedulers.kubernetes_scheduler",
20
21
  "kubernetes_mcad": "torchx.schedulers.kubernetes_mcad_scheduler",
21
22
  "aws_batch": "torchx.schedulers.aws_batch_scheduler",
22
- "gcp_batch": "torchx.schedulers.gcp_batch_scheduler",
23
- "ray": "torchx.schedulers.ray_scheduler",
23
+ "aws_sagemaker": "torchx.schedulers.aws_sagemaker_scheduler",
24
24
  "lsf": "torchx.schedulers.lsf_scheduler",
25
25
  }
26
26
 
27
27
 
28
28
  class SchedulerFactory(Protocol):
29
29
  # pyre-fixme: Scheduler opts
30
- def __call__(self, session_name: str, **kwargs: object) -> Scheduler:
31
- ...
30
+ def __call__(self, session_name: str, **kwargs: object) -> Scheduler: ...
32
31
 
33
32
 
34
33
  def _defer_load_scheduler(path: str) -> SchedulerFactory:
@@ -40,22 +39,24 @@ def _defer_load_scheduler(path: str) -> SchedulerFactory:
40
39
  return run
41
40
 
42
41
 
43
- def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
42
+ def get_scheduler_factories(
43
+ group: str = "torchx.schedulers", skip_defaults: bool = False
44
+ ) -> dict[str, SchedulerFactory]:
44
45
  """
45
- get_scheduler_factories returns all the available schedulers names and the
46
+ get_scheduler_factories returns all the available schedulers names under `group` and the
46
47
  method to instantiate them.
47
48
 
48
49
  The first scheduler in the dictionary is used as the default scheduler.
49
50
  """
50
51
 
51
- default_schedulers: Dict[str, SchedulerFactory] = {}
52
- for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
53
- default_schedulers[scheduler] = _defer_load_scheduler(path)
52
+ if skip_defaults:
53
+ default_schedulers = {}
54
+ else:
55
+ default_schedulers: dict[str, SchedulerFactory] = {}
56
+ for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57
+ default_schedulers[scheduler] = _defer_load_scheduler(path)
54
58
 
55
- return load_group(
56
- "torchx.schedulers",
57
- default=default_schedulers,
58
- )
59
+ return load_group(group, default=default_schedulers)
59
60
 
60
61
 
61
62
  def get_default_scheduler_name() -> str:
torchx/schedulers/api.py CHANGED
@@ -1,10 +1,11 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
5
4
  # This source code is licensed under the BSD-style license found in the
6
5
  # LICENSE file in the root directory of this source tree.
7
6
 
7
+ # pyre-strict
8
+
8
9
  import abc
9
10
  import re
10
11
  from dataclasses import dataclass, field
@@ -21,8 +22,9 @@ from torchx.specs import (
21
22
  Role,
22
23
  RoleStatus,
23
24
  runopts,
25
+ Workspace,
24
26
  )
25
- from torchx.workspace.api import WorkspaceMixin
27
+ from torchx.workspace import WorkspaceMixin
26
28
 
27
29
 
28
30
  DAYS_IN_2_WEEKS = 14
@@ -41,7 +43,7 @@ class DescribeAppResponse:
41
43
  the status and description of the application as known by the scheduler.
42
44
  For some schedulers implementations this response object has necessary
43
45
  and sufficient information to recreate an ``AppDef`` object. For these types
44
- of schedulers, the user can re-``run()`` the recreted application. Otherwise
46
+ of schedulers, the user can re-``run()`` the recreated application. Otherwise
45
47
  the user can only call non-creating methods (e.g. ``wait()``, ``status()``,
46
48
  etc).
47
49
 
@@ -59,6 +61,7 @@ class DescribeAppResponse:
59
61
  msg: str = NONE
60
62
  structured_error_msg: str = NONE
61
63
  ui_url: Optional[str] = None
64
+ metadata: dict[str, str] = field(default_factory=dict)
62
65
 
63
66
  roles_statuses: List[RoleStatus] = field(default_factory=list)
64
67
  roles: List[Role] = field(default_factory=list)
@@ -83,6 +86,7 @@ class ListAppResponse:
83
86
  app_id: str
84
87
  state: AppState
85
88
  app_handle: str = "<NOT_SET>"
89
+ name: str = ""
86
90
 
87
91
  # Implementing __hash__() makes ListAppResponse hashable which makes
88
92
  # it easier to check if a ListAppResponse object exists in a list of
@@ -126,7 +130,7 @@ class Scheduler(abc.ABC, Generic[T]):
126
130
  self,
127
131
  app: AppDef,
128
132
  cfg: T,
129
- workspace: Optional[str] = None,
133
+ workspace: str | Workspace | None = None,
130
134
  ) -> str:
131
135
  """
132
136
  Submits the application to be run by the scheduler.
@@ -139,10 +143,14 @@ class Scheduler(abc.ABC, Generic[T]):
139
143
  # pyre-fixme: Generic cfg type passed to resolve
140
144
  resolved_cfg = self.run_opts().resolve(cfg)
141
145
  if workspace:
142
- sched = self
143
- assert isinstance(sched, WorkspaceMixin)
144
- role = app.roles[0]
145
- sched.build_workspace_and_update_role(role, workspace, resolved_cfg)
146
+ assert isinstance(self, WorkspaceMixin)
147
+
148
+ if isinstance(workspace, str):
149
+ workspace = Workspace.from_str(workspace)
150
+
151
+ app.roles[0].workspace = workspace
152
+ self.build_workspaces(app.roles, resolved_cfg)
153
+
146
154
  # pyre-fixme: submit_dryrun takes Generic type for resolved_cfg
147
155
  dryrun_info = self.submit_dryrun(app, resolved_cfg)
148
156
  return self.schedule(dryrun_info)
@@ -177,8 +185,10 @@ class Scheduler(abc.ABC, Generic[T]):
177
185
  resolved_cfg = self.run_opts().resolve(cfg)
178
186
  # pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
179
187
  dryrun_info = self._submit_dryrun(app, resolved_cfg)
188
+
180
189
  for role in app.roles:
181
190
  dryrun_info = role.pre_proc(self.backend, dryrun_info)
191
+
182
192
  dryrun_info._app = app
183
193
  dryrun_info._cfg = resolved_cfg
184
194
  return dryrun_info
@@ -253,6 +263,46 @@ class Scheduler(abc.ABC, Generic[T]):
253
263
  # do nothing if the app does not exist
254
264
  return
255
265
 
266
+ def delete(self, app_id: str) -> None:
267
+ """
268
+ Deletes the job information for the specified ``app_id`` from the
269
+ scheduler's data-plane. Basically "deep-purging" the job from the
270
+ scheduler's data-plane. Calling this API on a "live" job (e.g in a
271
+ non-terminal status such as PENDING or RUNNING) cancels the job.
272
+
273
+ Note that this API is only relevant for schedulers for which its
274
+ data-plane persistently stores the "JobDefinition" (which is often
275
+ versioned). AWS Batch and Kubernetes are examples of such schedulers.
276
+ On these schedulers, a finished job may fall out of the data-plane
277
+ (e.g. really old finished jobs get deleted) but the JobDefinition is
278
+ typically permanently stored. In this case, calling
279
+ :py:meth:`~cancel` would not delete the job definition.
280
+
281
+ In schedulers with no such feature (e.g. SLURM)
282
+ :py:meth:`~delete` is the same as :py:meth:`~cancel`, which is the
283
+ default implementation. Hence implementors of such schedulers need not
284
+ override this method.
285
+
286
+ .. warning::
287
+ Calling :py:meth:`~delete` on an ``app_id`` that has fallen out of
288
+ the scheduler's data-plane does nothing. The user is responsible for
289
+ manually tracking down and cleaning up any dangling resources related
290
+ to the job.
291
+ """
292
+ if self.exists(app_id):
293
+ self._delete_existing(app_id)
294
+
295
+ def _delete_existing(self, app_id: str) -> None:
296
+ """
297
+ Deletes the job information for the specified ``app_id`` from the
298
+ scheduler's data-plane. This method will only be called on an
299
+ application that exists.
300
+
301
+ The default implementation calls :py:meth:`~_cancel_existing` which is
302
+ appropriate for schedulers without persistent job definitions.
303
+ """
304
+ self._cancel_existing(app_id)
305
+
256
306
  def log_iter(
257
307
  self,
258
308
  app_id: str,
@@ -335,18 +385,24 @@ class Scheduler(abc.ABC, Generic[T]):
335
385
  f"{self.__class__.__qualname__} does not support application log iteration"
336
386
  )
337
387
 
338
- def _validate(self, app: AppDef, scheduler: str) -> None:
388
+ def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
339
389
  """
340
- Validates whether application is consistent with the scheduler.
390
+ validates before workspace build whether application is consistent with the scheduler.
341
391
 
342
- Raises:
343
- ValueError: if application is not compatible with scheduler
392
+ Raises error if application is not compatible with scheduler
393
+ """
394
+ pass
395
+
396
+ def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
397
+ """
398
+ Validates after workspace build whether application is consistent with the scheduler.
399
+
400
+ Raises error if application is not compatible with scheduler
344
401
  """
345
402
  for role in app.roles:
346
403
  if role.resource == NULL_RESOURCE:
347
404
  raise ValueError(
348
- f"No resource for role: {role.image}."
349
- f" Did you forget to attach resource to the role"
405
+ f"No resource for role: {role.image}. Did you forget to attach resource to the role"
350
406
  )
351
407
 
352
408
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
 
10
12
  This contains the TorchX AWS Batch scheduler which can be used to run TorchX
@@ -51,13 +53,13 @@ from typing import (
51
53
  Optional,
52
54
  Tuple,
53
55
  TYPE_CHECKING,
56
+ TypedDict,
54
57
  TypeVar,
55
58
  )
56
59
 
57
60
  import torchx
58
61
  import yaml
59
62
  from torchx.schedulers.api import (
60
- AppDryRunInfo,
61
63
  DescribeAppResponse,
62
64
  filter_regex,
63
65
  ListAppResponse,
@@ -69,6 +71,7 @@ from torchx.schedulers.devices import get_device_mounts
69
71
  from torchx.schedulers.ids import make_unique
70
72
  from torchx.specs.api import (
71
73
  AppDef,
74
+ AppDryRunInfo,
72
75
  AppState,
73
76
  BindMount,
74
77
  CfgVal,
@@ -81,14 +84,16 @@ from torchx.specs.api import (
81
84
  runopts,
82
85
  VolumeMount,
83
86
  )
87
+ from torchx.specs.named_resources_aws import instance_type_from_resource
84
88
  from torchx.util.types import none_throws
85
89
  from torchx.workspace.docker_workspace import DockerWorkspaceMixin
86
- from typing_extensions import TypedDict
87
90
 
88
91
  ENV_TORCHX_ROLE_IDX = "TORCHX_ROLE_IDX"
89
92
 
90
93
  ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME"
91
94
 
95
+ ENV_TORCHX_IMAGE = "TORCHX_IMAGE"
96
+
92
97
  DEFAULT_ROLE_NAME = "node"
93
98
 
94
99
  TAG_TORCHX_VER = "torchx.pytorch.org/version"
@@ -96,6 +101,37 @@ TAG_TORCHX_APPNAME = "torchx.pytorch.org/app-name"
96
101
  TAG_TORCHX_USER = "torchx.pytorch.org/user"
97
102
 
98
103
 
104
+ def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]:
105
+ """
106
+ Parse ulimit string in format: name:softLimit:hardLimit
107
+ Multiple ulimits separated by commas.
108
+ """
109
+ if not ulimits_list:
110
+ return []
111
+
112
+ ulimits = []
113
+ for ulimit_str in ulimits_list:
114
+ if not ulimit_str.strip():
115
+ continue
116
+
117
+ parts = ulimit_str.strip().split(":")
118
+ if len(parts) != 3:
119
+ raise ValueError(
120
+ f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}"
121
+ )
122
+
123
+ name, soft_limit, hard_limit = parts
124
+ ulimits.append(
125
+ {
126
+ "name": name,
127
+ "softLimit": int(soft_limit) if soft_limit != "-1" else -1,
128
+ "hardLimit": int(hard_limit) if hard_limit != "-1" else -1,
129
+ }
130
+ )
131
+
132
+ return ulimits
133
+
134
+
99
135
  if TYPE_CHECKING:
100
136
  from docker import DockerClient
101
137
 
@@ -169,8 +205,13 @@ def resource_from_resource_requirements(
169
205
 
170
206
 
171
207
  def _role_to_node_properties(
172
- role: Role, start_idx: int, privileged: bool = False
173
- ) -> Dict[str, object]:
208
+ role: Role,
209
+ start_idx: int,
210
+ privileged: bool = False,
211
+ job_role_arn: Optional[str] = None,
212
+ execution_role_arn: Optional[str] = None,
213
+ ulimits: Optional[List[Dict[str, Any]]] = None,
214
+ ) -> Dict[str, Any]:
174
215
  role.mounts += get_device_mounts(role.resource.devices)
175
216
 
176
217
  mount_points = []
@@ -232,6 +273,7 @@ def _role_to_node_properties(
232
273
  "environment": [{"name": k, "value": v} for k, v in role.env.items()],
233
274
  "privileged": privileged,
234
275
  "resourceRequirements": resource_requirements_from_resource(role.resource),
276
+ **({"ulimits": ulimits} if ulimits else {}),
235
277
  "linuxParameters": {
236
278
  # To support PyTorch dataloaders we need to set /dev/shm to larger
237
279
  # than the 64M default.
@@ -244,6 +286,14 @@ def _role_to_node_properties(
244
286
  "mountPoints": mount_points,
245
287
  "volumes": volumes,
246
288
  }
289
+ if job_role_arn:
290
+ container["jobRoleArn"] = job_role_arn
291
+ if execution_role_arn:
292
+ container["executionRoleArn"] = execution_role_arn
293
+ if role.num_replicas > 0:
294
+ instance_type = instance_type_from_resource(role.resource)
295
+ if instance_type is not None:
296
+ container["instanceType"] = instance_type
247
297
 
248
298
  return {
249
299
  "targetNodes": f"{start_idx}:{start_idx + role.num_replicas - 1}",
@@ -331,7 +381,7 @@ def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
331
381
 
332
382
 
333
383
  @_thread_local_cache
334
- def _local_session() -> "boto3.session.Session":
384
+ def _local_session() -> "boto3.session.Session": # noqa: F821
335
385
  import boto3.session
336
386
 
337
387
  return boto3.session.Session()
@@ -344,6 +394,9 @@ class AWSBatchOpts(TypedDict, total=False):
344
394
  privileged: bool
345
395
  share_id: Optional[str]
346
396
  priority: int
397
+ job_role_arn: Optional[str]
398
+ execution_role_arn: Optional[str]
399
+ ulimits: Optional[list[str]]
347
400
 
348
401
 
349
402
  class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
@@ -487,12 +540,16 @@ class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
487
540
  role = values.apply(role)
488
541
  role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx)
489
542
  role.env[ENV_TORCHX_ROLE_NAME] = str(role.name)
543
+ role.env[ENV_TORCHX_IMAGE] = role.image
490
544
 
491
545
  nodes.append(
492
546
  _role_to_node_properties(
493
547
  role,
494
548
  start_idx=node_idx,
495
549
  privileged=cfg["privileged"],
550
+ job_role_arn=cfg.get("job_role_arn"),
551
+ execution_role_arn=cfg.get("execution_role_arn"),
552
+ ulimits=parse_ulimits(cfg.get("ulimits") or []),
496
553
  )
497
554
  )
498
555
  node_idx += role.num_replicas
@@ -568,6 +625,21 @@ class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
568
625
  "Higher number (between 0 and 9999) means higher priority. "
569
626
  "This will only take effect if the job queue has a scheduling policy.",
570
627
  )
628
+ opts.add(
629
+ "job_role_arn",
630
+ type_=str,
631
+ help="The Amazon Resource Name (ARN) of the IAM role that the container can assume for AWS permissions.",
632
+ )
633
+ opts.add(
634
+ "execution_role_arn",
635
+ type_=str,
636
+ help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
637
+ )
638
+ opts.add(
639
+ "ulimits",
640
+ type_=List[str],
641
+ help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)",
642
+ )
571
643
  return opts
572
644
 
573
645
  def _get_job_id(self, app_id: str) -> Optional[str]:
@@ -780,6 +852,8 @@ class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
780
852
  startFromHead=True,
781
853
  **args,
782
854
  )
855
+ # pyre-fixme[66]: Exception handler type annotation `unknown` must
856
+ # extend BaseException.
783
857
  except self._log_client.exceptions.ResourceNotFoundException:
784
858
  return [] # noqa: B901
785
859
  if response["nextForwardToken"] == next_token: