nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import signal
|
|
5
|
+
import subprocess
|
|
6
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
7
|
+
from datetime import timedelta
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Literal
|
|
10
|
+
|
|
11
|
+
from typing_extensions import (
|
|
12
|
+
TypeAlias,
|
|
13
|
+
TypedDict,
|
|
14
|
+
TypeVar,
|
|
15
|
+
TypeVarTuple,
|
|
16
|
+
Unpack,
|
|
17
|
+
assert_never,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from . import lsf, slurm
|
|
21
|
+
from ._output import SubmitOutput
|
|
22
|
+
|
|
23
|
+
TArgs = TypeVarTuple("TArgs")
|
|
24
|
+
_Path: TypeAlias = str | Path | os.PathLike
|
|
25
|
+
|
|
26
|
+
log = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GenericJobKwargs(TypedDict, total=False):
|
|
30
|
+
name: str
|
|
31
|
+
"""The name of the job."""
|
|
32
|
+
|
|
33
|
+
partition: str | Sequence[str]
|
|
34
|
+
"""The partition or queue to submit the job to. Same as `queue`."""
|
|
35
|
+
|
|
36
|
+
queue: str | Sequence[str]
|
|
37
|
+
"""The queue to submit the job to. Same as `partition`."""
|
|
38
|
+
|
|
39
|
+
qos: str
|
|
40
|
+
"""
|
|
41
|
+
The quality of service to submit the job to.
|
|
42
|
+
|
|
43
|
+
This corresponds to the "--qos" option in sbatch (only for Slurm).
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
account: str
|
|
47
|
+
"""The account (or project) to charge the job to. Same as `project`."""
|
|
48
|
+
|
|
49
|
+
project: str
|
|
50
|
+
"""The project (or account) to charge the job to. Same as `account`."""
|
|
51
|
+
|
|
52
|
+
output_file: _Path
|
|
53
|
+
"""
|
|
54
|
+
The file to write the job output to.
|
|
55
|
+
|
|
56
|
+
This corresponds to the "-o" option in bsub. If not specified, the output will be written to the default output file.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
error_file: _Path
|
|
60
|
+
"""
|
|
61
|
+
The file to write the job errors to.
|
|
62
|
+
|
|
63
|
+
This corresponds to the "-e" option in bsub. If not specified, the errors will be written to the default error file.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
nodes: int
|
|
67
|
+
"""The number of nodes to request."""
|
|
68
|
+
|
|
69
|
+
tasks_per_node: int
|
|
70
|
+
"""The number of tasks to request per node."""
|
|
71
|
+
|
|
72
|
+
cpus_per_task: int
|
|
73
|
+
"""The number of CPUs to request per task."""
|
|
74
|
+
|
|
75
|
+
gpus_per_task: int
|
|
76
|
+
"""The number of GPUs to request per task."""
|
|
77
|
+
|
|
78
|
+
memory_mb: int
|
|
79
|
+
"""The maximum memory for the job in MB."""
|
|
80
|
+
|
|
81
|
+
walltime: timedelta
|
|
82
|
+
"""The maximum walltime for the job."""
|
|
83
|
+
|
|
84
|
+
email: str
|
|
85
|
+
"""The email address to send notifications to."""
|
|
86
|
+
|
|
87
|
+
notifications: set[Literal["begin", "end"]]
|
|
88
|
+
"""The notifications to send via email."""
|
|
89
|
+
|
|
90
|
+
setup_commands: Sequence[str]
|
|
91
|
+
"""
|
|
92
|
+
The setup commands to run before the job.
|
|
93
|
+
|
|
94
|
+
These commands will be executed prior to everything else in the job script.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
environment: Mapping[str, str]
|
|
98
|
+
"""
|
|
99
|
+
The environment variables to set for the job.
|
|
100
|
+
|
|
101
|
+
These variables will be set prior to executing any commands in the job script.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
command_prefix: str
|
|
105
|
+
"""
|
|
106
|
+
A command to prefix the job command with.
|
|
107
|
+
|
|
108
|
+
This is used to add commands like `srun` or `jsrun` to the job command.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
constraint: str | Sequence[str]
|
|
112
|
+
"""
|
|
113
|
+
The constraint to request for the job. For SLRUM, this corresponds to the `--constraint` option. For LSF, this is unused.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
signal: signal.Signals
|
|
117
|
+
"""The signal that will be sent to the job when it is time to stop it."""
|
|
118
|
+
|
|
119
|
+
command_template: str
|
|
120
|
+
"""
|
|
121
|
+
The template for the command to execute the helper script.
|
|
122
|
+
|
|
123
|
+
Default: `bash {script}`.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
requeue_on_preempt: bool
|
|
127
|
+
"""
|
|
128
|
+
Whether to requeue the job if it is preempted.
|
|
129
|
+
|
|
130
|
+
This corresponds to the "--requeue" option in sbatch (only for Slurm).
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
slurm_options: slurm.SlurmJobKwargs
|
|
134
|
+
"""Additional keyword arguments for Slurm jobs."""
|
|
135
|
+
|
|
136
|
+
lsf_options: lsf.LSFJobKwargs
|
|
137
|
+
"""Additional keyword arguments for LSF jobs."""
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
Scheduler: TypeAlias = Literal["slurm", "lsf"]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
T = TypeVar("T", infer_variance=True)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _one_of(*fns: Callable[[], T | None]) -> T | None:
|
|
147
|
+
values = [value for fn in fns if (value := fn()) is not None]
|
|
148
|
+
|
|
149
|
+
# Only one (or zero) value should be set. If not, raise an error.
|
|
150
|
+
if len(set(values)) > 1:
|
|
151
|
+
raise ValueError(f"Multiple values set: {values}")
|
|
152
|
+
|
|
153
|
+
return next((value for value in values if value is not None), None)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _to_slurm(kwargs: GenericJobKwargs) -> slurm.SlurmJobKwargs:
|
|
157
|
+
slurm_kwargs: slurm.SlurmJobKwargs = {}
|
|
158
|
+
if (name := kwargs.get("name")) is not None:
|
|
159
|
+
slurm_kwargs["name"] = name
|
|
160
|
+
if (
|
|
161
|
+
account := _one_of(
|
|
162
|
+
lambda: kwargs.get("account"),
|
|
163
|
+
lambda: kwargs.get("project"),
|
|
164
|
+
)
|
|
165
|
+
) is not None:
|
|
166
|
+
slurm_kwargs["account"] = account
|
|
167
|
+
if (
|
|
168
|
+
partition := _one_of(
|
|
169
|
+
lambda: kwargs.get("partition"),
|
|
170
|
+
lambda: kwargs.get("queue"),
|
|
171
|
+
)
|
|
172
|
+
) is not None:
|
|
173
|
+
slurm_kwargs["partition"] = partition
|
|
174
|
+
if (qos := kwargs.get("qos")) is not None:
|
|
175
|
+
slurm_kwargs["qos"] = qos
|
|
176
|
+
if (output_file := kwargs.get("output_file")) is not None:
|
|
177
|
+
slurm_kwargs["output_file"] = output_file
|
|
178
|
+
if (error_file := kwargs.get("error_file")) is not None:
|
|
179
|
+
slurm_kwargs["error_file"] = error_file
|
|
180
|
+
if (walltime := kwargs.get("walltime")) is not None:
|
|
181
|
+
slurm_kwargs["time"] = walltime
|
|
182
|
+
if (memory_mb := kwargs.get("memory_mb")) is not None:
|
|
183
|
+
slurm_kwargs["memory_mb"] = memory_mb
|
|
184
|
+
if (nodes := kwargs.get("nodes")) is not None:
|
|
185
|
+
slurm_kwargs["nodes"] = nodes
|
|
186
|
+
if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
|
|
187
|
+
slurm_kwargs["ntasks_per_node"] = tasks_per_node
|
|
188
|
+
if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
|
|
189
|
+
slurm_kwargs["cpus_per_task"] = cpus_per_task
|
|
190
|
+
if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
|
|
191
|
+
slurm_kwargs["gpus_per_task"] = gpus_per_task
|
|
192
|
+
if (constraint := kwargs.get("constraint")) is not None:
|
|
193
|
+
slurm_kwargs["constraint"] = constraint
|
|
194
|
+
if (signal := kwargs.get("signal")) is not None:
|
|
195
|
+
slurm_kwargs["signal"] = signal
|
|
196
|
+
if (email := kwargs.get("email")) is not None:
|
|
197
|
+
slurm_kwargs["mail_user"] = email
|
|
198
|
+
if (notifications := kwargs.get("notifications")) is not None:
|
|
199
|
+
mail_type: list[slurm.MailType] = []
|
|
200
|
+
for notification in notifications:
|
|
201
|
+
match notification:
|
|
202
|
+
case "begin":
|
|
203
|
+
mail_type.append("BEGIN")
|
|
204
|
+
case "end":
|
|
205
|
+
mail_type.append("END")
|
|
206
|
+
case _:
|
|
207
|
+
raise ValueError(f"Unknown notification type: {notification}")
|
|
208
|
+
slurm_kwargs["mail_type"] = mail_type
|
|
209
|
+
if (setup_commands := kwargs.get("setup_commands")) is not None:
|
|
210
|
+
slurm_kwargs["setup_commands"] = setup_commands
|
|
211
|
+
if (environment := kwargs.get("environment")) is not None:
|
|
212
|
+
slurm_kwargs["environment"] = environment
|
|
213
|
+
if (command_prefix := kwargs.get("command_prefix")) is not None:
|
|
214
|
+
slurm_kwargs["command_prefix"] = command_prefix
|
|
215
|
+
if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
|
|
216
|
+
slurm_kwargs["requeue"] = requeue_on_preempt
|
|
217
|
+
if (additional_kwargs := kwargs.get("slurm_options")) is not None:
|
|
218
|
+
slurm_kwargs.update(additional_kwargs)
|
|
219
|
+
|
|
220
|
+
return slurm_kwargs
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _to_lsf(kwargs: GenericJobKwargs) -> lsf.LSFJobKwargs:
|
|
224
|
+
lsf_kwargs: lsf.LSFJobKwargs = {}
|
|
225
|
+
if (name := kwargs.get("name")) is not None:
|
|
226
|
+
lsf_kwargs["name"] = name
|
|
227
|
+
if (
|
|
228
|
+
account := _one_of(
|
|
229
|
+
lambda: kwargs.get("account"),
|
|
230
|
+
lambda: kwargs.get("project"),
|
|
231
|
+
)
|
|
232
|
+
) is not None:
|
|
233
|
+
lsf_kwargs["project"] = account
|
|
234
|
+
if (
|
|
235
|
+
partition := _one_of(
|
|
236
|
+
lambda: kwargs.get("partition"),
|
|
237
|
+
lambda: kwargs.get("queue"),
|
|
238
|
+
)
|
|
239
|
+
) is not None:
|
|
240
|
+
lsf_kwargs["queue"] = partition
|
|
241
|
+
if (output_file := kwargs.get("output_file")) is not None:
|
|
242
|
+
lsf_kwargs["output_file"] = output_file
|
|
243
|
+
if (error_file := kwargs.get("error_file")) is not None:
|
|
244
|
+
lsf_kwargs["error_file"] = error_file
|
|
245
|
+
if (walltime := kwargs.get("walltime")) is not None:
|
|
246
|
+
lsf_kwargs["walltime"] = walltime
|
|
247
|
+
if (memory_mb := kwargs.get("memory_mb")) is not None:
|
|
248
|
+
lsf_kwargs["memory_mb"] = memory_mb
|
|
249
|
+
if (nodes := kwargs.get("nodes")) is not None:
|
|
250
|
+
lsf_kwargs["nodes"] = nodes
|
|
251
|
+
if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
|
|
252
|
+
lsf_kwargs["rs_per_node"] = tasks_per_node
|
|
253
|
+
if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
|
|
254
|
+
lsf_kwargs["cpus_per_rs"] = cpus_per_task
|
|
255
|
+
if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
|
|
256
|
+
lsf_kwargs["gpus_per_rs"] = gpus_per_task
|
|
257
|
+
if (constraint := kwargs.get("constraint")) is not None:
|
|
258
|
+
log.warning(f'LSF does not support constraints, ignoring "{constraint=}".')
|
|
259
|
+
if (email := kwargs.get("email")) is not None:
|
|
260
|
+
lsf_kwargs["email"] = email
|
|
261
|
+
if (notifications := kwargs.get("notifications")) is not None:
|
|
262
|
+
if "begin" in notifications:
|
|
263
|
+
lsf_kwargs["notify_begin"] = True
|
|
264
|
+
if "end" in notifications:
|
|
265
|
+
lsf_kwargs["notify_end"] = True
|
|
266
|
+
if (setup_commands := kwargs.get("setup_commands")) is not None:
|
|
267
|
+
lsf_kwargs["setup_commands"] = setup_commands
|
|
268
|
+
if (environment := kwargs.get("environment")) is not None:
|
|
269
|
+
lsf_kwargs["environment"] = environment
|
|
270
|
+
if (command_prefix := kwargs.get("command_prefix")) is not None:
|
|
271
|
+
lsf_kwargs["command_prefix"] = command_prefix
|
|
272
|
+
if (signal := kwargs.get("signal")) is not None:
|
|
273
|
+
lsf_kwargs["signal"] = signal
|
|
274
|
+
if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
|
|
275
|
+
log.warning(
|
|
276
|
+
f'LSF does not support requeueing, ignoring "{requeue_on_preempt=}".'
|
|
277
|
+
)
|
|
278
|
+
if (additional_kwargs := kwargs.get("lsf_options")) is not None:
|
|
279
|
+
lsf_kwargs.update(additional_kwargs)
|
|
280
|
+
|
|
281
|
+
return lsf_kwargs
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def validate_kwargs(scheduler: Scheduler, kwargs: GenericJobKwargs) -> None:
|
|
285
|
+
match scheduler:
|
|
286
|
+
case "slurm":
|
|
287
|
+
_to_slurm(copy.deepcopy(kwargs))
|
|
288
|
+
case "lsf":
|
|
289
|
+
_to_lsf(copy.deepcopy(kwargs))
|
|
290
|
+
case _:
|
|
291
|
+
assert_never(scheduler)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def to_array_batch_script(
|
|
295
|
+
scheduler: Scheduler,
|
|
296
|
+
dest: Path,
|
|
297
|
+
callable: Callable[[Unpack[TArgs]], Any],
|
|
298
|
+
args_list: Sequence[tuple[Unpack[TArgs]]],
|
|
299
|
+
/,
|
|
300
|
+
job_index_variable: str | None = None,
|
|
301
|
+
print_environment_info: bool = False,
|
|
302
|
+
python_command_prefix: str | None = None,
|
|
303
|
+
**kwargs: Unpack[GenericJobKwargs],
|
|
304
|
+
) -> SubmitOutput:
|
|
305
|
+
job_index_variable_kwargs = {}
|
|
306
|
+
if job_index_variable is not None:
|
|
307
|
+
job_index_variable_kwargs["job_index_variable"] = job_index_variable
|
|
308
|
+
match scheduler:
|
|
309
|
+
case "slurm":
|
|
310
|
+
slurm_kwargs = _to_slurm(kwargs)
|
|
311
|
+
return slurm.to_array_batch_script(
|
|
312
|
+
dest,
|
|
313
|
+
callable,
|
|
314
|
+
args_list,
|
|
315
|
+
**job_index_variable_kwargs,
|
|
316
|
+
print_environment_info=print_environment_info,
|
|
317
|
+
python_command_prefix=python_command_prefix,
|
|
318
|
+
**slurm_kwargs,
|
|
319
|
+
)
|
|
320
|
+
case "lsf":
|
|
321
|
+
lsf_kwargs = _to_lsf(kwargs)
|
|
322
|
+
return lsf.to_array_batch_script(
|
|
323
|
+
dest,
|
|
324
|
+
callable,
|
|
325
|
+
args_list,
|
|
326
|
+
**job_index_variable_kwargs,
|
|
327
|
+
print_environment_info=print_environment_info,
|
|
328
|
+
python_command_prefix=python_command_prefix,
|
|
329
|
+
**lsf_kwargs,
|
|
330
|
+
)
|
|
331
|
+
case _:
|
|
332
|
+
assert_never(scheduler)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def infer_current_scheduler() -> Scheduler:
|
|
336
|
+
# First, we check for `bsub` as it's much less common than `sbatch`.
|
|
337
|
+
try:
|
|
338
|
+
subprocess.check_output(["bsub", "-V"])
|
|
339
|
+
return "lsf"
|
|
340
|
+
except BaseException:
|
|
341
|
+
pass
|
|
342
|
+
|
|
343
|
+
# Next, we check for `sbatch` as it's the most common scheduler.
|
|
344
|
+
try:
|
|
345
|
+
subprocess.check_output(["sbatch", "--version"])
|
|
346
|
+
return "slurm"
|
|
347
|
+
except BaseException:
|
|
348
|
+
pass
|
|
349
|
+
|
|
350
|
+
raise RuntimeError("Could not determine the current scheduler.")
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from ._callback import ActSaveCallback as ActSaveCallback
|
|
2
|
+
from ._loader import ActivationLoader as ActivationLoader
|
|
3
|
+
from ._loader import ActLoad as ActLoad
|
|
4
|
+
from ._saver import Activation as Activation
|
|
5
|
+
from ._saver import ActivationSaver as ActivationSaver
|
|
6
|
+
from ._saver import ActSave as ActSave
|
|
7
|
+
from ._saver import Transform as Transform
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
5
|
+
from lightning.pytorch.callbacks.callback import Callback
|
|
6
|
+
from typing_extensions import TypeAlias, override
|
|
7
|
+
|
|
8
|
+
from ._saver import ActSave
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..model.config import BaseConfig
|
|
12
|
+
|
|
13
|
+
Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ActSaveCallback(Callback):
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
|
|
21
|
+
|
|
22
|
+
def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
|
|
23
|
+
hparams = cast("BaseConfig", pl_module.hparams)
|
|
24
|
+
if not hparams.trainer.actsave:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
# If we have an active context manager for this stage, exit it
|
|
28
|
+
if active_contexts := self._active_contexts.get(stage):
|
|
29
|
+
active_contexts.__exit__(None, None, None)
|
|
30
|
+
|
|
31
|
+
# Enter a new context manager for this stage
|
|
32
|
+
context = ActSave.context(stage)
|
|
33
|
+
context.__enter__()
|
|
34
|
+
self._active_contexts[stage] = context
|
|
35
|
+
|
|
36
|
+
def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
|
|
37
|
+
hparams = cast("BaseConfig", pl_module.hparams)
|
|
38
|
+
if not hparams.trainer.actsave:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
# If we have an active context manager for this stage, exit it
|
|
42
|
+
if active_contexts := self._active_contexts.get(stage):
|
|
43
|
+
active_contexts.__exit__(None, None, None)
|
|
44
|
+
|
|
45
|
+
@override
|
|
46
|
+
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
47
|
+
return self._on_start("train", trainer, pl_module)
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
51
|
+
return self._on_end("train", trainer, pl_module)
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
55
|
+
return self._on_start("validation", trainer, pl_module)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
59
|
+
return self._on_end("validation", trainer, pl_module)
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
63
|
+
return self._on_start("test", trainer, pl_module)
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
67
|
+
return self._on_end("test", trainer, pl_module)
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def on_predict_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
71
|
+
return self._on_start("predict", trainer, pl_module)
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
75
|
+
return self._on_end("predict", trainer, pl_module)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import pprint
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from logging import getLogger
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import cast, overload
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from typing_extensions import TypeVar, override
|
|
10
|
+
|
|
11
|
+
log = getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T", infer_variance=True)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class LoadedActivation:
|
|
18
|
+
base_dir: Path = field(repr=False)
|
|
19
|
+
name: str
|
|
20
|
+
num_activations: int = field(init=False)
|
|
21
|
+
activation_files: list[Path] = field(init=False, repr=False)
|
|
22
|
+
|
|
23
|
+
def __post_init__(self):
|
|
24
|
+
if not self.activation_dir.exists():
|
|
25
|
+
raise ValueError(f"Activation dir {self.activation_dir} does not exist")
|
|
26
|
+
|
|
27
|
+
# The number of activations = the * of .npy files in the activation dir
|
|
28
|
+
self.activation_files = list(self.activation_dir.glob("*.npy"))
|
|
29
|
+
# Sort the activation files by the numerical index in the filename
|
|
30
|
+
self.activation_files.sort(key=lambda p: int(p.stem))
|
|
31
|
+
self.num_activations = len(self.activation_files)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def activation_dir(self) -> Path:
|
|
35
|
+
return self.base_dir / self.name
|
|
36
|
+
|
|
37
|
+
def _load_activation(self, item: int):
|
|
38
|
+
activation_path = self.activation_files[item]
|
|
39
|
+
if not activation_path.exists():
|
|
40
|
+
raise ValueError(f"Activation {activation_path} does not exist")
|
|
41
|
+
return cast(np.ndarray, np.load(activation_path, allow_pickle=True))
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def __getitem__(self, item: int) -> np.ndarray: ...
|
|
45
|
+
|
|
46
|
+
@overload
|
|
47
|
+
def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ...
|
|
48
|
+
|
|
49
|
+
def __getitem__(
|
|
50
|
+
self, item: int | slice | list[int]
|
|
51
|
+
) -> np.ndarray | list[np.ndarray]:
|
|
52
|
+
if isinstance(item, int):
|
|
53
|
+
return self._load_activation(item)
|
|
54
|
+
elif isinstance(item, slice):
|
|
55
|
+
return [
|
|
56
|
+
self._load_activation(i)
|
|
57
|
+
for i in range(*item.indices(self.num_activations))
|
|
58
|
+
]
|
|
59
|
+
elif isinstance(item, list):
|
|
60
|
+
return [self._load_activation(i) for i in item]
|
|
61
|
+
else:
|
|
62
|
+
raise TypeError(f"Invalid type {type(item)} for item {item}")
|
|
63
|
+
|
|
64
|
+
def __iter__(self):
|
|
65
|
+
return iter(self[i] for i in range(self.num_activations))
|
|
66
|
+
|
|
67
|
+
def __len__(self):
|
|
68
|
+
return self.num_activations
|
|
69
|
+
|
|
70
|
+
def all_activations(self):
|
|
71
|
+
return [self[i] for i in range(self.num_activations)]
|
|
72
|
+
|
|
73
|
+
@override
|
|
74
|
+
def __repr__(self):
|
|
75
|
+
return f"<LoadedActivation {self.name} ({self.num_activations} activations)>"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ActLoad:
|
|
79
|
+
@classmethod
|
|
80
|
+
def all_versions(cls, dir: str | Path):
|
|
81
|
+
dir = Path(dir)
|
|
82
|
+
|
|
83
|
+
# If the dir is not an activation base directory, we return None
|
|
84
|
+
if not (dir / ".activationbase").exists():
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
# The contents of `dir` should be directories, each of which is a version.
|
|
88
|
+
return [
|
|
89
|
+
(subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir()
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def is_valid_activation_base(cls, dir: str | Path):
|
|
94
|
+
return cls.all_versions(dir) is not None
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_latest_version(cls, dir: str | Path):
|
|
98
|
+
# The contents of `dir` should be directories, each of which is a version
|
|
99
|
+
# We need to find the latest version
|
|
100
|
+
if (all_versions := cls.all_versions(dir)) is None:
|
|
101
|
+
raise ValueError(f"{dir} is not an activation base directory")
|
|
102
|
+
|
|
103
|
+
path, _ = max(all_versions, key=lambda p: p[1])
|
|
104
|
+
return cls(path)
|
|
105
|
+
|
|
106
|
+
def __init__(self, dir: Path):
|
|
107
|
+
self._dir = dir
|
|
108
|
+
|
|
109
|
+
def activation(self, name: str):
|
|
110
|
+
return LoadedActivation(self._dir, name)
|
|
111
|
+
|
|
112
|
+
@cached_property
|
|
113
|
+
def activations(self):
|
|
114
|
+
dirs = list(self._dir.iterdir())
|
|
115
|
+
# Sort the dirs by the last modified time
|
|
116
|
+
dirs.sort(key=lambda p: p.stat().st_mtime)
|
|
117
|
+
|
|
118
|
+
return {p.name: LoadedActivation(self._dir, p.name) for p in dirs}
|
|
119
|
+
|
|
120
|
+
def __iter__(self):
|
|
121
|
+
return iter(self.activations.values())
|
|
122
|
+
|
|
123
|
+
def __getitem__(self, item: str):
|
|
124
|
+
return self.activations[item]
|
|
125
|
+
|
|
126
|
+
def __len__(self):
|
|
127
|
+
return len(self.activations)
|
|
128
|
+
|
|
129
|
+
@override
|
|
130
|
+
def __repr__(self):
|
|
131
|
+
acts_str = pprint.pformat(
|
|
132
|
+
{
|
|
133
|
+
name: f"<{activation.num_activations} activations>"
|
|
134
|
+
for name, activation in self.activations.items()
|
|
135
|
+
}
|
|
136
|
+
)
|
|
137
|
+
acts_str = acts_str.replace("'<", "<").replace(">'", ">")
|
|
138
|
+
return f"ActLoad({acts_str})"
|
|
139
|
+
|
|
140
|
+
def get(self, name: str, /, default: T) -> LoadedActivation | T:
|
|
141
|
+
return self.activations.get(name, default)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
ActivationLoader = ActLoad
|