torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/__init__.py +2 -0
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/apps/serve/serve.py +2 -0
- torchx/apps/utils/booth_main.py +2 -0
- torchx/apps/utils/copy_main.py +2 -0
- torchx/apps/utils/process_monitor.py +2 -0
- torchx/cli/__init__.py +2 -0
- torchx/cli/argparse_util.py +38 -3
- torchx/cli/cmd_base.py +2 -0
- torchx/cli/cmd_cancel.py +2 -0
- torchx/cli/cmd_configure.py +2 -0
- torchx/cli/cmd_delete.py +30 -0
- torchx/cli/cmd_describe.py +2 -0
- torchx/cli/cmd_list.py +8 -4
- torchx/cli/cmd_log.py +6 -24
- torchx/cli/cmd_run.py +269 -45
- torchx/cli/cmd_runopts.py +2 -0
- torchx/cli/cmd_status.py +12 -1
- torchx/cli/cmd_tracker.py +3 -1
- torchx/cli/colors.py +2 -0
- torchx/cli/main.py +4 -0
- torchx/components/__init__.py +3 -8
- torchx/components/component_test_base.py +2 -0
- torchx/components/dist.py +18 -7
- torchx/components/integration_tests/component_provider.py +4 -2
- torchx/components/integration_tests/integ_tests.py +2 -0
- torchx/components/serve.py +2 -0
- torchx/components/structured_arg.py +7 -6
- torchx/components/utils.py +15 -4
- torchx/distributed/__init__.py +2 -4
- torchx/examples/apps/datapreproc/datapreproc.py +2 -0
- torchx/examples/apps/lightning/data.py +5 -3
- torchx/examples/apps/lightning/model.py +7 -6
- torchx/examples/apps/lightning/profiler.py +7 -4
- torchx/examples/apps/lightning/train.py +11 -2
- torchx/examples/torchx_out_of_sync_training.py +11 -0
- torchx/notebook.py +2 -0
- torchx/runner/__init__.py +2 -0
- torchx/runner/api.py +167 -60
- torchx/runner/config.py +43 -10
- torchx/runner/events/__init__.py +57 -13
- torchx/runner/events/api.py +14 -3
- torchx/runner/events/handlers.py +2 -0
- torchx/runtime/tracking/__init__.py +2 -0
- torchx/runtime/tracking/api.py +2 -0
- torchx/schedulers/__init__.py +16 -15
- torchx/schedulers/api.py +70 -14
- torchx/schedulers/aws_batch_scheduler.py +79 -5
- torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
- torchx/schedulers/devices.py +17 -4
- torchx/schedulers/docker_scheduler.py +43 -11
- torchx/schedulers/ids.py +29 -23
- torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
- torchx/schedulers/kubernetes_scheduler.py +383 -38
- torchx/schedulers/local_scheduler.py +100 -27
- torchx/schedulers/lsf_scheduler.py +5 -4
- torchx/schedulers/slurm_scheduler.py +336 -20
- torchx/schedulers/streams.py +2 -0
- torchx/specs/__init__.py +89 -12
- torchx/specs/api.py +431 -32
- torchx/specs/builders.py +176 -38
- torchx/specs/file_linter.py +143 -57
- torchx/specs/finder.py +68 -28
- torchx/specs/named_resources_aws.py +254 -22
- torchx/specs/named_resources_generic.py +2 -0
- torchx/specs/overlays.py +106 -0
- torchx/specs/test/components/__init__.py +2 -0
- torchx/specs/test/components/a/__init__.py +2 -0
- torchx/specs/test/components/a/b/__init__.py +2 -0
- torchx/specs/test/components/a/b/c.py +2 -0
- torchx/specs/test/components/c/__init__.py +2 -0
- torchx/specs/test/components/c/d.py +2 -0
- torchx/tracker/__init__.py +12 -6
- torchx/tracker/api.py +15 -18
- torchx/tracker/backend/fsspec.py +2 -0
- torchx/util/cuda.py +2 -0
- torchx/util/datetime.py +2 -0
- torchx/util/entrypoints.py +39 -15
- torchx/util/io.py +2 -0
- torchx/util/log_tee_helpers.py +210 -0
- torchx/util/modules.py +65 -0
- torchx/util/session.py +42 -0
- torchx/util/shlex.py +2 -0
- torchx/util/strings.py +3 -1
- torchx/util/types.py +90 -29
- torchx/version.py +4 -2
- torchx/workspace/__init__.py +2 -0
- torchx/workspace/api.py +136 -6
- torchx/workspace/dir_workspace.py +2 -0
- torchx/workspace/docker_workspace.py +30 -2
- torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
- torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
- torchx/pipelines/kfp/__init__.py +0 -28
- torchx/pipelines/kfp/adapter.py +0 -271
- torchx/pipelines/kfp/version.py +0 -17
- torchx/schedulers/gcp_batch_scheduler.py +0 -487
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -453
- torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
- torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/components/dist.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
For distributed training, TorchX relies on the scheduler's gang scheduling
|
|
9
11
|
capabilities to schedule ``n`` copies of nodes. Once launched, the application
|
|
@@ -90,6 +92,7 @@ def spmd(
|
|
|
90
92
|
h: str = "gpu.small",
|
|
91
93
|
j: str = "1x1",
|
|
92
94
|
env: Optional[Dict[str, str]] = None,
|
|
95
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
93
96
|
max_retries: int = 0,
|
|
94
97
|
mounts: Optional[List[str]] = None,
|
|
95
98
|
debug: bool = False,
|
|
@@ -129,10 +132,8 @@ def spmd(
|
|
|
129
132
|
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
|
|
130
133
|
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
|
|
131
134
|
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
|
|
135
|
+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
|
|
132
136
|
max_retries: the number of scheduler retries allowed
|
|
133
|
-
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
|
|
134
|
-
Only takes effect when running multi-node. When running single node, this parameter
|
|
135
|
-
is ignored and a random free port is chosen.
|
|
136
137
|
mounts: (for docker based runs only) mounts to mount into the worker environment/container
|
|
137
138
|
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
|
|
138
139
|
debug: whether to run with preset debug flags enabled
|
|
@@ -151,6 +152,7 @@ def spmd(
|
|
|
151
152
|
h=h,
|
|
152
153
|
j=str(StructuredJArgument.parse_from(h, j)),
|
|
153
154
|
env=env,
|
|
155
|
+
metadata=metadata,
|
|
154
156
|
max_retries=max_retries,
|
|
155
157
|
mounts=mounts,
|
|
156
158
|
debug=debug,
|
|
@@ -169,11 +171,14 @@ def ddp(
|
|
|
169
171
|
memMB: int = 1024,
|
|
170
172
|
j: str = "1x2",
|
|
171
173
|
env: Optional[Dict[str, str]] = None,
|
|
174
|
+
metadata: Optional[Dict[str, str]] = None,
|
|
172
175
|
max_retries: int = 0,
|
|
173
176
|
rdzv_port: int = 29500,
|
|
174
177
|
rdzv_backend: str = "c10d",
|
|
178
|
+
rdzv_conf: Optional[str] = None,
|
|
175
179
|
mounts: Optional[List[str]] = None,
|
|
176
180
|
debug: bool = False,
|
|
181
|
+
tee: int = 3,
|
|
177
182
|
) -> specs.AppDef:
|
|
178
183
|
"""
|
|
179
184
|
Distributed data parallel style application (one role, multi-replica).
|
|
@@ -185,7 +190,7 @@ def ddp(
|
|
|
185
190
|
|
|
186
191
|
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
|
|
187
192
|
``h`` takes precedence if specified for setting resource requirements.
|
|
188
|
-
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
193
|
+
See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
189
194
|
|
|
190
195
|
Args:
|
|
191
196
|
script_args: arguments to the main module
|
|
@@ -200,14 +205,17 @@ def ddp(
|
|
|
200
205
|
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
|
|
201
206
|
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
|
|
202
207
|
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
|
|
208
|
+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
|
|
203
209
|
max_retries: the number of scheduler retries allowed
|
|
204
210
|
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
|
|
205
211
|
Only takes effect when running multi-node. When running single node, this parameter
|
|
206
212
|
is ignored and a random free port is chosen.
|
|
207
213
|
rdzv_backend: the rendezvous backend to use. Only takes effect when running multi-node.
|
|
214
|
+
rdzv_conf: the additional rendezvous configuration to use (ex. join_timeout=600,close_timeout=600,timeout=600).
|
|
208
215
|
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
|
|
209
216
|
See scheduler documentation for more info.
|
|
210
217
|
debug: whether to run with preset debug flags enabled
|
|
218
|
+
tee: tees the specified std stream(s) to console + file. 0: none, 1: stdout, 2: stderr, 3: both
|
|
211
219
|
"""
|
|
212
220
|
|
|
213
221
|
if (script is None) == (m is None):
|
|
@@ -234,8 +242,8 @@ def ddp(
|
|
|
234
242
|
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
|
|
235
243
|
rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}")
|
|
236
244
|
|
|
237
|
-
|
|
238
|
-
|
|
245
|
+
env = env or {}
|
|
246
|
+
metadata = metadata or {}
|
|
239
247
|
|
|
240
248
|
argname = StructuredNameArgument.parse_from(
|
|
241
249
|
name=name,
|
|
@@ -244,6 +252,7 @@ def ddp(
|
|
|
244
252
|
)
|
|
245
253
|
|
|
246
254
|
env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name
|
|
255
|
+
env["TORCHX_TRACKING_RUN_NAME"] = argname.run_name
|
|
247
256
|
|
|
248
257
|
env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
|
|
249
258
|
if debug:
|
|
@@ -253,6 +262,7 @@ def ddp(
|
|
|
253
262
|
"torchrun",
|
|
254
263
|
"--rdzv_backend",
|
|
255
264
|
rdzv_backend,
|
|
265
|
+
*(["--rdzv_conf", rdzv_conf] if rdzv_conf is not None else []),
|
|
256
266
|
"--rdzv_endpoint",
|
|
257
267
|
rdzv_endpoint,
|
|
258
268
|
"--rdzv_id",
|
|
@@ -262,7 +272,7 @@ def ddp(
|
|
|
262
272
|
"--nproc_per_node",
|
|
263
273
|
str(nproc_per_node),
|
|
264
274
|
"--tee",
|
|
265
|
-
|
|
275
|
+
str(tee),
|
|
266
276
|
"--role",
|
|
267
277
|
"",
|
|
268
278
|
]
|
|
@@ -294,6 +304,7 @@ def ddp(
|
|
|
294
304
|
mounts=specs.parse_mounts(mounts) if mounts else [],
|
|
295
305
|
)
|
|
296
306
|
],
|
|
307
|
+
metadata=metadata,
|
|
297
308
|
)
|
|
298
309
|
|
|
299
310
|
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import os
|
|
8
10
|
import tempfile
|
|
9
11
|
from abc import ABC, abstractmethod
|
|
@@ -107,7 +109,7 @@ class CopyComponentProvider(ComponentProvider):
|
|
|
107
109
|
self._dst_path = "<None>"
|
|
108
110
|
|
|
109
111
|
def setUp(self) -> None:
|
|
110
|
-
if self._scheduler in ["local_cwd"
|
|
112
|
+
if self._scheduler in ["local_cwd"]:
|
|
111
113
|
fname = "torchx_copy_test.txt"
|
|
112
114
|
self._src_path: str = os.path.join(tempfile.gettempdir(), fname)
|
|
113
115
|
self._dst_path: str = os.path.join(tempfile.gettempdir(), f"{fname}.copy")
|
|
@@ -124,7 +126,7 @@ class CopyComponentProvider(ComponentProvider):
|
|
|
124
126
|
def tearDown(self) -> None:
|
|
125
127
|
if os.path.exists(self._dst_path):
|
|
126
128
|
os.remove(self._dst_path)
|
|
127
|
-
if self._scheduler in ["local_cwd"
|
|
129
|
+
if self._scheduler in ["local_cwd"] and os.path.exists(self._dst_path):
|
|
128
130
|
os.remove(self._dst_path)
|
|
129
131
|
|
|
130
132
|
def get_app_def(self) -> AppDef:
|
torchx/components/serve.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
These components aim to make it easier to interact with inference and serving
|
|
9
11
|
tools such as `torchserve <https://pytorch.org/serve/>`_.
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
Defines methods for structured (higher order) component argument parsing.
|
|
9
11
|
Use the functionalities defined in this module to author components
|
|
@@ -28,8 +30,6 @@ from dataclasses import dataclass
|
|
|
28
30
|
from pathlib import Path
|
|
29
31
|
from typing import Optional
|
|
30
32
|
|
|
31
|
-
from pyre_extensions import none_throws
|
|
32
|
-
|
|
33
33
|
from torchx import specs
|
|
34
34
|
|
|
35
35
|
|
|
@@ -146,7 +146,8 @@ class StructuredNameArgument:
|
|
|
146
146
|
if m: # use the last module name
|
|
147
147
|
run_name = m.rpartition(".")[2]
|
|
148
148
|
else: # use script name w/ no extension
|
|
149
|
-
|
|
149
|
+
assert script, "`script` can't be `None` here due checks above"
|
|
150
|
+
run_name = Path(script).stem
|
|
150
151
|
return StructuredNameArgument(
|
|
151
152
|
experiment_name or default_experiment_name, run_name
|
|
152
153
|
)
|
|
@@ -191,12 +192,12 @@ class StructuredJArgument:
|
|
|
191
192
|
|
|
192
193
|
.. doctest::
|
|
193
194
|
|
|
194
|
-
>>> str(StructuredJArgument.parse_from(h="aws_trn1.
|
|
195
|
+
>>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2"))
|
|
195
196
|
Traceback (most recent call last):
|
|
196
197
|
...
|
|
197
|
-
ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.
|
|
198
|
+
ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xlarge` is not a GPU instance. ...
|
|
198
199
|
|
|
199
|
-
>>> str(StructuredJArgument.parse_from(h="aws_trn1.
|
|
200
|
+
>>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2x16"))
|
|
200
201
|
'2x16'
|
|
201
202
|
|
|
202
203
|
"""
|
torchx/components/utils.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
This contains TorchX utility components that are `ready-to-use` out of the box. These are
|
|
9
11
|
components that simply execute well known binaries (e.g. ``cp``)
|
|
@@ -81,6 +83,7 @@ def sh(
|
|
|
81
83
|
env: Optional[Dict[str, str]] = None,
|
|
82
84
|
max_retries: int = 0,
|
|
83
85
|
mounts: Optional[List[str]] = None,
|
|
86
|
+
entrypoint: Optional[str] = None,
|
|
84
87
|
) -> specs.AppDef:
|
|
85
88
|
"""
|
|
86
89
|
Runs the provided command via sh. Currently sh does not support
|
|
@@ -98,21 +101,29 @@ def sh(
|
|
|
98
101
|
max_retries: the number of scheduler retries allowed
|
|
99
102
|
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
|
|
100
103
|
See scheduler documentation for more info.
|
|
104
|
+
entrypoint: the entrypoint to use for the command (defaults to sh)
|
|
101
105
|
"""
|
|
102
106
|
|
|
103
|
-
escaped_args =
|
|
107
|
+
escaped_args = [shlex.quote(arg) for arg in args]
|
|
104
108
|
if env is None:
|
|
105
109
|
env = {}
|
|
106
110
|
env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
|
|
107
111
|
|
|
112
|
+
if entrypoint is not None:
|
|
113
|
+
resolved_entrypoint = entrypoint
|
|
114
|
+
resolved_args = escaped_args
|
|
115
|
+
else:
|
|
116
|
+
resolved_entrypoint = "sh"
|
|
117
|
+
resolved_args = ["-c", " ".join(escaped_args)]
|
|
118
|
+
|
|
108
119
|
return specs.AppDef(
|
|
109
120
|
name="sh",
|
|
110
121
|
roles=[
|
|
111
122
|
specs.Role(
|
|
112
123
|
name="sh",
|
|
113
124
|
image=image,
|
|
114
|
-
entrypoint=
|
|
115
|
-
args=
|
|
125
|
+
entrypoint=resolved_entrypoint,
|
|
126
|
+
args=resolved_args,
|
|
116
127
|
num_replicas=num_replicas,
|
|
117
128
|
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
|
|
118
129
|
env=env,
|
|
@@ -143,7 +154,7 @@ def python(
|
|
|
143
154
|
|
|
144
155
|
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
|
|
145
156
|
``h`` takes precedence if specified for setting resource requirements.
|
|
146
|
-
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
157
|
+
See `registering named resources <https://meta-pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
|
|
147
158
|
|
|
148
159
|
Args:
|
|
149
160
|
args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
|
torchx/distributed/__init__.py
CHANGED
|
@@ -48,7 +48,7 @@ def local_rank() -> int:
|
|
|
48
48
|
" but the `LOCAL_RANK` environment variable is not set. Will trivially return 0 for local_rank.\n"
|
|
49
49
|
" It is recommended to use torchrun/torchx to run your script or set the `LOCAL_RANK` manually.\n"
|
|
50
50
|
" For additional details see:\n"
|
|
51
|
-
" 1) https://pytorch.org/torchx/latest/components/distributed.html\n"
|
|
51
|
+
" 1) https://meta-pytorch.org/torchx/latest/components/distributed.html\n"
|
|
52
52
|
" 2) https://pytorch.org/docs/stable/elastic/run.html\n"
|
|
53
53
|
"=============================================================================================="
|
|
54
54
|
)
|
|
@@ -83,9 +83,7 @@ def local_device() -> torch.device:
|
|
|
83
83
|
if dist.is_initialized():
|
|
84
84
|
default_pg = _get_default_group()
|
|
85
85
|
return (
|
|
86
|
-
local_cuda_device()
|
|
87
|
-
if default_pg.options.backend == "nccl"
|
|
88
|
-
else torch.device("cpu")
|
|
86
|
+
local_cuda_device() if default_pg.name() == "nccl" else torch.device("cpu")
|
|
89
87
|
)
|
|
90
88
|
else:
|
|
91
89
|
return torch.device("cuda") if has_cuda_devices() else torch.device("cpu")
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
Trainer Datasets Example
|
|
9
11
|
========================
|
|
@@ -62,17 +64,17 @@ class ImageFolderSamplesDataset(datasets.ImageFolder):
|
|
|
62
64
|
# our trainer and other components that need to load data.
|
|
63
65
|
|
|
64
66
|
|
|
65
|
-
# pyre-fixme[13]: Attribute `test_ds` is never initialized.
|
|
66
|
-
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
|
|
67
|
-
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
|
|
68
67
|
class TinyImageNetDataModule(pl.LightningDataModule):
|
|
69
68
|
"""
|
|
70
69
|
TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
|
|
71
70
|
imagenet dataset.
|
|
72
71
|
"""
|
|
73
72
|
|
|
73
|
+
# pyre-fixme[13]: Attribute `test_ds` is never initialized.
|
|
74
74
|
train_ds: ImageFolderSamplesDataset
|
|
75
|
+
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
|
|
75
76
|
val_ds: ImageFolderSamplesDataset
|
|
77
|
+
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
|
|
76
78
|
test_ds: ImageFolderSamplesDataset
|
|
77
79
|
|
|
78
80
|
def __init__(
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
Tiny ImageNet Model
|
|
9
11
|
====================
|
|
@@ -21,7 +23,7 @@ import pytorch_lightning as pl
|
|
|
21
23
|
import torch
|
|
22
24
|
import torch.jit
|
|
23
25
|
from torch.nn import functional as F
|
|
24
|
-
from torchmetrics import
|
|
26
|
+
from torchmetrics.classification import MulticlassAccuracy
|
|
25
27
|
from torchvision.models.resnet import BasicBlock, ResNet
|
|
26
28
|
|
|
27
29
|
|
|
@@ -42,13 +44,12 @@ class TinyImageNetModel(pl.LightningModule):
|
|
|
42
44
|
|
|
43
45
|
# We use the torchvision resnet model with some small tweaks to match
|
|
44
46
|
# TinyImageNet.
|
|
45
|
-
m = ResNet(BasicBlock, layer_sizes)
|
|
47
|
+
m = ResNet(BasicBlock, layer_sizes, num_classes=200)
|
|
46
48
|
m.avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
|
47
|
-
m.fc.out_features = 200
|
|
48
49
|
self.model: ResNet = m
|
|
49
50
|
|
|
50
|
-
self.train_acc =
|
|
51
|
-
self.val_acc =
|
|
51
|
+
self.train_acc = MulticlassAccuracy(num_classes=m.fc.out_features)
|
|
52
|
+
self.val_acc = MulticlassAccuracy(num_classes=m.fc.out_features)
|
|
52
53
|
|
|
53
54
|
# pyre-fixme[14]
|
|
54
55
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -69,7 +70,7 @@ class TinyImageNetModel(pl.LightningModule):
|
|
|
69
70
|
def _step(
|
|
70
71
|
self,
|
|
71
72
|
step_name: str,
|
|
72
|
-
acc_metric:
|
|
73
|
+
acc_metric: MulticlassAccuracy,
|
|
73
74
|
batch: Tuple[torch.Tensor, torch.Tensor],
|
|
74
75
|
batch_idx: int,
|
|
75
76
|
) -> torch.Tensor:
|
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
8
10
|
"""
|
|
9
11
|
Simple Logging Profiler
|
|
10
12
|
===========================
|
|
@@ -17,18 +19,19 @@ output is used for HPO optimization with Ax.
|
|
|
17
19
|
import time
|
|
18
20
|
from typing import Dict
|
|
19
21
|
|
|
20
|
-
from pytorch_lightning.loggers.
|
|
21
|
-
|
|
22
|
+
from pytorch_lightning.loggers.logger import Logger
|
|
23
|
+
|
|
24
|
+
from pytorch_lightning.profilers.profiler import Profiler
|
|
22
25
|
|
|
23
26
|
|
|
24
|
-
class SimpleLoggingProfiler(
|
|
27
|
+
class SimpleLoggingProfiler(Profiler):
|
|
25
28
|
"""
|
|
26
29
|
This profiler records the duration of actions (in seconds) and reports the
|
|
27
30
|
mean duration of each action to the specified logger. Reported metrics are
|
|
28
31
|
in the format `duration_<event>`.
|
|
29
32
|
"""
|
|
30
33
|
|
|
31
|
-
def __init__(self, logger:
|
|
34
|
+
def __init__(self, logger: Logger) -> None:
|
|
32
35
|
super().__init__()
|
|
33
36
|
|
|
34
37
|
self.current_actions: Dict[str, float] = {}
|
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
8
10
|
"""
|
|
9
11
|
Trainer Example
|
|
10
12
|
=============================================
|
|
@@ -58,6 +60,7 @@ import pytorch_lightning as pl
|
|
|
58
60
|
import torch
|
|
59
61
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
60
62
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
63
|
+
from torch.distributed.elastic.multiprocessing import errors
|
|
61
64
|
from torchx.examples.apps.lightning.data import (
|
|
62
65
|
create_random_data,
|
|
63
66
|
download_data,
|
|
@@ -83,7 +86,12 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
|
|
|
83
86
|
parser.add_argument(
|
|
84
87
|
"--batch_size", type=int, default=32, help="batch size to use for training"
|
|
85
88
|
)
|
|
86
|
-
parser.add_argument(
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--num_samples",
|
|
91
|
+
type=int,
|
|
92
|
+
default=32,
|
|
93
|
+
help="number of samples in the dataset",
|
|
94
|
+
)
|
|
87
95
|
parser.add_argument(
|
|
88
96
|
"--data_path",
|
|
89
97
|
type=str,
|
|
@@ -124,6 +132,7 @@ def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
|
|
|
124
132
|
)
|
|
125
133
|
|
|
126
134
|
|
|
135
|
+
@errors.record
|
|
127
136
|
def main(argv: List[str]) -> None:
|
|
128
137
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
129
138
|
args = parse_args(argv)
|
|
@@ -136,7 +145,7 @@ def main(argv: List[str]) -> None:
|
|
|
136
145
|
if not args.data_path:
|
|
137
146
|
data_path = os.path.join(tmpdir, "data")
|
|
138
147
|
os.makedirs(data_path)
|
|
139
|
-
create_random_data(data_path)
|
|
148
|
+
create_random_data(data_path, args.num_samples)
|
|
140
149
|
else:
|
|
141
150
|
data_path = download_data(args.data_path, tmpdir)
|
|
142
151
|
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
# This file is for lex team to ramp up on torchX OSS. It's for training purpose.
|
|
11
|
+
# Please DO NOT modify this file unless you know what you are doing.
|
torchx/notebook.py
CHANGED
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
8
10
|
"""
|
|
9
11
|
This contains TorchX utilities for creating and running components and apps from
|
|
10
12
|
an Jupyter/IPython Notebook.
|