nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,350 @@
1
+ import copy
2
+ import logging
3
+ import os
4
+ import signal
5
+ import subprocess
6
+ from collections.abc import Callable, Mapping, Sequence
7
+ from datetime import timedelta
8
+ from pathlib import Path
9
+ from typing import Any, Literal
10
+
11
+ from typing_extensions import (
12
+ TypeAlias,
13
+ TypedDict,
14
+ TypeVar,
15
+ TypeVarTuple,
16
+ Unpack,
17
+ assert_never,
18
+ )
19
+
20
+ from . import lsf, slurm
21
+ from ._output import SubmitOutput
22
+
23
+ TArgs = TypeVarTuple("TArgs")
24
+ _Path: TypeAlias = str | Path | os.PathLike
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+
29
+ class GenericJobKwargs(TypedDict, total=False):
30
+ name: str
31
+ """The name of the job."""
32
+
33
+ partition: str | Sequence[str]
34
+ """The partition or queue to submit the job to. Same as `queue`."""
35
+
36
+ queue: str | Sequence[str]
37
+ """The queue to submit the job to. Same as `partition`."""
38
+
39
+ qos: str
40
+ """
41
+ The quality of service to submit the job to.
42
+
43
+ This corresponds to the "--qos" option in sbatch (only for Slurm).
44
+ """
45
+
46
+ account: str
47
+ """The account (or project) to charge the job to. Same as `project`."""
48
+
49
+ project: str
50
+ """The project (or account) to charge the job to. Same as `account`."""
51
+
52
+ output_file: _Path
53
+ """
54
+ The file to write the job output to.
55
+
56
+ This corresponds to the "-o" option in bsub. If not specified, the output will be written to the default output file.
57
+ """
58
+
59
+ error_file: _Path
60
+ """
61
+ The file to write the job errors to.
62
+
63
+ This corresponds to the "-e" option in bsub. If not specified, the errors will be written to the default error file.
64
+ """
65
+
66
+ nodes: int
67
+ """The number of nodes to request."""
68
+
69
+ tasks_per_node: int
70
+ """The number of tasks to request per node."""
71
+
72
+ cpus_per_task: int
73
+ """The number of CPUs to request per task."""
74
+
75
+ gpus_per_task: int
76
+ """The number of GPUs to request per task."""
77
+
78
+ memory_mb: int
79
+ """The maximum memory for the job in MB."""
80
+
81
+ walltime: timedelta
82
+ """The maximum walltime for the job."""
83
+
84
+ email: str
85
+ """The email address to send notifications to."""
86
+
87
+ notifications: set[Literal["begin", "end"]]
88
+ """The notifications to send via email."""
89
+
90
+ setup_commands: Sequence[str]
91
+ """
92
+ The setup commands to run before the job.
93
+
94
+ These commands will be executed prior to everything else in the job script.
95
+ """
96
+
97
+ environment: Mapping[str, str]
98
+ """
99
+ The environment variables to set for the job.
100
+
101
+ These variables will be set prior to executing any commands in the job script.
102
+ """
103
+
104
+ command_prefix: str
105
+ """
106
+ A command to prefix the job command with.
107
+
108
+ This is used to add commands like `srun` or `jsrun` to the job command.
109
+ """
110
+
111
+ constraint: str | Sequence[str]
112
+ """
113
+ The constraint to request for the job. For SLRUM, this corresponds to the `--constraint` option. For LSF, this is unused.
114
+ """
115
+
116
+ signal: signal.Signals
117
+ """The signal that will be sent to the job when it is time to stop it."""
118
+
119
+ command_template: str
120
+ """
121
+ The template for the command to execute the helper script.
122
+
123
+ Default: `bash {script}`.
124
+ """
125
+
126
+ requeue_on_preempt: bool
127
+ """
128
+ Whether to requeue the job if it is preempted.
129
+
130
+ This corresponds to the "--requeue" option in sbatch (only for Slurm).
131
+ """
132
+
133
+ slurm_options: slurm.SlurmJobKwargs
134
+ """Additional keyword arguments for Slurm jobs."""
135
+
136
+ lsf_options: lsf.LSFJobKwargs
137
+ """Additional keyword arguments for LSF jobs."""
138
+
139
+
140
+ Scheduler: TypeAlias = Literal["slurm", "lsf"]
141
+
142
+
143
+ T = TypeVar("T", infer_variance=True)
144
+
145
+
146
+ def _one_of(*fns: Callable[[], T | None]) -> T | None:
147
+ values = [value for fn in fns if (value := fn()) is not None]
148
+
149
+ # Only one (or zero) value should be set. If not, raise an error.
150
+ if len(set(values)) > 1:
151
+ raise ValueError(f"Multiple values set: {values}")
152
+
153
+ return next((value for value in values if value is not None), None)
154
+
155
+
156
+ def _to_slurm(kwargs: GenericJobKwargs) -> slurm.SlurmJobKwargs:
157
+ slurm_kwargs: slurm.SlurmJobKwargs = {}
158
+ if (name := kwargs.get("name")) is not None:
159
+ slurm_kwargs["name"] = name
160
+ if (
161
+ account := _one_of(
162
+ lambda: kwargs.get("account"),
163
+ lambda: kwargs.get("project"),
164
+ )
165
+ ) is not None:
166
+ slurm_kwargs["account"] = account
167
+ if (
168
+ partition := _one_of(
169
+ lambda: kwargs.get("partition"),
170
+ lambda: kwargs.get("queue"),
171
+ )
172
+ ) is not None:
173
+ slurm_kwargs["partition"] = partition
174
+ if (qos := kwargs.get("qos")) is not None:
175
+ slurm_kwargs["qos"] = qos
176
+ if (output_file := kwargs.get("output_file")) is not None:
177
+ slurm_kwargs["output_file"] = output_file
178
+ if (error_file := kwargs.get("error_file")) is not None:
179
+ slurm_kwargs["error_file"] = error_file
180
+ if (walltime := kwargs.get("walltime")) is not None:
181
+ slurm_kwargs["time"] = walltime
182
+ if (memory_mb := kwargs.get("memory_mb")) is not None:
183
+ slurm_kwargs["memory_mb"] = memory_mb
184
+ if (nodes := kwargs.get("nodes")) is not None:
185
+ slurm_kwargs["nodes"] = nodes
186
+ if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
187
+ slurm_kwargs["ntasks_per_node"] = tasks_per_node
188
+ if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
189
+ slurm_kwargs["cpus_per_task"] = cpus_per_task
190
+ if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
191
+ slurm_kwargs["gpus_per_task"] = gpus_per_task
192
+ if (constraint := kwargs.get("constraint")) is not None:
193
+ slurm_kwargs["constraint"] = constraint
194
+ if (signal := kwargs.get("signal")) is not None:
195
+ slurm_kwargs["signal"] = signal
196
+ if (email := kwargs.get("email")) is not None:
197
+ slurm_kwargs["mail_user"] = email
198
+ if (notifications := kwargs.get("notifications")) is not None:
199
+ mail_type: list[slurm.MailType] = []
200
+ for notification in notifications:
201
+ match notification:
202
+ case "begin":
203
+ mail_type.append("BEGIN")
204
+ case "end":
205
+ mail_type.append("END")
206
+ case _:
207
+ raise ValueError(f"Unknown notification type: {notification}")
208
+ slurm_kwargs["mail_type"] = mail_type
209
+ if (setup_commands := kwargs.get("setup_commands")) is not None:
210
+ slurm_kwargs["setup_commands"] = setup_commands
211
+ if (environment := kwargs.get("environment")) is not None:
212
+ slurm_kwargs["environment"] = environment
213
+ if (command_prefix := kwargs.get("command_prefix")) is not None:
214
+ slurm_kwargs["command_prefix"] = command_prefix
215
+ if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
216
+ slurm_kwargs["requeue"] = requeue_on_preempt
217
+ if (additional_kwargs := kwargs.get("slurm_options")) is not None:
218
+ slurm_kwargs.update(additional_kwargs)
219
+
220
+ return slurm_kwargs
221
+
222
+
223
+ def _to_lsf(kwargs: GenericJobKwargs) -> lsf.LSFJobKwargs:
224
+ lsf_kwargs: lsf.LSFJobKwargs = {}
225
+ if (name := kwargs.get("name")) is not None:
226
+ lsf_kwargs["name"] = name
227
+ if (
228
+ account := _one_of(
229
+ lambda: kwargs.get("account"),
230
+ lambda: kwargs.get("project"),
231
+ )
232
+ ) is not None:
233
+ lsf_kwargs["project"] = account
234
+ if (
235
+ partition := _one_of(
236
+ lambda: kwargs.get("partition"),
237
+ lambda: kwargs.get("queue"),
238
+ )
239
+ ) is not None:
240
+ lsf_kwargs["queue"] = partition
241
+ if (output_file := kwargs.get("output_file")) is not None:
242
+ lsf_kwargs["output_file"] = output_file
243
+ if (error_file := kwargs.get("error_file")) is not None:
244
+ lsf_kwargs["error_file"] = error_file
245
+ if (walltime := kwargs.get("walltime")) is not None:
246
+ lsf_kwargs["walltime"] = walltime
247
+ if (memory_mb := kwargs.get("memory_mb")) is not None:
248
+ lsf_kwargs["memory_mb"] = memory_mb
249
+ if (nodes := kwargs.get("nodes")) is not None:
250
+ lsf_kwargs["nodes"] = nodes
251
+ if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
252
+ lsf_kwargs["rs_per_node"] = tasks_per_node
253
+ if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
254
+ lsf_kwargs["cpus_per_rs"] = cpus_per_task
255
+ if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
256
+ lsf_kwargs["gpus_per_rs"] = gpus_per_task
257
+ if (constraint := kwargs.get("constraint")) is not None:
258
+ log.warning(f'LSF does not support constraints, ignoring "{constraint=}".')
259
+ if (email := kwargs.get("email")) is not None:
260
+ lsf_kwargs["email"] = email
261
+ if (notifications := kwargs.get("notifications")) is not None:
262
+ if "begin" in notifications:
263
+ lsf_kwargs["notify_begin"] = True
264
+ if "end" in notifications:
265
+ lsf_kwargs["notify_end"] = True
266
+ if (setup_commands := kwargs.get("setup_commands")) is not None:
267
+ lsf_kwargs["setup_commands"] = setup_commands
268
+ if (environment := kwargs.get("environment")) is not None:
269
+ lsf_kwargs["environment"] = environment
270
+ if (command_prefix := kwargs.get("command_prefix")) is not None:
271
+ lsf_kwargs["command_prefix"] = command_prefix
272
+ if (signal := kwargs.get("signal")) is not None:
273
+ lsf_kwargs["signal"] = signal
274
+ if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
275
+ log.warning(
276
+ f'LSF does not support requeueing, ignoring "{requeue_on_preempt=}".'
277
+ )
278
+ if (additional_kwargs := kwargs.get("lsf_options")) is not None:
279
+ lsf_kwargs.update(additional_kwargs)
280
+
281
+ return lsf_kwargs
282
+
283
+
284
+ def validate_kwargs(scheduler: Scheduler, kwargs: GenericJobKwargs) -> None:
285
+ match scheduler:
286
+ case "slurm":
287
+ _to_slurm(copy.deepcopy(kwargs))
288
+ case "lsf":
289
+ _to_lsf(copy.deepcopy(kwargs))
290
+ case _:
291
+ assert_never(scheduler)
292
+
293
+
294
+ def to_array_batch_script(
295
+ scheduler: Scheduler,
296
+ dest: Path,
297
+ callable: Callable[[Unpack[TArgs]], Any],
298
+ args_list: Sequence[tuple[Unpack[TArgs]]],
299
+ /,
300
+ job_index_variable: str | None = None,
301
+ print_environment_info: bool = False,
302
+ python_command_prefix: str | None = None,
303
+ **kwargs: Unpack[GenericJobKwargs],
304
+ ) -> SubmitOutput:
305
+ job_index_variable_kwargs = {}
306
+ if job_index_variable is not None:
307
+ job_index_variable_kwargs["job_index_variable"] = job_index_variable
308
+ match scheduler:
309
+ case "slurm":
310
+ slurm_kwargs = _to_slurm(kwargs)
311
+ return slurm.to_array_batch_script(
312
+ dest,
313
+ callable,
314
+ args_list,
315
+ **job_index_variable_kwargs,
316
+ print_environment_info=print_environment_info,
317
+ python_command_prefix=python_command_prefix,
318
+ **slurm_kwargs,
319
+ )
320
+ case "lsf":
321
+ lsf_kwargs = _to_lsf(kwargs)
322
+ return lsf.to_array_batch_script(
323
+ dest,
324
+ callable,
325
+ args_list,
326
+ **job_index_variable_kwargs,
327
+ print_environment_info=print_environment_info,
328
+ python_command_prefix=python_command_prefix,
329
+ **lsf_kwargs,
330
+ )
331
+ case _:
332
+ assert_never(scheduler)
333
+
334
+
335
+ def infer_current_scheduler() -> Scheduler:
336
+ # First, we check for `bsub` as it's much less common than `sbatch`.
337
+ try:
338
+ subprocess.check_output(["bsub", "-V"])
339
+ return "lsf"
340
+ except BaseException:
341
+ pass
342
+
343
+ # Next, we check for `sbatch` as it's the most common scheduler.
344
+ try:
345
+ subprocess.check_output(["sbatch", "--version"])
346
+ return "slurm"
347
+ except BaseException:
348
+ pass
349
+
350
+ raise RuntimeError("Could not determine the current scheduler.")
@@ -0,0 +1,7 @@
1
+ from ._callback import ActSaveCallback as ActSaveCallback
2
+ from ._loader import ActivationLoader as ActivationLoader
3
+ from ._loader import ActLoad as ActLoad
4
+ from ._saver import Activation as Activation
5
+ from ._saver import ActivationSaver as ActivationSaver
6
+ from ._saver import ActSave as ActSave
7
+ from ._saver import Transform as Transform
@@ -0,0 +1,75 @@
1
+ import contextlib
2
+ from typing import TYPE_CHECKING, Literal, cast
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from lightning.pytorch.callbacks.callback import Callback
6
+ from typing_extensions import TypeAlias, override
7
+
8
+ from ._saver import ActSave
9
+
10
+ if TYPE_CHECKING:
11
+ from ..model.config import BaseConfig
12
+
13
+ Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
14
+
15
+
16
+ class ActSaveCallback(Callback):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ self._active_contexts: dict[Stage, contextlib._GeneratorContextManager] = {}
21
+
22
+ def _on_start(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
23
+ hparams = cast("BaseConfig", pl_module.hparams)
24
+ if not hparams.trainer.actsave:
25
+ return
26
+
27
+ # If we have an active context manager for this stage, exit it
28
+ if active_contexts := self._active_contexts.get(stage):
29
+ active_contexts.__exit__(None, None, None)
30
+
31
+ # Enter a new context manager for this stage
32
+ context = ActSave.context(stage)
33
+ context.__enter__()
34
+ self._active_contexts[stage] = context
35
+
36
+ def _on_end(self, stage: Stage, trainer: Trainer, pl_module: LightningModule):
37
+ hparams = cast("BaseConfig", pl_module.hparams)
38
+ if not hparams.trainer.actsave:
39
+ return
40
+
41
+ # If we have an active context manager for this stage, exit it
42
+ if active_contexts := self._active_contexts.get(stage):
43
+ active_contexts.__exit__(None, None, None)
44
+
45
+ @override
46
+ def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
47
+ return self._on_start("train", trainer, pl_module)
48
+
49
+ @override
50
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
51
+ return self._on_end("train", trainer, pl_module)
52
+
53
+ @override
54
+ def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
55
+ return self._on_start("validation", trainer, pl_module)
56
+
57
+ @override
58
+ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
59
+ return self._on_end("validation", trainer, pl_module)
60
+
61
+ @override
62
+ def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
63
+ return self._on_start("test", trainer, pl_module)
64
+
65
+ @override
66
+ def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
67
+ return self._on_end("test", trainer, pl_module)
68
+
69
+ @override
70
+ def on_predict_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
71
+ return self._on_start("predict", trainer, pl_module)
72
+
73
+ @override
74
+ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
75
+ return self._on_end("predict", trainer, pl_module)
@@ -0,0 +1,144 @@
1
+ import pprint
2
+ from dataclasses import dataclass, field
3
+ from functools import cached_property
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import cast, overload
7
+
8
+ import numpy as np
9
+ from typing_extensions import TypeVar, override
10
+
11
+ log = getLogger(__name__)
12
+
13
+ T = TypeVar("T", infer_variance=True)
14
+
15
+
16
+ @dataclass
17
+ class LoadedActivation:
18
+ base_dir: Path = field(repr=False)
19
+ name: str
20
+ num_activations: int = field(init=False)
21
+ activation_files: list[Path] = field(init=False, repr=False)
22
+
23
+ def __post_init__(self):
24
+ if not self.activation_dir.exists():
25
+ raise ValueError(f"Activation dir {self.activation_dir} does not exist")
26
+
27
+ # The number of activations = the * of .npy files in the activation dir
28
+ self.activation_files = list(self.activation_dir.glob("*.npy"))
29
+ # Sort the activation files by the numerical index in the filename
30
+ self.activation_files.sort(key=lambda p: int(p.stem))
31
+ self.num_activations = len(self.activation_files)
32
+
33
+ @property
34
+ def activation_dir(self) -> Path:
35
+ return self.base_dir / self.name
36
+
37
+ def _load_activation(self, item: int):
38
+ activation_path = self.activation_files[item]
39
+ if not activation_path.exists():
40
+ raise ValueError(f"Activation {activation_path} does not exist")
41
+ return cast(np.ndarray, np.load(activation_path, allow_pickle=True))
42
+
43
+ @overload
44
+ def __getitem__(self, item: int) -> np.ndarray: ...
45
+
46
+ @overload
47
+ def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ...
48
+
49
+ def __getitem__(
50
+ self, item: int | slice | list[int]
51
+ ) -> np.ndarray | list[np.ndarray]:
52
+ if isinstance(item, int):
53
+ return self._load_activation(item)
54
+ elif isinstance(item, slice):
55
+ return [
56
+ self._load_activation(i)
57
+ for i in range(*item.indices(self.num_activations))
58
+ ]
59
+ elif isinstance(item, list):
60
+ return [self._load_activation(i) for i in item]
61
+ else:
62
+ raise TypeError(f"Invalid type {type(item)} for item {item}")
63
+
64
+ def __iter__(self):
65
+ return iter(self[i] for i in range(self.num_activations))
66
+
67
+ def __len__(self):
68
+ return self.num_activations
69
+
70
+ def all_activations(self):
71
+ return [self[i] for i in range(self.num_activations)]
72
+
73
+ @override
74
+ def __repr__(self):
75
+ return f"<LoadedActivation {self.name} ({self.num_activations} activations)>"
76
+
77
+
78
+ class ActLoad:
79
+ @classmethod
80
+ def all_versions(cls, dir: str | Path):
81
+ dir = Path(dir)
82
+
83
+ # If the dir is not an activation base directory, we return None
84
+ if not (dir / ".activationbase").exists():
85
+ return None
86
+
87
+ # The contents of `dir` should be directories, each of which is a version.
88
+ return [
89
+ (subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir()
90
+ ]
91
+
92
+ @classmethod
93
+ def is_valid_activation_base(cls, dir: str | Path):
94
+ return cls.all_versions(dir) is not None
95
+
96
+ @classmethod
97
+ def from_latest_version(cls, dir: str | Path):
98
+ # The contents of `dir` should be directories, each of which is a version
99
+ # We need to find the latest version
100
+ if (all_versions := cls.all_versions(dir)) is None:
101
+ raise ValueError(f"{dir} is not an activation base directory")
102
+
103
+ path, _ = max(all_versions, key=lambda p: p[1])
104
+ return cls(path)
105
+
106
+ def __init__(self, dir: Path):
107
+ self._dir = dir
108
+
109
+ def activation(self, name: str):
110
+ return LoadedActivation(self._dir, name)
111
+
112
+ @cached_property
113
+ def activations(self):
114
+ dirs = list(self._dir.iterdir())
115
+ # Sort the dirs by the last modified time
116
+ dirs.sort(key=lambda p: p.stat().st_mtime)
117
+
118
+ return {p.name: LoadedActivation(self._dir, p.name) for p in dirs}
119
+
120
+ def __iter__(self):
121
+ return iter(self.activations.values())
122
+
123
+ def __getitem__(self, item: str):
124
+ return self.activations[item]
125
+
126
+ def __len__(self):
127
+ return len(self.activations)
128
+
129
+ @override
130
+ def __repr__(self):
131
+ acts_str = pprint.pformat(
132
+ {
133
+ name: f"<{activation.num_activations} activations>"
134
+ for name, activation in self.activations.items()
135
+ }
136
+ )
137
+ acts_str = acts_str.replace("'<", "<").replace(">'", ">")
138
+ return f"ActLoad({acts_str})"
139
+
140
+ def get(self, name: str, /, default: T) -> LoadedActivation | T:
141
+ return self.activations.get(name, default)
142
+
143
+
144
+ ActivationLoader = ActLoad