torchx-nightly 2025.7.9__py3-none-any.whl → 2025.11.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  2. torchx/cli/cmd_list.py +1 -2
  3. torchx/cli/cmd_run.py +202 -28
  4. torchx/cli/cmd_tracker.py +1 -1
  5. torchx/components/__init__.py +1 -8
  6. torchx/components/dist.py +9 -3
  7. torchx/components/integration_tests/component_provider.py +2 -2
  8. torchx/components/utils.py +1 -1
  9. torchx/distributed/__init__.py +1 -1
  10. torchx/runner/api.py +92 -81
  11. torchx/runner/config.py +11 -9
  12. torchx/runner/events/__init__.py +20 -10
  13. torchx/runner/events/api.py +1 -1
  14. torchx/schedulers/__init__.py +7 -10
  15. torchx/schedulers/api.py +20 -15
  16. torchx/schedulers/aws_batch_scheduler.py +45 -2
  17. torchx/schedulers/docker_scheduler.py +3 -0
  18. torchx/schedulers/kubernetes_scheduler.py +200 -17
  19. torchx/schedulers/local_scheduler.py +1 -0
  20. torchx/schedulers/slurm_scheduler.py +160 -26
  21. torchx/specs/__init__.py +23 -6
  22. torchx/specs/api.py +279 -33
  23. torchx/specs/builders.py +109 -28
  24. torchx/specs/file_linter.py +117 -53
  25. torchx/specs/finder.py +25 -37
  26. torchx/specs/named_resources_aws.py +13 -2
  27. torchx/tracker/__init__.py +2 -2
  28. torchx/tracker/api.py +1 -1
  29. torchx/util/entrypoints.py +1 -6
  30. torchx/util/strings.py +1 -1
  31. torchx/util/types.py +12 -1
  32. torchx/version.py +2 -2
  33. torchx/workspace/api.py +102 -5
  34. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/METADATA +34 -48
  35. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/RECORD +39 -51
  36. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/WHEEL +1 -1
  37. torchx/examples/pipelines/__init__.py +0 -0
  38. torchx/examples/pipelines/kfp/__init__.py +0 -0
  39. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -289
  40. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -71
  41. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -83
  42. torchx/pipelines/kfp/__init__.py +0 -30
  43. torchx/pipelines/kfp/adapter.py +0 -274
  44. torchx/pipelines/kfp/version.py +0 -19
  45. torchx/schedulers/gcp_batch_scheduler.py +0 -497
  46. torchx/schedulers/ray/ray_common.py +0 -22
  47. torchx/schedulers/ray/ray_driver.py +0 -307
  48. torchx/schedulers/ray_scheduler.py +0 -454
  49. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/entry_points.txt +0 -0
  50. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info/licenses}/LICENSE +0 -0
  51. {torchx_nightly-2025.7.9.dist-info → torchx_nightly-2025.11.12.dist-info}/top_level.txt +0 -0
torchx/runner/api.py CHANGED
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
@@ -25,6 +24,7 @@ from typing import (
25
24
  Type,
26
25
  TYPE_CHECKING,
27
26
  TypeVar,
27
+ Union,
28
28
  )
29
29
 
30
30
  from torchx.runner.events import log_event
@@ -42,6 +42,7 @@ from torchx.specs import (
42
42
  parse_app_handle,
43
43
  runopts,
44
44
  UnknownAppException,
45
+ Workspace,
45
46
  )
46
47
  from torchx.specs.finder import get_component
47
48
  from torchx.tracker.api import (
@@ -53,7 +54,7 @@ from torchx.tracker.api import (
53
54
  from torchx.util.session import get_session_id_or_create_new, TORCHX_INTERNAL_SESSION_ID
54
55
 
55
56
  from torchx.util.types import none_throws
56
- from torchx.workspace.api import PkgInfo, WorkspaceBuilder, WorkspaceMixin
57
+ from torchx.workspace import WorkspaceMixin
57
58
 
58
59
  if TYPE_CHECKING:
59
60
  from typing_extensions import Self
@@ -129,9 +130,9 @@ class Runner:
129
130
  def _get_scheduler_params_from_env(self) -> Dict[str, str]:
130
131
  scheduler_params = {}
131
132
  for key, value in os.environ.items():
132
- lower_case_key = key.lower()
133
- if lower_case_key.startswith("torchx_"):
134
- scheduler_params[lower_case_key.strip("torchx_")] = value
133
+ key = key.lower()
134
+ if key.startswith("torchx_"):
135
+ scheduler_params[key.removeprefix("torchx_")] = value
135
136
  return scheduler_params
136
137
 
137
138
  def __enter__(self) -> "Self":
@@ -164,25 +165,13 @@ class Runner:
164
165
  for scheduler in self._scheduler_instances.values():
165
166
  scheduler.close()
166
167
 
167
- def build_standalone_workspace(
168
- self,
169
- workspace_builder: WorkspaceBuilder[S, T],
170
- sync: bool = True,
171
- ) -> PkgInfo[S]:
172
- """
173
- Build a standalone workspace for the given role.
174
- This method is used to build a workspace for a role independent of the scheduler and
175
- also enables asynchronous workspace building using the Role overrides.
176
- """
177
- return workspace_builder.build_workspace(sync)
178
-
179
168
  def run_component(
180
169
  self,
181
170
  component: str,
182
- component_args: List[str],
171
+ component_args: Union[list[str], dict[str, Any]],
183
172
  scheduler: str,
184
173
  cfg: Optional[Mapping[str, CfgVal]] = None,
185
- workspace: Optional[str] = None,
174
+ workspace: Optional[Union[Workspace, str]] = None,
186
175
  parent_run_id: Optional[str] = None,
187
176
  ) -> AppHandle:
188
177
  """
@@ -217,7 +206,7 @@ class Runner:
217
206
  ComponentNotFoundException: if the ``component_path`` is failed to resolve.
218
207
  """
219
208
 
220
- with log_event("run_component", workspace=workspace) as ctx:
209
+ with log_event("run_component") as ctx:
221
210
  dryrun_info = self.dryrun_component(
222
211
  component,
223
212
  component_args,
@@ -228,7 +217,8 @@ class Runner:
228
217
  )
229
218
  handle = self.schedule(dryrun_info)
230
219
  app = none_throws(dryrun_info._app)
231
- ctx._torchx_event.workspace = workspace
220
+
221
+ ctx._torchx_event.workspace = str(workspace)
232
222
  ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
233
223
  ctx._torchx_event.app_image = app.roles[0].image
234
224
  ctx._torchx_event.app_id = parse_app_handle(handle)[2]
@@ -238,10 +228,10 @@ class Runner:
238
228
  def dryrun_component(
239
229
  self,
240
230
  component: str,
241
- component_args: List[str],
231
+ component_args: Union[list[str], dict[str, Any]],
242
232
  scheduler: str,
243
233
  cfg: Optional[Mapping[str, CfgVal]] = None,
244
- workspace: Optional[str] = None,
234
+ workspace: Optional[Union[Workspace, str]] = None,
245
235
  parent_run_id: Optional[str] = None,
246
236
  ) -> AppDryRunInfo:
247
237
  """
@@ -249,10 +239,13 @@ class Runner:
249
239
  component, but just returns what "would" have run.
250
240
  """
251
241
  component_def = get_component(component)
242
+ args_from_cli = component_args if isinstance(component_args, list) else []
243
+ args_from_json = component_args if isinstance(component_args, dict) else {}
252
244
  app = materialize_appdef(
253
245
  component_def.fn,
254
- component_args,
246
+ args_from_cli,
255
247
  self._component_defaults.get(component, None),
248
+ args_from_json,
256
249
  )
257
250
  return self.dryrun(
258
251
  app,
@@ -267,7 +260,7 @@ class Runner:
267
260
  app: AppDef,
268
261
  scheduler: str,
269
262
  cfg: Optional[Mapping[str, CfgVal]] = None,
270
- workspace: Optional[str] = None,
263
+ workspace: Optional[Union[Workspace, str]] = None,
271
264
  parent_run_id: Optional[str] = None,
272
265
  ) -> AppHandle:
273
266
  """
@@ -280,9 +273,7 @@ class Runner:
280
273
  An application handle that is used to call other action APIs on the app.
281
274
  """
282
275
 
283
- with log_event(
284
- api="run", runcfg=json.dumps(cfg) if cfg else None, workspace=workspace
285
- ) as ctx:
276
+ with log_event(api="run") as ctx:
286
277
  dryrun_info = self.dryrun(
287
278
  app,
288
279
  scheduler,
@@ -291,10 +282,15 @@ class Runner:
291
282
  parent_run_id=parent_run_id,
292
283
  )
293
284
  handle = self.schedule(dryrun_info)
294
- ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
295
- ctx._torchx_event.app_image = none_throws(dryrun_info._app).roles[0].image
296
- ctx._torchx_event.app_id = parse_app_handle(handle)[2]
297
- ctx._torchx_event.app_metadata = app.metadata
285
+
286
+ event = ctx._torchx_event
287
+ event.scheduler = scheduler
288
+ event.runcfg = json.dumps(cfg) if cfg else None
289
+ event.workspace = str(workspace)
290
+ event.app_id = parse_app_handle(handle)[2]
291
+ event.app_image = none_throws(dryrun_info._app).roles[0].image
292
+ event.app_metadata = app.metadata
293
+
298
294
  return handle
299
295
 
300
296
  def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle:
@@ -328,21 +324,22 @@ class Runner:
328
324
 
329
325
  """
330
326
  scheduler = none_throws(dryrun_info._scheduler)
331
- app_image = none_throws(dryrun_info._app).roles[0].image
332
327
  cfg = dryrun_info._cfg
333
- with log_event(
334
- "schedule",
335
- scheduler,
336
- app_image=app_image,
337
- runcfg=json.dumps(cfg) if cfg else None,
338
- ) as ctx:
328
+ with log_event("schedule") as ctx:
339
329
  sched = self._scheduler(scheduler)
340
330
  app_id = sched.schedule(dryrun_info)
341
331
  app_handle = make_app_handle(scheduler, self._name, app_id)
332
+
342
333
  app = none_throws(dryrun_info._app)
343
334
  self._apps[app_handle] = app
344
- _, _, app_id = parse_app_handle(app_handle)
345
- ctx._torchx_event.app_id = app_id
335
+
336
+ event = ctx._torchx_event
337
+ event.scheduler = scheduler
338
+ event.runcfg = json.dumps(cfg) if cfg else None
339
+ event.app_id = app_id
340
+ event.app_image = none_throws(dryrun_info._app).roles[0].image
341
+ event.app_metadata = app.metadata
342
+
346
343
  return app_handle
347
344
 
348
345
  def name(self) -> str:
@@ -353,7 +350,7 @@ class Runner:
353
350
  app: AppDef,
354
351
  scheduler: str,
355
352
  cfg: Optional[Mapping[str, CfgVal]] = None,
356
- workspace: Optional[str] = None,
353
+ workspace: Optional[Union[Workspace, str]] = None,
357
354
  parent_run_id: Optional[str] = None,
358
355
  ) -> AppDryRunInfo:
359
356
  """
@@ -422,52 +419,45 @@ class Runner:
422
419
  "dryrun",
423
420
  scheduler,
424
421
  runcfg=json.dumps(cfg) if cfg else None,
425
- workspace=workspace,
426
- ):
422
+ workspace=str(workspace),
423
+ ) as ctx:
427
424
  sched = self._scheduler(scheduler)
428
425
  resolved_cfg = sched.run_opts().resolve(cfg)
429
426
 
430
- # early validation before build workspace
431
- with log_event(
432
- "pre_build_validate",
433
- scheduler,
434
- ):
435
- sched._pre_build_validate(app, scheduler, resolved_cfg)
436
-
437
- if workspace and isinstance(sched, WorkspaceMixin):
438
- role = app.roles[0]
439
- old_img = role.image
440
-
441
- logger.info(f"Checking for changes in workspace `{workspace}`...")
442
- logger.info(
443
- 'To disable workspaces pass: --workspace="" from CLI or workspace=None programmatically.'
444
- )
445
- with log_event(
446
- "build_workspace_and_update_role",
447
- scheduler,
448
- ) as ctx:
449
- sched.build_workspace_and_update_role(role, workspace, resolved_cfg)
450
- ctx._torchx_event.app_image = role.image
451
- ctx._torchx_event.workspace = workspace
452
-
453
- if old_img != role.image:
454
- logger.info(
455
- f"Built new image `{role.image}` based on original image `{old_img}`"
456
- f" and changes in workspace `{workspace}` for role[0]={role.name}."
457
- )
458
- else:
459
- logger.info(
460
- f"Reusing original image `{old_img}` for role[0]={role.name}."
461
- " Either a patch was built or no changes to workspace was detected."
427
+ sched._pre_build_validate(app, scheduler, resolved_cfg)
428
+
429
+ if isinstance(sched, WorkspaceMixin):
430
+ if workspace:
431
+ # NOTE: torchx originally took workspace as a runner arg and only applied the workspace to role[0]
432
+ # later, torchx added support for the workspace attr in Role
433
+ # for BC, give precedence to the workspace argument over the workspace attr for role[0]
434
+ if app.roles[0].workspace:
435
+ logger.info(
436
+ "Overriding role[%d] (%s) workspace to `%s`"
437
+ "To use the role's workspace attr pass: --workspace='' from CLI or workspace=None programmatically.",
438
+ 0,
439
+ role.name,
440
+ str(app.roles[0].workspace),
441
+ )
442
+ app.roles[0].workspace = (
443
+ Workspace.from_str(workspace)
444
+ if isinstance(workspace, str)
445
+ else workspace
462
446
  )
463
447
 
464
- with log_event(
465
- "validate",
466
- scheduler,
467
- ):
468
- sched._validate(app, scheduler, resolved_cfg)
448
+ sched.build_workspaces(app.roles, resolved_cfg)
449
+
450
+ sched._validate(app, scheduler, resolved_cfg)
469
451
  dryrun_info = sched.submit_dryrun(app, resolved_cfg)
470
452
  dryrun_info._scheduler = scheduler
453
+
454
+ event = ctx._torchx_event
455
+ event.scheduler = scheduler
456
+ event.runcfg = json.dumps(cfg) if cfg else None
457
+ event.app_id = app.name
458
+ event.app_image = none_throws(dryrun_info._app).roles[0].image
459
+ event.app_metadata = app.metadata
460
+
471
461
  return dryrun_info
472
462
 
473
463
  def scheduler_run_opts(self, scheduler: str) -> runopts:
@@ -486,6 +476,27 @@ class Runner:
486
476
  """
487
477
  return self._scheduler(scheduler).run_opts()
488
478
 
479
+ def cfg_from_str(self, scheduler: str, *cfg_literal: str) -> Mapping[str, CfgVal]:
480
+ """
481
+ Convenience function around the scheduler's ``runopts.cfg_from_str()`` method.
482
+
483
+ Usage:
484
+
485
+ .. doctest::
486
+
487
+ from torchx.runner import get_runner
488
+
489
+ runner = get_runner()
490
+ cfg = runner.cfg_from_str("local_cwd", "log_dir=/tmp/foobar", "prepend_cwd=True")
491
+ assert cfg == {"log_dir": "/tmp/foobar", "prepend_cwd": True, "auto_set_cuda_visible_devices": False}
492
+ """
493
+
494
+ opts = self._scheduler(scheduler).run_opts()
495
+ cfg = {}
496
+ for cfg_str in cfg_literal:
497
+ cfg.update(opts.cfg_from_str(cfg_str))
498
+ return cfg
499
+
489
500
  def scheduler_backends(self) -> List[str]:
490
501
  """
491
502
  Returns a list of all supported scheduler backends.
torchx/runner/config.py CHANGED
@@ -73,7 +73,7 @@ CLI Usage
73
73
 
74
74
  #. In addition, it is possible to specify a different config other than .torchxconfig to
75
75
  load at runtime. Requirements are that the config path is specified by enviornment
76
- variable TORCHX_CONFIG. It also disables hierarchy loading configs from multiple
76
+ variable TORCHXCONFIG. It also disables hierarchy loading configs from multiple
77
77
  directories as the cases otherwise.
78
78
 
79
79
  #. User level .torchxconfig
@@ -278,14 +278,14 @@ def dump(
278
278
  continue
279
279
 
280
280
  # serialize list elements with `;` delimiter (consistent with torchx cli)
281
- if opt.opt_type == List[str]:
281
+ if opt.is_type_list_of_str:
282
282
  # deal with empty or None default lists
283
283
  if opt.default:
284
284
  # pyre-ignore[6] opt.default type checked already as List[str]
285
285
  val = ";".join(opt.default)
286
286
  else:
287
287
  val = _NONE
288
- elif opt.opt_type == Dict[str, str]:
288
+ elif opt.is_type_dict_of_str:
289
289
  # deal with empty or None default lists
290
290
  if opt.default:
291
291
  # pyre-ignore[16] opt.default type checked already as Dict[str, str]
@@ -494,6 +494,8 @@ def find_configs(dirs: Optional[Iterable[str]] = None) -> List[str]:
494
494
 
495
495
  config = os.getenv(ENV_TORCHXCONFIG)
496
496
  if config is not None:
497
+ if not config:
498
+ return []
497
499
  configfile = Path(config)
498
500
  if not configfile.is_file():
499
501
  raise FileNotFoundError(
@@ -536,26 +538,26 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
536
538
  # this also handles empty or None lists
537
539
  cfg[name] = None
538
540
  else:
539
- runopt = runopts.get(name)
541
+ opt = runopts.get(name)
540
542
 
541
- if runopt is None:
543
+ if opt is None:
542
544
  log.warning(
543
545
  f"`{name} = {value}` was declared in the [{section}] section "
544
546
  f" of the config file but is not a runopt of `{scheduler}` scheduler."
545
547
  f" Remove the entry from the config file to no longer see this warning"
546
548
  )
547
549
  else:
548
- if runopt.opt_type is bool:
550
+ if opt.opt_type is bool:
549
551
  # need to handle bool specially since str -> bool is based on
550
552
  # str emptiness not value (e.g. bool("False") == True)
551
553
  cfg[name] = config.getboolean(section, name)
552
- elif runopt.opt_type is List[str]:
554
+ elif opt.is_type_list_of_str:
553
555
  cfg[name] = value.split(";")
554
- elif runopt.opt_type is Dict[str, str]:
556
+ elif opt.is_type_dict_of_str:
555
557
  cfg[name] = {
556
558
  s.split(":", 1)[0]: s.split(":", 1)[1]
557
559
  for s in value.replace(",", ";").split(";")
558
560
  }
559
561
  else:
560
562
  # pyre-ignore[29]
561
- cfg[name] = runopt.opt_type(value)
563
+ cfg[name] = opt.opt_type(value)
@@ -33,8 +33,9 @@ from torchx.util.session import get_session_id_or_create_new
33
33
 
34
34
  from .api import SourceType, TorchxEvent # noqa F401
35
35
 
36
- # pyre-fixme[9]: _events_logger is a global variable
37
- _events_logger: logging.Logger = None
36
+ _events_logger: Optional[logging.Logger] = None
37
+
38
+ log: logging.Logger = logging.getLogger(__name__)
38
39
 
39
40
 
40
41
  def _get_or_create_logger(destination: str = "null") -> logging.Logger:
@@ -51,19 +52,28 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger:
51
52
  a new logger if None provided.
52
53
  """
53
54
  global _events_logger
55
+
54
56
  if _events_logger:
55
57
  return _events_logger
56
- logging_handler = get_logging_handler(destination)
57
- logging_handler.setLevel(logging.DEBUG)
58
- _events_logger = logging.getLogger(f"torchx-events-{destination}")
59
- # Do not propagate message to the root logger
60
- _events_logger.propagate = False
61
- _events_logger.addHandler(logging_handler)
62
- return _events_logger
58
+ else:
59
+ logging_handler = get_logging_handler(destination)
60
+ logging_handler.setLevel(logging.DEBUG)
61
+ _events_logger = logging.getLogger(f"torchx-events-{destination}")
62
+ # Do not propagate message to the root logger
63
+ _events_logger.propagate = False
64
+ _events_logger.addHandler(logging_handler)
65
+
66
+ assert _events_logger # make type-checker happy
67
+ return _events_logger
63
68
 
64
69
 
65
70
  def record(event: TorchxEvent, destination: str = "null") -> None:
66
- _get_or_create_logger(destination).info(event.serialize())
71
+ try:
72
+ serialized_event = event.serialize()
73
+ except Exception:
74
+ log.exception("failed to serialize event, will not record event")
75
+ else:
76
+ _get_or_create_logger(destination).info(serialized_event)
67
77
 
68
78
 
69
79
  class log_event:
@@ -29,7 +29,7 @@ class TorchxEvent:
29
29
  scheduler: Scheduler that is used to execute request
30
30
  api: Api name
31
31
  app_id: Unique id that is set by the underlying scheduler
32
- image: Image/container bundle that is used to execute request.
32
+ app_image: Image/container bundle that is used to execute request.
33
33
  app_metadata: metadata to the app (treatment of metadata is scheduler dependent)
34
34
  runcfg: Run config that was used to schedule app.
35
35
  source: Type of source the event is generated.
@@ -21,8 +21,6 @@ DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
21
21
  "kubernetes_mcad": "torchx.schedulers.kubernetes_mcad_scheduler",
22
22
  "aws_batch": "torchx.schedulers.aws_batch_scheduler",
23
23
  "aws_sagemaker": "torchx.schedulers.aws_sagemaker_scheduler",
24
- "gcp_batch": "torchx.schedulers.gcp_batch_scheduler",
25
- "ray": "torchx.schedulers.ray_scheduler",
26
24
  "lsf": "torchx.schedulers.lsf_scheduler",
27
25
  }
28
26
 
@@ -51,15 +49,14 @@ def get_scheduler_factories(
51
49
  The first scheduler in the dictionary is used as the default scheduler.
52
50
  """
53
51
 
54
- default_schedulers: dict[str, SchedulerFactory] = {}
55
- for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
56
- default_schedulers[scheduler] = _defer_load_scheduler(path)
52
+ if skip_defaults:
53
+ default_schedulers = {}
54
+ else:
55
+ default_schedulers: dict[str, SchedulerFactory] = {}
56
+ for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57
+ default_schedulers[scheduler] = _defer_load_scheduler(path)
57
58
 
58
- return load_group(
59
- group,
60
- default=default_schedulers,
61
- skip_defaults=skip_defaults,
62
- )
59
+ return load_group(group, default=default_schedulers)
63
60
 
64
61
 
65
62
  def get_default_scheduler_name() -> str:
torchx/schedulers/api.py CHANGED
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
@@ -12,7 +11,7 @@ import re
12
11
  from dataclasses import dataclass, field
13
12
  from datetime import datetime
14
13
  from enum import Enum
15
- from typing import Generic, Iterable, List, Optional, TypeVar
14
+ from typing import Generic, Iterable, List, Optional, TypeVar, Union
16
15
 
17
16
  from torchx.specs import (
18
17
  AppDef,
@@ -22,8 +21,9 @@ from torchx.specs import (
22
21
  Role,
23
22
  RoleStatus,
24
23
  runopts,
24
+ Workspace,
25
25
  )
26
- from torchx.workspace.api import WorkspaceMixin
26
+ from torchx.workspace import WorkspaceMixin
27
27
 
28
28
 
29
29
  DAYS_IN_2_WEEKS = 14
@@ -131,7 +131,7 @@ class Scheduler(abc.ABC, Generic[T, A, D]):
131
131
  self,
132
132
  app: A,
133
133
  cfg: T,
134
- workspace: Optional[str] = None,
134
+ workspace: str | Workspace | None = None,
135
135
  ) -> str:
136
136
  """
137
137
  Submits the application to be run by the scheduler.
@@ -144,10 +144,14 @@ class Scheduler(abc.ABC, Generic[T, A, D]):
144
144
  # pyre-fixme: Generic cfg type passed to resolve
145
145
  resolved_cfg = self.run_opts().resolve(cfg)
146
146
  if workspace:
147
- sched = self
148
- assert isinstance(sched, WorkspaceMixin)
149
- role = app.roles[0]
150
- sched.build_workspace_and_update_role(role, workspace, resolved_cfg)
147
+ assert isinstance(self, WorkspaceMixin)
148
+
149
+ if isinstance(workspace, str):
150
+ workspace = Workspace.from_str(workspace)
151
+
152
+ app.roles[0].workspace = workspace
153
+ self.build_workspaces(app.roles, resolved_cfg)
154
+
151
155
  # pyre-fixme: submit_dryrun takes Generic type for resolved_cfg
152
156
  dryrun_info = self.submit_dryrun(app, resolved_cfg)
153
157
  return self.schedule(dryrun_info)
@@ -356,13 +360,14 @@ class Scheduler(abc.ABC, Generic[T, A, D]):
356
360
 
357
361
  Raises error if application is not compatible with scheduler
358
362
  """
359
- if isinstance(app, AppDef):
360
- for role in app.roles:
361
- if role.resource == NULL_RESOURCE:
362
- raise ValueError(
363
- f"No resource for role: {role.image}."
364
- f" Did you forget to attach resource to the role"
365
- )
363
+ if not isinstance(app, AppDef):
364
+ return
365
+
366
+ for role in app.roles:
367
+ if role.resource == NULL_RESOURCE:
368
+ raise ValueError(
369
+ f"No resource for role: {role.image}. Did you forget to attach resource to the role"
370
+ )
366
371
 
367
372
 
368
373
  def filter_regex(regex: str, data: Iterable[str]) -> Iterable[str]:
@@ -92,6 +92,8 @@ ENV_TORCHX_ROLE_IDX = "TORCHX_ROLE_IDX"
92
92
 
93
93
  ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME"
94
94
 
95
+ ENV_TORCHX_IMAGE = "TORCHX_IMAGE"
96
+
95
97
  DEFAULT_ROLE_NAME = "node"
96
98
 
97
99
  TAG_TORCHX_VER = "torchx.pytorch.org/version"
@@ -99,6 +101,37 @@ TAG_TORCHX_APPNAME = "torchx.pytorch.org/app-name"
99
101
  TAG_TORCHX_USER = "torchx.pytorch.org/user"
100
102
 
101
103
 
104
+ def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]:
105
+ """
106
+ Parse ulimit string in format: name:softLimit:hardLimit
107
+ Multiple ulimits separated by commas.
108
+ """
109
+ if not ulimits_list:
110
+ return []
111
+
112
+ ulimits = []
113
+ for ulimit_str in ulimits_list:
114
+ if not ulimit_str.strip():
115
+ continue
116
+
117
+ parts = ulimit_str.strip().split(":")
118
+ if len(parts) != 3:
119
+ raise ValueError(
120
+ f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}"
121
+ )
122
+
123
+ name, soft_limit, hard_limit = parts
124
+ ulimits.append(
125
+ {
126
+ "name": name,
127
+ "softLimit": int(soft_limit) if soft_limit != "-1" else -1,
128
+ "hardLimit": int(hard_limit) if hard_limit != "-1" else -1,
129
+ }
130
+ )
131
+
132
+ return ulimits
133
+
134
+
102
135
  if TYPE_CHECKING:
103
136
  from docker import DockerClient
104
137
 
@@ -177,7 +210,8 @@ def _role_to_node_properties(
177
210
  privileged: bool = False,
178
211
  job_role_arn: Optional[str] = None,
179
212
  execution_role_arn: Optional[str] = None,
180
- ) -> Dict[str, object]:
213
+ ulimits: Optional[List[Dict[str, Any]]] = None,
214
+ ) -> Dict[str, Any]:
181
215
  role.mounts += get_device_mounts(role.resource.devices)
182
216
 
183
217
  mount_points = []
@@ -239,6 +273,7 @@ def _role_to_node_properties(
239
273
  "environment": [{"name": k, "value": v} for k, v in role.env.items()],
240
274
  "privileged": privileged,
241
275
  "resourceRequirements": resource_requirements_from_resource(role.resource),
276
+ **({"ulimits": ulimits} if ulimits else {}),
242
277
  "linuxParameters": {
243
278
  # To support PyTorch dataloaders we need to set /dev/shm to larger
244
279
  # than the 64M default.
@@ -255,7 +290,7 @@ def _role_to_node_properties(
255
290
  container["jobRoleArn"] = job_role_arn
256
291
  if execution_role_arn:
257
292
  container["executionRoleArn"] = execution_role_arn
258
- if role.num_replicas > 1:
293
+ if role.num_replicas > 0:
259
294
  instance_type = instance_type_from_resource(role.resource)
260
295
  if instance_type is not None:
261
296
  container["instanceType"] = instance_type
@@ -361,6 +396,7 @@ class AWSBatchOpts(TypedDict, total=False):
361
396
  priority: int
362
397
  job_role_arn: Optional[str]
363
398
  execution_role_arn: Optional[str]
399
+ ulimits: Optional[list[str]]
364
400
 
365
401
 
366
402
  class AWSBatchScheduler(
@@ -506,6 +542,7 @@ class AWSBatchScheduler(
506
542
  role = values.apply(role)
507
543
  role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx)
508
544
  role.env[ENV_TORCHX_ROLE_NAME] = str(role.name)
545
+ role.env[ENV_TORCHX_IMAGE] = role.image
509
546
 
510
547
  nodes.append(
511
548
  _role_to_node_properties(
@@ -514,6 +551,7 @@ class AWSBatchScheduler(
514
551
  privileged=cfg["privileged"],
515
552
  job_role_arn=cfg.get("job_role_arn"),
516
553
  execution_role_arn=cfg.get("execution_role_arn"),
554
+ ulimits=parse_ulimits(cfg.get("ulimits") or []),
517
555
  )
518
556
  )
519
557
  node_idx += role.num_replicas
@@ -599,6 +637,11 @@ class AWSBatchScheduler(
599
637
  type_=str,
600
638
  help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
601
639
  )
640
+ opts.add(
641
+ "ulimits",
642
+ type_=List[str],
643
+ help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)",
644
+ )
602
645
  return opts
603
646
 
604
647
  def _get_job_id(self, app_id: str) -> Optional[str]:
@@ -84,6 +84,8 @@ LABEL_APP_ID: str = "torchx.pytorch.org/app-id"
84
84
  LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name"
85
85
  LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id"
86
86
 
87
+ ENV_TORCHX_IMAGE: str = "TORCHX_IMAGE"
88
+
87
89
  NETWORK = "torchx"
88
90
 
89
91
 
@@ -279,6 +281,7 @@ class DockerScheduler(
279
281
 
280
282
  # configure distributed host envs
281
283
  env["TORCHX_RANK0_HOST"] = rank0_name
284
+ env[ENV_TORCHX_IMAGE] = replica_role.image
282
285
 
283
286
  c = DockerContainer(
284
287
  image=replica_role.image,