torchx-nightly 2024.2.12__py3-none-any.whl → 2025.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (102) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/apps/serve/serve.py +2 -0
  3. torchx/apps/utils/booth_main.py +2 -0
  4. torchx/apps/utils/copy_main.py +2 -0
  5. torchx/apps/utils/process_monitor.py +2 -0
  6. torchx/cli/__init__.py +2 -0
  7. torchx/cli/argparse_util.py +38 -3
  8. torchx/cli/cmd_base.py +2 -0
  9. torchx/cli/cmd_cancel.py +2 -0
  10. torchx/cli/cmd_configure.py +2 -0
  11. torchx/cli/cmd_describe.py +2 -0
  12. torchx/cli/cmd_list.py +2 -0
  13. torchx/cli/cmd_log.py +6 -24
  14. torchx/cli/cmd_run.py +30 -12
  15. torchx/cli/cmd_runopts.py +2 -0
  16. torchx/cli/cmd_status.py +2 -0
  17. torchx/cli/cmd_tracker.py +2 -0
  18. torchx/cli/colors.py +2 -0
  19. torchx/cli/main.py +2 -0
  20. torchx/components/__init__.py +2 -0
  21. torchx/components/component_test_base.py +2 -0
  22. torchx/components/dist.py +2 -0
  23. torchx/components/integration_tests/component_provider.py +2 -0
  24. torchx/components/integration_tests/integ_tests.py +2 -0
  25. torchx/components/serve.py +2 -0
  26. torchx/components/structured_arg.py +2 -0
  27. torchx/components/utils.py +2 -0
  28. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  29. torchx/examples/apps/lightning/data.py +5 -3
  30. torchx/examples/apps/lightning/model.py +2 -0
  31. torchx/examples/apps/lightning/profiler.py +7 -4
  32. torchx/examples/apps/lightning/train.py +2 -0
  33. torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
  34. torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
  35. torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/pipelines/kfp/__init__.py +2 -0
  39. torchx/pipelines/kfp/adapter.py +7 -4
  40. torchx/pipelines/kfp/version.py +2 -0
  41. torchx/runner/__init__.py +2 -0
  42. torchx/runner/api.py +78 -20
  43. torchx/runner/config.py +34 -3
  44. torchx/runner/events/__init__.py +37 -3
  45. torchx/runner/events/api.py +13 -2
  46. torchx/runner/events/handlers.py +2 -0
  47. torchx/runtime/tracking/__init__.py +2 -0
  48. torchx/runtime/tracking/api.py +2 -0
  49. torchx/schedulers/__init__.py +10 -5
  50. torchx/schedulers/api.py +3 -1
  51. torchx/schedulers/aws_batch_scheduler.py +4 -0
  52. torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
  53. torchx/schedulers/devices.py +17 -4
  54. torchx/schedulers/docker_scheduler.py +38 -8
  55. torchx/schedulers/gcp_batch_scheduler.py +8 -9
  56. torchx/schedulers/ids.py +2 -0
  57. torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
  58. torchx/schedulers/kubernetes_scheduler.py +31 -5
  59. torchx/schedulers/local_scheduler.py +45 -6
  60. torchx/schedulers/lsf_scheduler.py +3 -1
  61. torchx/schedulers/ray/ray_driver.py +7 -7
  62. torchx/schedulers/ray_scheduler.py +1 -1
  63. torchx/schedulers/slurm_scheduler.py +3 -1
  64. torchx/schedulers/streams.py +2 -0
  65. torchx/specs/__init__.py +49 -8
  66. torchx/specs/api.py +87 -5
  67. torchx/specs/builders.py +61 -19
  68. torchx/specs/file_linter.py +8 -2
  69. torchx/specs/finder.py +2 -0
  70. torchx/specs/named_resources_aws.py +109 -2
  71. torchx/specs/named_resources_generic.py +2 -0
  72. torchx/specs/test/components/__init__.py +2 -0
  73. torchx/specs/test/components/a/__init__.py +2 -0
  74. torchx/specs/test/components/a/b/__init__.py +2 -0
  75. torchx/specs/test/components/a/b/c.py +2 -0
  76. torchx/specs/test/components/c/__init__.py +2 -0
  77. torchx/specs/test/components/c/d.py +2 -0
  78. torchx/tracker/__init__.py +2 -0
  79. torchx/tracker/api.py +4 -4
  80. torchx/tracker/backend/fsspec.py +2 -0
  81. torchx/util/cuda.py +2 -0
  82. torchx/util/datetime.py +2 -0
  83. torchx/util/entrypoints.py +6 -2
  84. torchx/util/io.py +2 -0
  85. torchx/util/log_tee_helpers.py +210 -0
  86. torchx/util/modules.py +2 -0
  87. torchx/util/session.py +42 -0
  88. torchx/util/shlex.py +2 -0
  89. torchx/util/strings.py +2 -0
  90. torchx/util/types.py +20 -2
  91. torchx/version.py +3 -1
  92. torchx/workspace/__init__.py +2 -0
  93. torchx/workspace/api.py +34 -1
  94. torchx/workspace/dir_workspace.py +2 -0
  95. torchx/workspace/docker_workspace.py +25 -2
  96. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/METADATA +55 -48
  97. torchx_nightly-2025.1.14.dist-info/RECORD +123 -0
  98. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/WHEEL +1 -1
  99. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/entry_points.txt +0 -1
  100. torchx_nightly-2024.2.12.dist-info/RECORD +0 -119
  101. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/LICENSE +0 -0
  102. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import json
9
11
  import os
10
12
  import os.path
@@ -48,7 +50,9 @@ def component_spec_from_app(app: api.AppDef) -> Tuple[str, api.Role]:
48
50
 
49
51
  role = app.roles[0]
50
52
  assert (
51
- role.num_replicas == 1
53
+ role.num_replicas
54
+ == 1
55
+ # pyre-fixme[16]: `AppDef` has no attribute `num_replicas`.
52
56
  ), f"KFP adapter only supports one replica, got {app.num_replicas}"
53
57
 
54
58
  command = [role.entrypoint, *role.args]
@@ -74,8 +78,7 @@ class ContainerFactory(Protocol):
74
78
  kfp.dsl.ContainerOp.
75
79
  """
76
80
 
77
- def __call__(self, *args: object, **kwargs: object) -> dsl.ContainerOp:
78
- ...
81
+ def __call__(self, *args: object, **kwargs: object) -> dsl.ContainerOp: ...
79
82
 
80
83
 
81
84
  class KFPContainerFactory(ContainerFactory, Protocol):
@@ -104,7 +107,7 @@ def component_from_app(
104
107
  app: The AppDef to generate a KFP container factory for.
105
108
  ui_metadata: KFP UI Metadata to output so you can have model results show
106
109
  up in the UI. See
107
- https://www.kubeflow.org/docs/components/pipelines/sdk/output-viewer/
110
+ https://www.kubeflow.org/docs/components/pipelines/legacy-v1/sdk/output-viewer/
108
111
  for more info on the format.
109
112
 
110
113
  >>> from torchx import specs
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  # Follows PEP-0440 version scheme guidelines
9
11
  # https://www.python.org/dev/peps/pep-0440/#version-scheme
10
12
  #
torchx/runner/__init__.py CHANGED
@@ -5,4 +5,6 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  from torchx.runner.api import get_runner, Runner # noqa: F401 F403
torchx/runner/api.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import json
9
11
  import logging
10
12
  import os
@@ -12,7 +14,7 @@ import time
12
14
  import warnings
13
15
  from datetime import datetime
14
16
  from types import TracebackType
15
- from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type
17
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar
16
18
 
17
19
  from torchx.runner.events import log_event
18
20
  from torchx.schedulers import get_scheduler_factories, SchedulerFactory
@@ -37,9 +39,10 @@ from torchx.tracker.api import (
37
39
  ENV_TORCHX_TRACKERS,
38
40
  tracker_config_env_var_name,
39
41
  )
42
+ from torchx.util.session import get_session_id_or_create_new, TORCHX_INTERNAL_SESSION_ID
40
43
 
41
44
  from torchx.util.types import none_throws
42
- from torchx.workspace.api import WorkspaceMixin
45
+ from torchx.workspace.api import PkgInfo, WorkspaceBuilder, WorkspaceMixin
43
46
 
44
47
  from .config import get_config, get_configs
45
48
 
@@ -47,6 +50,8 @@ logger: logging.Logger = logging.getLogger(__name__)
47
50
 
48
51
 
49
52
  NONE: str = "<NONE>"
53
+ S = TypeVar("S")
54
+ T = TypeVar("T")
50
55
 
51
56
 
52
57
  def get_configured_trackers() -> Dict[str, Optional[str]]:
@@ -96,7 +101,10 @@ class Runner:
96
101
  """
97
102
  self._name: str = name
98
103
  self._scheduler_factories = scheduler_factories
99
- self._scheduler_params: Dict[str, object] = scheduler_params or {}
104
+ self._scheduler_params: Dict[str, Any] = {
105
+ **(self._get_scheduler_params_from_env()),
106
+ **(scheduler_params or {}),
107
+ }
100
108
  # pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
101
109
  self._scheduler_instances: Dict[str, Scheduler] = {}
102
110
  self._apps: Dict[AppHandle, AppDef] = {}
@@ -104,6 +112,14 @@ class Runner:
104
112
  # component_name -> map of component_fn_param_name -> user-specified default val encoded as str
105
113
  self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}
106
114
 
115
+ def _get_scheduler_params_from_env(self) -> Dict[str, str]:
116
+ scheduler_params = {}
117
+ for key, value in os.environ.items():
118
+ lower_case_key = key.lower()
119
+ if lower_case_key.startswith("torchx_"):
120
+ scheduler_params[lower_case_key.strip("torchx_")] = value
121
+ return scheduler_params
122
+
107
123
  def __enter__(self) -> "Runner":
108
124
  return self
109
125
 
@@ -131,9 +147,21 @@ class Runner:
131
147
  It is ok to call this method multiple times on the same runner object.
132
148
  """
133
149
 
134
- for name, scheduler in self._scheduler_instances.items():
150
+ for scheduler in self._scheduler_instances.values():
135
151
  scheduler.close()
136
152
 
153
+ def build_standalone_workspace(
154
+ self,
155
+ workspace_builder: WorkspaceBuilder[S, T],
156
+ sync: bool = True,
157
+ ) -> PkgInfo[S]:
158
+ """
159
+ Build a standalone workspace for the given role.
160
+ This method is used to build a workspace for a role independent of the scheduler and
161
+ also enables asynchronous workspace building using the Role overrides.
162
+ """
163
+ return workspace_builder.build_workspace(sync)
164
+
137
165
  def run_component(
138
166
  self,
139
167
  component: str,
@@ -175,15 +203,23 @@ class Runner:
175
203
  ComponentNotFoundException: if the ``component_path`` is failed to resolve.
176
204
  """
177
205
 
178
- dryrun_info = self.dryrun_component(
179
- component,
180
- component_args,
181
- scheduler,
182
- cfg=cfg,
183
- workspace=workspace,
184
- parent_run_id=parent_run_id,
185
- )
186
- return self.schedule(dryrun_info)
206
+ with log_event("run_component", workspace=workspace) as ctx:
207
+ dryrun_info = self.dryrun_component(
208
+ component,
209
+ component_args,
210
+ scheduler,
211
+ cfg=cfg,
212
+ workspace=workspace,
213
+ parent_run_id=parent_run_id,
214
+ )
215
+ handle = self.schedule(dryrun_info)
216
+ app = none_throws(dryrun_info._app)
217
+ ctx._torchx_event.workspace = workspace
218
+ ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
219
+ ctx._torchx_event.app_image = app.roles[0].image
220
+ ctx._torchx_event.app_id = parse_app_handle(handle)[2]
221
+ ctx._torchx_event.app_metadata = app.metadata
222
+ return handle
187
223
 
188
224
  def dryrun_component(
189
225
  self,
@@ -230,10 +266,22 @@ class Runner:
230
266
  An application handle that is used to call other action APIs on the app.
231
267
  """
232
268
 
233
- dryrun_info = self.dryrun(
234
- app, scheduler, cfg=cfg, workspace=workspace, parent_run_id=parent_run_id
235
- )
236
- return self.schedule(dryrun_info)
269
+ with log_event(
270
+ api="run", runcfg=json.dumps(cfg) if cfg else None, workspace=workspace
271
+ ) as ctx:
272
+ dryrun_info = self.dryrun(
273
+ app,
274
+ scheduler,
275
+ cfg=cfg,
276
+ workspace=workspace,
277
+ parent_run_id=parent_run_id,
278
+ )
279
+ handle = self.schedule(dryrun_info)
280
+ ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
281
+ ctx._torchx_event.app_image = none_throws(dryrun_info._app).roles[0].image
282
+ ctx._torchx_event.app_id = parse_app_handle(handle)[2]
283
+ ctx._torchx_event.app_metadata = app.metadata
284
+ return handle
237
285
 
238
286
  def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle:
239
287
  """
@@ -343,6 +391,7 @@ class Runner:
343
391
  role.env[ENV_TORCHX_JOB_ID] = make_app_handle(
344
392
  scheduler, self._name, macros.app_id
345
393
  )
394
+ role.env[TORCHX_INTERNAL_SESSION_ID] = get_session_id_or_create_new()
346
395
 
347
396
  if parent_run_id:
348
397
  role.env[ENV_TORCHX_PARENT_RUN_ID] = parent_run_id
@@ -355,7 +404,12 @@ class Runner:
355
404
  role.env[tracker_config_env_var_name(name)] = config
356
405
 
357
406
  cfg = cfg or dict()
358
- with log_event("dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None):
407
+ with log_event(
408
+ "dryrun",
409
+ scheduler,
410
+ runcfg=json.dumps(cfg) if cfg else None,
411
+ workspace=workspace,
412
+ ):
359
413
  sched = self._scheduler(scheduler)
360
414
  resolved_cfg = sched.run_opts().resolve(cfg)
361
415
  if workspace and isinstance(sched, WorkspaceMixin):
@@ -379,7 +433,11 @@ class Runner:
379
433
  " Either a patch was built or no changes to workspace was detected."
380
434
  )
381
435
 
382
- sched._validate(app, scheduler)
436
+ with log_event(
437
+ "validate",
438
+ scheduler,
439
+ ):
440
+ sched._validate(app, scheduler, resolved_cfg)
383
441
  dryrun_info = sched.submit_dryrun(app, resolved_cfg)
384
442
  dryrun_info._scheduler = scheduler
385
443
  return dryrun_info
@@ -654,7 +712,7 @@ class Runner:
654
712
  def _scheduler_app_id(
655
713
  self,
656
714
  app_handle: AppHandle,
657
- check_session: bool = True
715
+ check_session: bool = True,
658
716
  # pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
659
717
  ) -> Tuple[Scheduler, str, str]:
660
718
  """
torchx/runner/config.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  Status: Beta
10
12
 
@@ -195,7 +197,15 @@ def _configparser() -> configparser.ConfigParser:
195
197
 
196
198
 
197
199
  def _get_scheduler(name: str) -> Scheduler:
198
- schedulers = get_scheduler_factories()
200
+ schedulers = {
201
+ **get_scheduler_factories(),
202
+ **(
203
+ get_scheduler_factories(
204
+ group="torchx.schedulers.orchestrator", skip_defaults=True
205
+ )
206
+ or {}
207
+ ),
208
+ }
199
209
  if name not in schedulers:
200
210
  raise ValueError(
201
211
  f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}"
@@ -239,7 +249,16 @@ def dump(
239
249
  if schedulers:
240
250
  scheds = schedulers
241
251
  else:
242
- scheds = get_scheduler_factories().keys()
252
+ scheduler_factories = {
253
+ **get_scheduler_factories(),
254
+ **(
255
+ get_scheduler_factories(
256
+ group="torchx.schedulers.orchestrator", skip_defaults=True
257
+ )
258
+ or {}
259
+ ),
260
+ }
261
+ scheds = scheduler_factories.keys()
243
262
 
244
263
  config = _configparser()
245
264
  for sched_name in scheds:
@@ -266,6 +285,13 @@ def dump(
266
285
  val = ";".join(opt.default)
267
286
  else:
268
287
  val = _NONE
288
+ elif opt.opt_type == Dict[str, str]:
289
+ # deal with empty or None default lists
290
+ if opt.default:
291
+ # pyre-ignore[16] opt.default type checked already as Dict[str, str]
292
+ val = ";".join([f"{k}:{v}" for k, v in opt.default.items()])
293
+ else:
294
+ val = _NONE
269
295
  else:
270
296
  val = f"{opt.default}"
271
297
 
@@ -480,7 +506,7 @@ def find_configs(dirs: Optional[Iterable[str]] = None) -> List[str]:
480
506
  dirs = DEFAULT_CONFIG_DIRS
481
507
  for d in dirs:
482
508
  configfile = Path(d) / CONFIG_FILE
483
- if configfile.exists():
509
+ if os.access(configfile, os.R_OK):
484
510
  config_files.append(str(configfile))
485
511
  return config_files
486
512
 
@@ -525,6 +551,11 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
525
551
  cfg[name] = config.getboolean(section, name)
526
552
  elif runopt.opt_type is List[str]:
527
553
  cfg[name] = value.split(";")
554
+ elif runopt.opt_type is Dict[str, str]:
555
+ cfg[name] = {
556
+ s.split(":", 1)[0]: s.split(":", 1)[1]
557
+ for s in value.replace(",", ";").split(";")
558
+ }
528
559
  else:
529
560
  # pyre-ignore[29]
530
561
  cfg[name] = runopt.opt_type(value)
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  Module contains events processing mechanisms that are integrated with the standard python logging.
10
12
 
@@ -18,13 +20,16 @@ Example of usage:
18
20
 
19
21
  """
20
22
 
23
+ import json
21
24
  import logging
25
+ import sys
22
26
  import time
23
27
  import traceback
24
28
  from types import TracebackType
25
- from typing import Optional, Type
29
+ from typing import Dict, Optional, Type
26
30
 
27
31
  from torchx.runner.events.handlers import get_logging_handler
32
+ from torchx.util.session import get_session_id_or_create_new
28
33
 
29
34
  from .api import SourceType, TorchxEvent # noqa F401
30
35
 
@@ -82,17 +87,28 @@ class log_event:
82
87
  scheduler: Optional[str] = None,
83
88
  app_id: Optional[str] = None,
84
89
  app_image: Optional[str] = None,
90
+ app_metadata: Optional[Dict[str, str]] = None,
85
91
  runcfg: Optional[str] = None,
92
+ workspace: Optional[str] = None,
86
93
  ) -> None:
87
94
  self._torchx_event: TorchxEvent = self._generate_torchx_event(
88
- api, scheduler or "", app_id, app_image=app_image, runcfg=runcfg
95
+ api,
96
+ scheduler or "",
97
+ app_id,
98
+ app_image=app_image,
99
+ app_metadata=app_metadata,
100
+ runcfg=runcfg,
101
+ workspace=workspace,
89
102
  )
90
103
  self._start_cpu_time_ns = 0
91
104
  self._start_wall_time_ns = 0
105
+ self._start_epoch_time_usec = 0
92
106
 
93
107
  def __enter__(self) -> "log_event":
94
108
  self._start_cpu_time_ns = time.process_time_ns()
95
109
  self._start_wall_time_ns = time.perf_counter_ns()
110
+ self._torchx_event.start_epoch_time_usec = int(time.time() * 1_000_000)
111
+
96
112
  return self
97
113
 
98
114
  def __exit__(
@@ -109,6 +125,20 @@ class log_event:
109
125
  ) // 1000
110
126
  if traceback_type:
111
127
  self._torchx_event.raw_exception = traceback.format_exc()
128
+ typ, value, tb = sys.exc_info()
129
+ if tb:
130
+ last_frame = traceback.extract_tb(tb)[-1]
131
+ self._torchx_event.exception_source_location = json.dumps(
132
+ {
133
+ "filename": last_frame.filename,
134
+ "lineno": last_frame.lineno,
135
+ "name": last_frame.name,
136
+ }
137
+ )
138
+ if exec_type:
139
+ self._torchx_event.exception_type = exec_type.__name__
140
+ if exec_value:
141
+ self._torchx_event.exception_message = str(exec_value)
112
142
  record(self._torchx_event)
113
143
 
114
144
  def _generate_torchx_event(
@@ -117,15 +147,19 @@ class log_event:
117
147
  scheduler: str,
118
148
  app_id: Optional[str] = None,
119
149
  app_image: Optional[str] = None,
150
+ app_metadata: Optional[Dict[str, str]] = None,
120
151
  runcfg: Optional[str] = None,
121
152
  source: SourceType = SourceType.UNKNOWN,
153
+ workspace: Optional[str] = None,
122
154
  ) -> TorchxEvent:
123
155
  return TorchxEvent(
124
- session=app_id or "",
156
+ session=get_session_id_or_create_new(),
125
157
  scheduler=scheduler,
126
158
  api=api,
127
159
  app_id=app_id,
128
160
  app_image=app_image,
161
+ app_metadata=app_metadata,
129
162
  runcfg=runcfg,
130
163
  source=source,
164
+ workspace=workspace,
131
165
  )
@@ -5,10 +5,12 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import json
9
11
  from dataclasses import asdict, dataclass
10
12
  from enum import Enum
11
- from typing import Optional, Union
13
+ from typing import Dict, Optional, Union
12
14
 
13
15
 
14
16
  class SourceType(str, Enum):
@@ -23,15 +25,18 @@ class TorchxEvent:
23
25
  The class represents the event produced by ``torchx.runner`` api calls.
24
26
 
25
27
  Arguments:
26
- session: Session id that was used to execute request.
28
+ session: Session id of the current run
27
29
  scheduler: Scheduler that is used to execute request
28
30
  api: Api name
29
31
  app_id: Unique id that is set by the underlying scheduler
30
32
  image: Image/container bundle that is used to execute request.
33
+ app_metadata: metadata to the app (treatment of metadata is scheduler dependent)
31
34
  runcfg: Run config that was used to schedule app.
32
35
  source: Type of source the event is generated.
33
36
  cpu_time_usec: CPU time spent in usec
34
37
  wall_time_usec: Wall time spent in usec
38
+ start_epoch_time_usec: Epoch time in usec when runner event starts
39
+ Workspace: Track how different workspaces/no workspace affects build and scheduler
35
40
  """
36
41
 
37
42
  session: str
@@ -39,11 +44,17 @@ class TorchxEvent:
39
44
  api: str
40
45
  app_id: Optional[str] = None
41
46
  app_image: Optional[str] = None
47
+ app_metadata: Optional[Dict[str, str]] = None
42
48
  runcfg: Optional[str] = None
43
49
  raw_exception: Optional[str] = None
44
50
  source: SourceType = SourceType.UNKNOWN
45
51
  cpu_time_usec: Optional[int] = None
46
52
  wall_time_usec: Optional[int] = None
53
+ start_epoch_time_usec: Optional[int] = None
54
+ workspace: Optional[str] = None
55
+ exception_type: Optional[str] = None
56
+ exception_message: Optional[str] = None
57
+ exception_source_location: Optional[str] = None
47
58
 
48
59
  def __str__(self) -> str:
49
60
  return self.serialize()
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import logging
9
11
  from typing import Dict
10
12
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  .. note:: EXPERIMENTAL, USE AT YOUR OWN RISK, APIs SUBJECT TO CHANGE
10
12
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import abc
9
11
  import json
10
12
  from typing import Dict, Union
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import importlib
9
11
  from typing import Dict, Mapping
10
12
 
@@ -19,6 +21,7 @@ DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
19
21
  "kubernetes": "torchx.schedulers.kubernetes_scheduler",
20
22
  "kubernetes_mcad": "torchx.schedulers.kubernetes_mcad_scheduler",
21
23
  "aws_batch": "torchx.schedulers.aws_batch_scheduler",
24
+ "aws_sagemaker": "torchx.schedulers.aws_sagemaker_scheduler",
22
25
  "gcp_batch": "torchx.schedulers.gcp_batch_scheduler",
23
26
  "ray": "torchx.schedulers.ray_scheduler",
24
27
  "lsf": "torchx.schedulers.lsf_scheduler",
@@ -27,8 +30,7 @@ DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
27
30
 
28
31
  class SchedulerFactory(Protocol):
29
32
  # pyre-fixme: Scheduler opts
30
- def __call__(self, session_name: str, **kwargs: object) -> Scheduler:
31
- ...
33
+ def __call__(self, session_name: str, **kwargs: object) -> Scheduler: ...
32
34
 
33
35
 
34
36
  def _defer_load_scheduler(path: str) -> SchedulerFactory:
@@ -40,9 +42,11 @@ def _defer_load_scheduler(path: str) -> SchedulerFactory:
40
42
  return run
41
43
 
42
44
 
43
- def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
45
+ def get_scheduler_factories(
46
+ group: str = "torchx.schedulers", skip_defaults: bool = False
47
+ ) -> Dict[str, SchedulerFactory]:
44
48
  """
45
- get_scheduler_factories returns all the available schedulers names and the
49
+ get_scheduler_factories returns all the available schedulers names under `group` and the
46
50
  method to instantiate them.
47
51
 
48
52
  The first scheduler in the dictionary is used as the default scheduler.
@@ -53,8 +57,9 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
53
57
  default_schedulers[scheduler] = _defer_load_scheduler(path)
54
58
 
55
59
  return load_group(
56
- "torchx.schedulers",
60
+ group,
57
61
  default=default_schedulers,
62
+ skip_defaults=skip_defaults,
58
63
  )
59
64
 
60
65
 
torchx/schedulers/api.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import abc
9
11
  import re
10
12
  from dataclasses import dataclass, field
@@ -335,7 +337,7 @@ class Scheduler(abc.ABC, Generic[T]):
335
337
  f"{self.__class__.__qualname__} does not support application log iteration"
336
338
  )
337
339
 
338
- def _validate(self, app: AppDef, scheduler: str) -> None:
340
+ def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
339
341
  """
340
342
  Validates whether application is consistent with the scheduler.
341
343
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
 
10
12
  This contains the TorchX AWS Batch scheduler which can be used to run TorchX
@@ -807,6 +809,8 @@ class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
807
809
  startFromHead=True,
808
810
  **args,
809
811
  )
812
+ # pyre-fixme[66]: Exception handler type annotation `unknown` must
813
+ # extend BaseException.
810
814
  except self._log_client.exceptions.ResourceNotFoundException:
811
815
  return [] # noqa: B901
812
816
  if response["nextForwardToken"] == next_token: