torchx-nightly 2024.2.11__py3-none-any.whl → 2025.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (102) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/apps/serve/serve.py +2 -0
  3. torchx/apps/utils/booth_main.py +2 -0
  4. torchx/apps/utils/copy_main.py +2 -0
  5. torchx/apps/utils/process_monitor.py +2 -0
  6. torchx/cli/__init__.py +2 -0
  7. torchx/cli/argparse_util.py +38 -3
  8. torchx/cli/cmd_base.py +2 -0
  9. torchx/cli/cmd_cancel.py +2 -0
  10. torchx/cli/cmd_configure.py +2 -0
  11. torchx/cli/cmd_describe.py +2 -0
  12. torchx/cli/cmd_list.py +2 -0
  13. torchx/cli/cmd_log.py +6 -24
  14. torchx/cli/cmd_run.py +30 -12
  15. torchx/cli/cmd_runopts.py +2 -0
  16. torchx/cli/cmd_status.py +2 -0
  17. torchx/cli/cmd_tracker.py +2 -0
  18. torchx/cli/colors.py +2 -0
  19. torchx/cli/main.py +2 -0
  20. torchx/components/__init__.py +2 -0
  21. torchx/components/component_test_base.py +2 -0
  22. torchx/components/dist.py +2 -0
  23. torchx/components/integration_tests/component_provider.py +2 -0
  24. torchx/components/integration_tests/integ_tests.py +2 -0
  25. torchx/components/serve.py +2 -0
  26. torchx/components/structured_arg.py +2 -0
  27. torchx/components/utils.py +2 -0
  28. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  29. torchx/examples/apps/lightning/data.py +5 -3
  30. torchx/examples/apps/lightning/model.py +2 -0
  31. torchx/examples/apps/lightning/profiler.py +7 -4
  32. torchx/examples/apps/lightning/train.py +2 -0
  33. torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
  34. torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
  35. torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/pipelines/kfp/__init__.py +2 -0
  39. torchx/pipelines/kfp/adapter.py +7 -4
  40. torchx/pipelines/kfp/version.py +2 -0
  41. torchx/runner/__init__.py +2 -0
  42. torchx/runner/api.py +78 -20
  43. torchx/runner/config.py +34 -3
  44. torchx/runner/events/__init__.py +37 -3
  45. torchx/runner/events/api.py +13 -2
  46. torchx/runner/events/handlers.py +2 -0
  47. torchx/runtime/tracking/__init__.py +2 -0
  48. torchx/runtime/tracking/api.py +2 -0
  49. torchx/schedulers/__init__.py +10 -5
  50. torchx/schedulers/api.py +3 -1
  51. torchx/schedulers/aws_batch_scheduler.py +4 -0
  52. torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
  53. torchx/schedulers/devices.py +17 -4
  54. torchx/schedulers/docker_scheduler.py +38 -8
  55. torchx/schedulers/gcp_batch_scheduler.py +8 -9
  56. torchx/schedulers/ids.py +2 -0
  57. torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
  58. torchx/schedulers/kubernetes_scheduler.py +31 -5
  59. torchx/schedulers/local_scheduler.py +45 -6
  60. torchx/schedulers/lsf_scheduler.py +3 -1
  61. torchx/schedulers/ray/ray_driver.py +7 -7
  62. torchx/schedulers/ray_scheduler.py +1 -1
  63. torchx/schedulers/slurm_scheduler.py +3 -1
  64. torchx/schedulers/streams.py +2 -0
  65. torchx/specs/__init__.py +49 -8
  66. torchx/specs/api.py +87 -5
  67. torchx/specs/builders.py +61 -19
  68. torchx/specs/file_linter.py +8 -2
  69. torchx/specs/finder.py +2 -0
  70. torchx/specs/named_resources_aws.py +109 -2
  71. torchx/specs/named_resources_generic.py +2 -0
  72. torchx/specs/test/components/__init__.py +2 -0
  73. torchx/specs/test/components/a/__init__.py +2 -0
  74. torchx/specs/test/components/a/b/__init__.py +2 -0
  75. torchx/specs/test/components/a/b/c.py +2 -0
  76. torchx/specs/test/components/c/__init__.py +2 -0
  77. torchx/specs/test/components/c/d.py +2 -0
  78. torchx/tracker/__init__.py +2 -0
  79. torchx/tracker/api.py +4 -4
  80. torchx/tracker/backend/fsspec.py +2 -0
  81. torchx/util/cuda.py +2 -0
  82. torchx/util/datetime.py +2 -0
  83. torchx/util/entrypoints.py +6 -2
  84. torchx/util/io.py +2 -0
  85. torchx/util/log_tee_helpers.py +210 -0
  86. torchx/util/modules.py +2 -0
  87. torchx/util/session.py +42 -0
  88. torchx/util/shlex.py +2 -0
  89. torchx/util/strings.py +2 -0
  90. torchx/util/types.py +20 -2
  91. torchx/version.py +3 -1
  92. torchx/workspace/__init__.py +2 -0
  93. torchx/workspace/api.py +34 -1
  94. torchx/workspace/dir_workspace.py +2 -0
  95. torchx/workspace/docker_workspace.py +25 -2
  96. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/METADATA +55 -48
  97. torchx_nightly-2025.1.14.dist-info/RECORD +123 -0
  98. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/WHEEL +1 -1
  99. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/entry_points.txt +0 -1
  100. torchx_nightly-2024.2.11.dist-info/RECORD +0 -119
  101. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/LICENSE +0 -0
  102. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/top_level.txt +0 -0
@@ -4,9 +4,12 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import fnmatch
8
10
  import logging
9
11
  import os.path
12
+ import re
10
13
  import tempfile
11
14
  from dataclasses import dataclass
12
15
  from datetime import datetime
@@ -121,6 +124,8 @@ def ensure_network(client: Optional["DockerClient"] = None) -> None:
121
124
 
122
125
  class DockerOpts(TypedDict, total=False):
123
126
  copy_env: Optional[List[str]]
127
+ env: Optional[Dict[str, str]]
128
+ privileged: bool
124
129
 
125
130
 
126
131
  class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
@@ -215,9 +220,14 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
215
220
  for k in keys:
216
221
  default_env[k] = os.environ[k]
217
222
 
223
+ env = cfg.get("env")
224
+ if env:
225
+ default_env.update(env)
226
+
218
227
  app_id = make_unique(app.name)
219
228
  req = DockerJob(app_id=app_id, containers=[])
220
- rank0_name = f"{app_id}-{app.roles[0].name}-0"
229
+ # trim app_id and role name in case name is longer than 64 letters
230
+ rank0_name = f"{app_id[-30:]}-{app.roles[0].name[:30]}-0"
221
231
  for role in app.roles:
222
232
  mounts = []
223
233
  devices = []
@@ -256,8 +266,12 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
256
266
  rank0_env="TORCHX_RANK0_HOST",
257
267
  )
258
268
  replica_role = values.apply(role)
259
- name = f"{app_id}-{role.name}-{replica_id}"
260
-
269
+ # trim app_id and role name in case name is longer than 64 letters. Assume replica_id is less than 10_000. Remove invalid prefixes (https://github.com/moby/moby/blob/master/daemon/names/names.go#L6).
270
+ name = re.sub(
271
+ r"^[^a-zA-Z0-9]+",
272
+ "",
273
+ f"{app_id[-30:]}-{role.name[:30]}-{replica_id}",
274
+ )
261
275
  env = default_env.copy()
262
276
  if replica_role.env:
263
277
  env.update(replica_role.env)
@@ -278,6 +292,7 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
278
292
  LABEL_REPLICA_ID: str(replica_id),
279
293
  },
280
294
  "hostname": name,
295
+ "privileged": cfg.get("privileged", False),
281
296
  "network": NETWORK,
282
297
  "mounts": mounts,
283
298
  "devices": devices,
@@ -292,9 +307,9 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
292
307
  if resource.memMB >= 0:
293
308
  # To support PyTorch dataloaders we need to set /dev/shm to
294
309
  # larger than the 64M default.
295
- c.kwargs["mem_limit"] = c.kwargs[
296
- "shm_size"
297
- ] = f"{int(resource.memMB)}m"
310
+ c.kwargs["mem_limit"] = c.kwargs["shm_size"] = (
311
+ f"{int(resource.memMB)}m"
312
+ )
298
313
  if resource.cpu >= 0:
299
314
  c.kwargs["nano_cpus"] = int(resource.cpu * 1e9)
300
315
  if resource.gpu > 0:
@@ -305,14 +320,14 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
305
320
  c.kwargs["device_requests"] = [
306
321
  DeviceRequest(
307
322
  count=resource.gpu,
308
- capabilities=[["compute"]],
323
+ capabilities=[["compute", "utility"]],
309
324
  )
310
325
  ]
311
326
  req.containers.append(c)
312
327
 
313
328
  return AppDryRunInfo(req, repr)
314
329
 
315
- def _validate(self, app: AppDef, scheduler: str) -> None:
330
+ def _validate(self, app: AppDef, scheduler: str, cfg: DockerOpts) -> None:
316
331
  # Skip validation step
317
332
  pass
318
333
 
@@ -357,6 +372,21 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
357
372
  default=None,
358
373
  help="list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*",
359
374
  )
375
+ opts.add(
376
+ "env",
377
+ type_=Dict[str, str],
378
+ default=None,
379
+ help="""environment variables to be passed to the run. The separator sign can be eiher comma or semicolon
380
+ (e.g. ENV1:v1,ENV2:v2,ENV3:v3 or ENV1:V1;ENV2:V2). Environment variables from env will be applied on top
381
+ of the ones from copy_env""",
382
+ )
383
+ opts.add(
384
+ "privileged",
385
+ type_=bool,
386
+ default=False,
387
+ help="If true runs the container with elevated permissions."
388
+ " Equivalent to running with `docker run --privileged`.",
389
+ )
360
390
  return opts
361
391
 
362
392
  def _get_app_state(self, container: "Container") -> AppState:
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
 
10
12
  This contains the TorchX GCP Batch scheduler which can be used to run TorchX
@@ -205,14 +207,12 @@ class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
205
207
  if cpu <= 0:
206
208
  cpu = 1
207
209
  MILLI = 1000
208
- # pyre-fixme[8]: Attribute has type `Field`; used as `int`.
209
210
  res.cpu_milli = cpu * MILLI
210
211
  memMB = resource.memMB
211
212
  if memMB < 0:
212
213
  raise ValueError(
213
214
  f"memMB should to be set to a positive value, got {memMB}"
214
215
  )
215
- # pyre-fixme[8]: Attribute has type `Field`; used as `int`.
216
216
  res.memory_mib = memMB
217
217
 
218
218
  # TODO support named resources
@@ -360,13 +360,11 @@ class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
360
360
  return None
361
361
 
362
362
  gpu = 0
363
- # pyre-fixme[16]: `Field` has no attribute `instances`.
364
363
  if len(job.allocation_policy.instances) != 0:
365
364
  gpu_type = job.allocation_policy.instances[0].policy.machine_type
366
365
  gpu = GPU_TYPE_TO_COUNT[gpu_type]
367
366
 
368
367
  roles = {}
369
- # pyre-fixme[16]: `RepeatedField` has no attribute `__iter__`.
370
368
  for tg in job.task_groups:
371
369
  env = tg.task_spec.environment.variables
372
370
  role = env["TORCHX_ROLE_NAME"]
@@ -390,7 +388,6 @@ class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
390
388
  # TODO map role/replica status
391
389
  desc = DescribeAppResponse(
392
390
  app_id=app_id,
393
- # pyre-fixme[16]: `Field` has no attribute `state`.
394
391
  state=JOB_STATE[job.status.state.name],
395
392
  roles=list(roles.values()),
396
393
  )
@@ -415,8 +412,10 @@ class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
415
412
  raise ValueError(f"app not found: {app_id}")
416
413
 
417
414
  job_uid = job.uid
418
- filters = [f"labels.job_uid={job_uid}"]
419
- filters.append(f"resource.labels.task_id:task/{job_uid}-group0-{k}")
415
+ filters = [
416
+ f"labels.job_uid={job_uid}",
417
+ f"labels.task_id:{job_uid}-group0-{k}",
418
+ ]
420
419
 
421
420
  if since is not None:
422
421
  filters.append(f'timestamp>="{str(since.isoformat())}"')
@@ -437,7 +436,7 @@ class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
437
436
 
438
437
  logger = logging.Client().logger(BATCH_LOGGER_NAME)
439
438
  for entry in logger.list_entries(filter_=filter):
440
- yield entry.payload
439
+ yield entry.payload + "\n"
441
440
 
442
441
  def _job_full_name_to_app_id(self, job_full_name: str) -> str:
443
442
  """
@@ -465,7 +464,7 @@ class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
465
464
  for job in all_jobs
466
465
  ]
467
466
 
468
- def _validate(self, app: AppDef, scheduler: str) -> None:
467
+ def _validate(self, app: AppDef, scheduler: str, cfg: GCPBatchOpts) -> None:
469
468
  # Skip validation step
470
469
  pass
471
470
 
torchx/schedulers/ids.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import os
9
11
  import random
10
12
  import struct
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
 
10
12
  This contains the TorchX Kubernetes_MCAD scheduler which can be used to run TorchX
@@ -1031,7 +1033,7 @@ class KubernetesMCADScheduler(DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts
1031
1033
  info._cfg = cfg
1032
1034
  return info
1033
1035
 
1034
- def _validate(self, app: AppDef, scheduler: str) -> None:
1036
+ def _validate(self, app: AppDef, scheduler: str, cfg: KubernetesMCADOpts) -> None:
1035
1037
  # Skip validation step
1036
1038
  pass
1037
1039
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
 
10
12
  This contains the TorchX Kubernetes scheduler which can be used to run TorchX
@@ -23,7 +25,7 @@ Install Volcano:
23
25
  kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.6.0/installer/volcano-development.yaml
24
26
 
25
27
  See the
26
- `Volcano Quickstart <https://github.com/volcano-sh/volcano#quick-start-guide>`_
28
+ `Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
27
29
  for more information.
28
30
  """
29
31
 
@@ -167,6 +169,17 @@ ANNOTATION_ISTIO_SIDECAR = "sidecar.istio.io/inject"
167
169
 
168
170
  LABEL_INSTANCE_TYPE = "node.kubernetes.io/instance-type"
169
171
 
172
+ # role.env translates to static env variables in the yaml
173
+ # {"FOO" : "bar"} =====> - name: FOO
174
+ # value: bar
175
+ # unless this placeholder is present at the start of the role.env value then the env variable
176
+ # in the yaml will be dynamically populated at runtime (placeholder is stripped out of the value)
177
+ # {"FOO" : "[FIELD_PATH]bar"} =====> - name: FOO
178
+ # valueFrom:
179
+ # fieldRef:
180
+ # fieldPath: bar
181
+ PLACEHOLDER_FIELD_PATH = "[FIELD_PATH]"
182
+
170
183
 
171
184
  def sanitize_for_serialization(obj: object) -> object:
172
185
  from kubernetes import client
@@ -181,7 +194,9 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
181
194
  V1ContainerPort,
182
195
  V1EmptyDirVolumeSource,
183
196
  V1EnvVar,
197
+ V1EnvVarSource,
184
198
  V1HostPathVolumeSource,
199
+ V1ObjectFieldSelector,
185
200
  V1ObjectMeta,
186
201
  V1PersistentVolumeClaimVolumeSource,
187
202
  V1Pod,
@@ -301,9 +316,20 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
301
316
  image=role.image,
302
317
  name=name,
303
318
  env=[
304
- V1EnvVar(
305
- name=name,
306
- value=value,
319
+ (
320
+ V1EnvVar(
321
+ name=name,
322
+ value_from=V1EnvVarSource(
323
+ field_ref=V1ObjectFieldSelector(
324
+ field_path=value.strip(PLACEHOLDER_FIELD_PATH)
325
+ )
326
+ ),
327
+ )
328
+ if value.startswith(PLACEHOLDER_FIELD_PATH)
329
+ else V1EnvVar(
330
+ name=name,
331
+ value=value,
332
+ )
307
333
  )
308
334
  for name, value in role.env.items()
309
335
  ],
@@ -635,7 +661,7 @@ class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
635
661
  )
636
662
  return AppDryRunInfo(req, repr)
637
663
 
638
- def _validate(self, app: AppDef, scheduler: str) -> None:
664
+ def _validate(self, app: AppDef, scheduler: str, cfg: KubernetesOpts) -> None:
639
665
  # Skip validation step
640
666
  pass
641
667
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  This contains the TorchX local scheduler which can be used to run TorchX
10
12
  components locally via subprocesses.
@@ -35,6 +37,7 @@ from typing import (
35
37
  Iterable,
36
38
  List,
37
39
  Optional,
40
+ Protocol,
38
41
  TextIO,
39
42
  Tuple,
40
43
  )
@@ -262,6 +265,26 @@ AppName = str
262
265
  RoleName = str
263
266
 
264
267
 
268
+ class PopenProtocol(Protocol):
269
+ """
270
+ Protocol wrapper around python's ``subprocess.Popen``. Keeps track of
271
+ the a list of interface methods that the process scheduled by the `LocalScheduler`
272
+ must implement.
273
+ """
274
+
275
+ @property
276
+ def pid(self) -> int: ...
277
+
278
+ @property
279
+ def returncode(self) -> int: ...
280
+
281
+ def wait(self, timeout: Optional[float] = None) -> int: ...
282
+
283
+ def poll(self) -> Optional[int]: ...
284
+
285
+ def kill(self) -> None: ...
286
+
287
+
265
288
  @dataclass
266
289
  class _LocalReplica:
267
290
  """
@@ -270,8 +293,7 @@ class _LocalReplica:
270
293
 
271
294
  role_name: RoleName
272
295
  replica_id: int
273
- # pyre-fixme[24]: Generic type `subprocess.Popen` expects 1 type parameter.
274
- proc: subprocess.Popen
296
+ proc: PopenProtocol
275
297
 
276
298
  # IO streams:
277
299
  # None means no log_dir (out to console)
@@ -608,7 +630,7 @@ class LocalScheduler(Scheduler[LocalOpts]):
608
630
  )
609
631
  return opts
610
632
 
611
- def _validate(self, app: AppDef, scheduler: str) -> None:
633
+ def _validate(self, app: AppDef, scheduler: str, cfg: LocalOpts) -> None:
612
634
  # Skip validation step for local application
613
635
  pass
614
636
 
@@ -674,12 +696,11 @@ class LocalScheduler(Scheduler[LocalOpts]):
674
696
  log.debug(f"Running {role_name} (replica {replica_id}):\n {args_pfmt}")
675
697
  env = self._get_replica_env(replica_params)
676
698
 
677
- proc = subprocess.Popen(
699
+ proc = self.run_local_job(
678
700
  args=replica_params.args,
679
701
  env=env,
680
702
  stdout=stdout_,
681
703
  stderr=stderr_,
682
- start_new_session=True,
683
704
  cwd=replica_params.cwd,
684
705
  )
685
706
  return _LocalReplica(
@@ -692,6 +713,23 @@ class LocalScheduler(Scheduler[LocalOpts]):
692
713
  error_file=env.get("TORCHELASTIC_ERROR_FILE", "<N/A>"),
693
714
  )
694
715
 
716
+ def run_local_job(
717
+ self,
718
+ args: List[str],
719
+ env: Dict[str, str],
720
+ stdout: Optional[io.FileIO],
721
+ stderr: Optional[io.FileIO],
722
+ cwd: Optional[str] = None,
723
+ ) -> "subprocess.Popen[bytes]":
724
+ return subprocess.Popen(
725
+ args=args,
726
+ env=env,
727
+ stdout=stdout,
728
+ stderr=stderr,
729
+ start_new_session=True,
730
+ cwd=cwd,
731
+ )
732
+
695
733
  def _get_replica_output_handles(
696
734
  self,
697
735
  replica_params: ReplicaParam,
@@ -1162,11 +1200,12 @@ def create_scheduler(
1162
1200
  session_name: str,
1163
1201
  cache_size: int = 100,
1164
1202
  extra_paths: Optional[List[str]] = None,
1203
+ image_provider_class: Callable[[LocalOpts], ImageProvider] = CWDImageProvider,
1165
1204
  **kwargs: Any,
1166
1205
  ) -> LocalScheduler:
1167
1206
  return LocalScheduler(
1168
1207
  session_name=session_name,
1169
- image_provider_class=CWDImageProvider,
1208
+ image_provider_class=image_provider_class,
1170
1209
  cache_size=cache_size,
1171
1210
  extra_paths=extra_paths,
1172
1211
  )
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  This contains the TorchX LSF scheduler which can be used to run TorchX
10
12
  components on a LSF cluster.
@@ -486,7 +488,7 @@ class LsfScheduler(Scheduler[LsfOpts]):
486
488
  subprocess.run(req.cmd, stdout=subprocess.PIPE, check=True)
487
489
  return req.app_id
488
490
 
489
- def _validate(self, app: AppDef, scheduler: str) -> None:
491
+ def _validate(self, app: AppDef, scheduler: str, cfg: LsfOpts) -> None:
490
492
  # Skip validation step for lsf
491
493
  pass
492
494
 
@@ -116,7 +116,7 @@ def load_actor_json(filename: str) -> List[RayActor]:
116
116
  return actors
117
117
 
118
118
 
119
- def create_placement_group_async(replicas: List[RayActor]) -> PlacementGroup:
119
+ def create_placement_group_async(replicas: List[RayActor]) -> PlacementGroup: # type: ignore
120
120
  """return a placement group reference, the corresponding placement group could be scheduled or pending"""
121
121
  bundles = []
122
122
  for replica in replicas:
@@ -148,12 +148,12 @@ class RayDriver:
148
148
  else:
149
149
  self.min_replicas = replicas[0].min_replicas # pyre-ignore[8]
150
150
 
151
- self.placement_groups: List[
152
- PlacementGroup
153
- ] = [] # all the placement groups, shall never change
154
- self.actor_info_of_id: Dict[
155
- str, ActorInfo
156
- ] = {} # store the info used to recover an actor
151
+ self.placement_groups: List[PlacementGroup] = (
152
+ []
153
+ ) # all the placement groups, shall never change
154
+ self.actor_info_of_id: Dict[str, ActorInfo] = (
155
+ {}
156
+ ) # store the info used to recover an actor
157
157
  self.active_tasks: List["ray.ObjectRef"] = [] # list of active tasks
158
158
 
159
159
  self.terminating: bool = False # if the job has finished and being terminated
@@ -318,7 +318,7 @@ if _has_ray:
318
318
 
319
319
  return AppDryRunInfo(job, repr)
320
320
 
321
- def _validate(self, app: AppDef, scheduler: str) -> None:
321
+ def _validate(self, app: AppDef, scheduler: str, cfg: RayOpts) -> None:
322
322
  if scheduler != "ray":
323
323
  raise ValueError(
324
324
  f"An unknown scheduler backend '{scheduler}' has been passed to the Ray scheduler."
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  This contains the TorchX Slurm scheduler which can be used to run TorchX
10
12
  components on a Slurm cluster.
@@ -470,7 +472,7 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
470
472
 
471
473
  return AppDryRunInfo(req, repr)
472
474
 
473
- def _validate(self, app: AppDef, scheduler: str) -> None:
475
+ def _validate(self, app: AppDef, scheduler: str, cfg: SlurmOpts) -> None:
474
476
  # Skip validation step for slurm
475
477
  pass
476
478
 
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import io
9
11
  import os
10
12
  import threading
torchx/specs/__init__.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  This contains the TorchX AppDef and related component definitions. These are
10
12
  used by components to define the apps which can then be launched via a TorchX
@@ -13,13 +15,7 @@ scheduler or pipeline adapter.
13
15
  import difflib
14
16
  from typing import Callable, Dict, Optional
15
17
 
16
- from torchx.specs.named_resources_aws import NAMED_RESOURCES as AWS_NAMED_RESOURCES
17
- from torchx.specs.named_resources_generic import (
18
- NAMED_RESOURCES as GENERIC_NAMED_RESOURCES,
19
- )
20
- from torchx.util.entrypoints import load_group
21
-
22
- from .api import ( # noqa: F401 F403
18
+ from torchx.specs.api import (
23
19
  ALL,
24
20
  AppDef,
25
21
  AppDryRunInfo,
@@ -50,7 +46,13 @@ from .api import ( # noqa: F401 F403
50
46
  UnknownSchedulerException,
51
47
  VolumeMount,
52
48
  )
53
- from .builders import make_app_handle, materialize_appdef, parse_mounts # noqa
49
+ from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts
50
+
51
+ from torchx.specs.named_resources_aws import NAMED_RESOURCES as AWS_NAMED_RESOURCES
52
+ from torchx.specs.named_resources_generic import (
53
+ NAMED_RESOURCES as GENERIC_NAMED_RESOURCES,
54
+ )
55
+ from torchx.util.entrypoints import load_group
54
56
 
55
57
  GiB: int = 1024
56
58
 
@@ -181,3 +183,42 @@ def get_named_resources(res: str) -> Resource:
181
183
 
182
184
  """
183
185
  return named_resources[res]
186
+
187
+
188
+ __all__ = [
189
+ "AppDef",
190
+ "AppDryRunInfo",
191
+ "AppHandle",
192
+ "AppState",
193
+ "AppStatus",
194
+ "BindMount",
195
+ "CfgVal",
196
+ "DeviceMount",
197
+ "get_type_name",
198
+ "is_terminal",
199
+ "macros",
200
+ "MISSING",
201
+ "NONE",
202
+ "NULL_RESOURCE",
203
+ "parse_app_handle",
204
+ "ReplicaState",
205
+ "ReplicaStatus",
206
+ "Resource",
207
+ "RetryPolicy",
208
+ "Role",
209
+ "RoleStatus",
210
+ "runopt",
211
+ "runopts",
212
+ "UnknownAppException",
213
+ "UnknownSchedulerException",
214
+ "InvalidRunConfigException",
215
+ "MalformedAppHandleException",
216
+ "VolumeMount",
217
+ "resource",
218
+ "get_named_resources",
219
+ "named_resources",
220
+ "make_app_handle",
221
+ "materialize_appdef",
222
+ "parse_mounts",
223
+ "ALL",
224
+ ]