torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +7 -6
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +79 -5
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +431 -32
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +254 -22
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
  108. torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
  109. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/components/dist.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  For distributed training, TorchX relies on the scheduler's gang scheduling
9
11
  capabilities to schedule ``n`` copies of nodes. Once launched, the application
@@ -90,6 +92,7 @@ def spmd(
90
92
  h: str = "gpu.small",
91
93
  j: str = "1x1",
92
94
  env: Optional[Dict[str, str]] = None,
95
+ metadata: Optional[Dict[str, str]] = None,
93
96
  max_retries: int = 0,
94
97
  mounts: Optional[List[str]] = None,
95
98
  debug: bool = False,
@@ -129,10 +132,8 @@ def spmd(
129
132
  h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
130
133
  j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
131
134
  env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135
+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
132
136
  max_retries: the number of scheduler retries allowed
133
- rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
134
- Only takes effect when running multi-node. When running single node, this parameter
135
- is ignored and a random free port is chosen.
136
137
  mounts: (for docker based runs only) mounts to mount into the worker environment/container
137
138
  (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
138
139
  debug: whether to run with preset debug flags enabled
@@ -151,6 +152,7 @@ def spmd(
151
152
  h=h,
152
153
  j=str(StructuredJArgument.parse_from(h, j)),
153
154
  env=env,
155
+ metadata=metadata,
154
156
  max_retries=max_retries,
155
157
  mounts=mounts,
156
158
  debug=debug,
@@ -169,11 +171,14 @@ def ddp(
169
171
  memMB: int = 1024,
170
172
  j: str = "1x2",
171
173
  env: Optional[Dict[str, str]] = None,
174
+ metadata: Optional[Dict[str, str]] = None,
172
175
  max_retries: int = 0,
173
176
  rdzv_port: int = 29500,
174
177
  rdzv_backend: str = "c10d",
178
+ rdzv_conf: Optional[str] = None,
175
179
  mounts: Optional[List[str]] = None,
176
180
  debug: bool = False,
181
+ tee: int = 3,
177
182
  ) -> specs.AppDef:
178
183
  """
179
184
  Distributed data parallel style application (one role, multi-replica).
@@ -185,7 +190,7 @@ def ddp(
185
190
 
186
191
  Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
187
192
  ``h`` takes precedence if specified for setting resource requirements.
188
- See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
193
+ See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
189
194
 
190
195
  Args:
191
196
  script_args: arguments to the main module
@@ -200,14 +205,17 @@ def ddp(
200
205
  h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
201
206
  j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
202
207
  env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
208
+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
203
209
  max_retries: the number of scheduler retries allowed
204
210
  rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
205
211
  Only takes effect when running multi-node. When running single node, this parameter
206
212
  is ignored and a random free port is chosen.
207
213
  rdzv_backend: the rendezvous backend to use. Only takes effect when running multi-node.
214
+ rdzv_conf: the additional rendezvous configuration to use (ex. join_timeout=600,close_timeout=600,timeout=600).
208
215
  mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
209
216
  See scheduler documentation for more info.
210
217
  debug: whether to run with preset debug flags enabled
218
+ tee: tees the specified std stream(s) to console + file. 0: none, 1: stdout, 2: stderr, 3: both
211
219
  """
212
220
 
213
221
  if (script is None) == (m is None):
@@ -234,8 +242,8 @@ def ddp(
234
242
  # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
235
243
  rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}")
236
244
 
237
- if env is None:
238
- env = {}
245
+ env = env or {}
246
+ metadata = metadata or {}
239
247
 
240
248
  argname = StructuredNameArgument.parse_from(
241
249
  name=name,
@@ -244,6 +252,7 @@ def ddp(
244
252
  )
245
253
 
246
254
  env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name
255
+ env["TORCHX_TRACKING_RUN_NAME"] = argname.run_name
247
256
 
248
257
  env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
249
258
  if debug:
@@ -253,6 +262,7 @@ def ddp(
253
262
  "torchrun",
254
263
  "--rdzv_backend",
255
264
  rdzv_backend,
265
+ *(["--rdzv_conf", rdzv_conf] if rdzv_conf is not None else []),
256
266
  "--rdzv_endpoint",
257
267
  rdzv_endpoint,
258
268
  "--rdzv_id",
@@ -262,7 +272,7 @@ def ddp(
262
272
  "--nproc_per_node",
263
273
  str(nproc_per_node),
264
274
  "--tee",
265
- "3",
275
+ str(tee),
266
276
  "--role",
267
277
  "",
268
278
  ]
@@ -294,6 +304,7 @@ def ddp(
294
304
  mounts=specs.parse_mounts(mounts) if mounts else [],
295
305
  )
296
306
  ],
307
+ metadata=metadata,
297
308
  )
298
309
 
299
310
 
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import os
8
10
  import tempfile
9
11
  from abc import ABC, abstractmethod
@@ -107,7 +109,7 @@ class CopyComponentProvider(ComponentProvider):
107
109
  self._dst_path = "<None>"
108
110
 
109
111
  def setUp(self) -> None:
110
- if self._scheduler in ["local_cwd", "ray"]:
112
+ if self._scheduler in ["local_cwd"]:
111
113
  fname = "torchx_copy_test.txt"
112
114
  self._src_path: str = os.path.join(tempfile.gettempdir(), fname)
113
115
  self._dst_path: str = os.path.join(tempfile.gettempdir(), f"{fname}.copy")
@@ -124,7 +126,7 @@ class CopyComponentProvider(ComponentProvider):
124
126
  def tearDown(self) -> None:
125
127
  if os.path.exists(self._dst_path):
126
128
  os.remove(self._dst_path)
127
- if self._scheduler in ["local_cwd", "ray"] and os.path.exists(self._dst_path):
129
+ if self._scheduler in ["local_cwd"] and os.path.exists(self._dst_path):
128
130
  os.remove(self._dst_path)
129
131
 
130
132
  def get_app_def(self) -> AppDef:
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import inspect
8
10
  import logging
9
11
  import sys
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  These components aim to make it easier to interact with inference and serving
9
11
  tools such as `torchserve <https://pytorch.org/serve/>`_.
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Defines methods for structured (higher order) component argument parsing.
9
11
  Use the functionalities defined in this module to author components
@@ -28,8 +30,6 @@ from dataclasses import dataclass
28
30
  from pathlib import Path
29
31
  from typing import Optional
30
32
 
31
- from pyre_extensions import none_throws
32
-
33
33
  from torchx import specs
34
34
 
35
35
 
@@ -146,7 +146,8 @@ class StructuredNameArgument:
146
146
  if m: # use the last module name
147
147
  run_name = m.rpartition(".")[2]
148
148
  else: # use script name w/ no extension
149
- run_name = Path(none_throws(script)).stem
149
+ assert script, "`script` can't be `None` here due checks above"
150
+ run_name = Path(script).stem
150
151
  return StructuredNameArgument(
151
152
  experiment_name or default_experiment_name, run_name
152
153
  )
@@ -191,12 +192,12 @@ class StructuredJArgument:
191
192
 
192
193
  .. doctest::
193
194
 
194
- >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2"))
195
+ >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2"))
195
196
  Traceback (most recent call last):
196
197
  ...
197
- ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xl` is not a GPU instance. ...
198
+ ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xlarge` is not a GPU instance. ...
198
199
 
199
- >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2x16"))
200
+ >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2x16"))
200
201
  '2x16'
201
202
 
202
203
  """
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  This contains TorchX utility components that are `ready-to-use` out of the box. These are
9
11
  components that simply execute well known binaries (e.g. ``cp``)
@@ -81,6 +83,7 @@ def sh(
81
83
  env: Optional[Dict[str, str]] = None,
82
84
  max_retries: int = 0,
83
85
  mounts: Optional[List[str]] = None,
86
+ entrypoint: Optional[str] = None,
84
87
  ) -> specs.AppDef:
85
88
  """
86
89
  Runs the provided command via sh. Currently sh does not support
@@ -98,21 +101,29 @@ def sh(
98
101
  max_retries: the number of scheduler retries allowed
99
102
  mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
100
103
  See scheduler documentation for more info.
104
+ entrypoint: the entrypoint to use for the command (defaults to sh)
101
105
  """
102
106
 
103
- escaped_args = " ".join(shlex.quote(arg) for arg in args)
107
+ escaped_args = [shlex.quote(arg) for arg in args]
104
108
  if env is None:
105
109
  env = {}
106
110
  env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
107
111
 
112
+ if entrypoint is not None:
113
+ resolved_entrypoint = entrypoint
114
+ resolved_args = escaped_args
115
+ else:
116
+ resolved_entrypoint = "sh"
117
+ resolved_args = ["-c", " ".join(escaped_args)]
118
+
108
119
  return specs.AppDef(
109
120
  name="sh",
110
121
  roles=[
111
122
  specs.Role(
112
123
  name="sh",
113
124
  image=image,
114
- entrypoint="sh",
115
- args=["-c", escaped_args],
125
+ entrypoint=resolved_entrypoint,
126
+ args=resolved_args,
116
127
  num_replicas=num_replicas,
117
128
  resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
118
129
  env=env,
@@ -143,7 +154,7 @@ def python(
143
154
 
144
155
  Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
145
156
  ``h`` takes precedence if specified for setting resource requirements.
146
- See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
157
+ See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
147
158
 
148
159
  Args:
149
160
  args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
@@ -48,7 +48,7 @@ def local_rank() -> int:
48
48
  " but the `LOCAL_RANK` environment variable is not set. Will trivially return 0 for local_rank.\n"
49
49
  " It is recommended to use torchrun/torchx to run your script or set the `LOCAL_RANK` manually.\n"
50
50
  " For additional details see:\n"
51
- " 1) https://pytorch.org/torchx/latest/components/distributed.html\n"
51
+ " 1) https://meta-pytorch.org/torchx/latest/components/distributed.html\n"
52
52
  " 2) https://pytorch.org/docs/stable/elastic/run.html\n"
53
53
  "=============================================================================================="
54
54
  )
@@ -83,9 +83,7 @@ def local_device() -> torch.device:
83
83
  if dist.is_initialized():
84
84
  default_pg = _get_default_group()
85
85
  return (
86
- local_cuda_device()
87
- if default_pg.options.backend == "nccl"
88
- else torch.device("cpu")
86
+ local_cuda_device() if default_pg.name() == "nccl" else torch.device("cpu")
89
87
  )
90
88
  else:
91
89
  return torch.device("cuda") if has_cuda_devices() else torch.device("cpu")
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  Data Preprocessing App Example
10
12
  ====================================
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Trainer Datasets Example
9
11
  ========================
@@ -62,17 +64,17 @@ class ImageFolderSamplesDataset(datasets.ImageFolder):
62
64
  # our trainer and other components that need to load data.
63
65
 
64
66
 
65
- # pyre-fixme[13]: Attribute `test_ds` is never initialized.
66
- # pyre-fixme[13]: Attribute `train_ds` is never initialized.
67
- # pyre-fixme[13]: Attribute `val_ds` is never initialized.
68
67
  class TinyImageNetDataModule(pl.LightningDataModule):
69
68
  """
70
69
  TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
71
70
  imagenet dataset.
72
71
  """
73
72
 
73
+ # pyre-fixme[13]: Attribute `test_ds` is never initialized.
74
74
  train_ds: ImageFolderSamplesDataset
75
+ # pyre-fixme[13]: Attribute `train_ds` is never initialized.
75
76
  val_ds: ImageFolderSamplesDataset
77
+ # pyre-fixme[13]: Attribute `val_ds` is never initialized.
76
78
  test_ds: ImageFolderSamplesDataset
77
79
 
78
80
  def __init__(
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Tiny ImageNet Model
9
11
  ====================
@@ -21,7 +23,7 @@ import pytorch_lightning as pl
21
23
  import torch
22
24
  import torch.jit
23
25
  from torch.nn import functional as F
24
- from torchmetrics import Accuracy
26
+ from torchmetrics.classification import MulticlassAccuracy
25
27
  from torchvision.models.resnet import BasicBlock, ResNet
26
28
 
27
29
 
@@ -42,13 +44,12 @@ class TinyImageNetModel(pl.LightningModule):
42
44
 
43
45
  # We use the torchvision resnet model with some small tweaks to match
44
46
  # TinyImageNet.
45
- m = ResNet(BasicBlock, layer_sizes)
47
+ m = ResNet(BasicBlock, layer_sizes, num_classes=200)
46
48
  m.avgpool = torch.nn.AdaptiveAvgPool2d(1)
47
- m.fc.out_features = 200
48
49
  self.model: ResNet = m
49
50
 
50
- self.train_acc = Accuracy()
51
- self.val_acc = Accuracy()
51
+ self.train_acc = MulticlassAccuracy(num_classes=m.fc.out_features)
52
+ self.val_acc = MulticlassAccuracy(num_classes=m.fc.out_features)
52
53
 
53
54
  # pyre-fixme[14]
54
55
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -69,7 +70,7 @@ class TinyImageNetModel(pl.LightningModule):
69
70
  def _step(
70
71
  self,
71
72
  step_name: str,
72
- acc_metric: Accuracy,
73
+ acc_metric: MulticlassAccuracy,
73
74
  batch: Tuple[torch.Tensor, torch.Tensor],
74
75
  batch_idx: int,
75
76
  ) -> torch.Tensor:
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  Simple Logging Profiler
10
12
  ===========================
@@ -17,18 +19,19 @@ output is used for HPO optimization with Ax.
17
19
  import time
18
20
  from typing import Dict
19
21
 
20
- from pytorch_lightning.loggers.base import LightningLoggerBase
21
- from pytorch_lightning.profiler.base import BaseProfiler
22
+ from pytorch_lightning.loggers.logger import Logger
23
+
24
+ from pytorch_lightning.profilers.profiler import Profiler
22
25
 
23
26
 
24
- class SimpleLoggingProfiler(BaseProfiler):
27
+ class SimpleLoggingProfiler(Profiler):
25
28
  """
26
29
  This profiler records the duration of actions (in seconds) and reports the
27
30
  mean duration of each action to the specified logger. Reported metrics are
28
31
  in the format `duration_<event>`.
29
32
  """
30
33
 
31
- def __init__(self, logger: LightningLoggerBase) -> None:
34
+ def __init__(self, logger: Logger) -> None:
32
35
  super().__init__()
33
36
 
34
37
  self.current_actions: Dict[str, float] = {}
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  Trainer Example
10
12
  =============================================
@@ -58,6 +60,7 @@ import pytorch_lightning as pl
58
60
  import torch
59
61
  from pytorch_lightning.callbacks import ModelCheckpoint
60
62
  from pytorch_lightning.loggers import TensorBoardLogger
63
+ from torch.distributed.elastic.multiprocessing import errors
61
64
  from torchx.examples.apps.lightning.data import (
62
65
  create_random_data,
63
66
  download_data,
@@ -83,7 +86,12 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
83
86
  parser.add_argument(
84
87
  "--batch_size", type=int, default=32, help="batch size to use for training"
85
88
  )
86
- parser.add_argument("--num_samples", type=int, default=10, help="num_samples")
89
+ parser.add_argument(
90
+ "--num_samples",
91
+ type=int,
92
+ default=32,
93
+ help="number of samples in the dataset",
94
+ )
87
95
  parser.add_argument(
88
96
  "--data_path",
89
97
  type=str,
@@ -124,6 +132,7 @@ def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
124
132
  )
125
133
 
126
134
 
135
+ @errors.record
127
136
  def main(argv: List[str]) -> None:
128
137
  with tempfile.TemporaryDirectory() as tmpdir:
129
138
  args = parse_args(argv)
@@ -136,7 +145,7 @@ def main(argv: List[str]) -> None:
136
145
  if not args.data_path:
137
146
  data_path = os.path.join(tmpdir, "data")
138
147
  os.makedirs(data_path)
139
- create_random_data(data_path)
148
+ create_random_data(data_path, args.num_samples)
140
149
  else:
141
150
  data_path = download_data(args.data_path, tmpdir)
142
151
 
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ # This file is for lex team to ramp up on torchX OSS. It's for training purpose.
11
+ # Please DO NOT modify this file unless you know what you are doing.
torchx/notebook.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  """
9
11
  This contains TorchX utilities for creating and running components and apps from
10
12
  an Jupyter/IPython Notebook.
torchx/runner/__init__.py CHANGED
@@ -5,4 +5,6 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  from torchx.runner.api import get_runner, Runner # noqa: F401 F403