nshtrainer 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_snoop.py CHANGED
@@ -1,216 +1 @@
1
- import contextlib
2
- from typing import Any, Protocol, cast
3
-
4
- from typing_extensions import TypeVar
5
-
6
- T = TypeVar("T", infer_variance=True)
7
-
8
-
9
- class SnoopConstructor(Protocol):
10
- def __call__(self, *args, **kwargs) -> contextlib.AbstractContextManager: ...
11
-
12
- def disable(self) -> contextlib.AbstractContextManager: ...
13
-
14
-
15
- try:
16
- import warnings
17
- from contextlib import nullcontext
18
-
19
- import lovely_numpy as lo
20
- import lovely_tensors as lt
21
- import numpy
22
- import pysnooper
23
- import pysnooper.utils
24
- import torch
25
- from pkg_resources import DistributionNotFound, get_distribution
26
-
27
- FLOATING_POINTS = set()
28
- for i in ["float", "double", "half", "complex128", "complex32", "complex64"]:
29
- if hasattr(torch, i): # older version of PyTorch do not have complex dtypes
30
- FLOATING_POINTS.add(getattr(torch, i))
31
-
32
- try:
33
- __version__ = get_distribution(__name__).version
34
- except DistributionNotFound:
35
- # package is not installed
36
- pass
37
-
38
- def default_format(x):
39
- try:
40
- formatted = str(lt.lovely(x))
41
- return formatted
42
- except BaseException:
43
- return str(x.shape)
44
-
45
- def default_numpy_format(x):
46
- return str(lo.lovely(x))
47
-
48
- class TorchSnooper(pysnooper.tracer.Tracer):
49
- def __init__(
50
- self,
51
- *args,
52
- tensor_format=default_format,
53
- numpy_format=default_numpy_format,
54
- **kwargs,
55
- ):
56
- self.orig_custom_repr = (
57
- kwargs["custom_repr"] if "custom_repr" in kwargs else ()
58
- )
59
- custom_repr = (lambda x: True, self.compute_repr)
60
- kwargs["custom_repr"] = (custom_repr,)
61
- super(TorchSnooper, self).__init__(*args, **kwargs)
62
- self.tensor_format = tensor_format
63
- self.numpy_format = numpy_format
64
-
65
- @staticmethod
66
- def is_return_types(x):
67
- return type(x).__module__ == "torch.return_types"
68
-
69
- def return_types_repr(self, x):
70
- if type(x).__name__ in {
71
- "max",
72
- "min",
73
- "median",
74
- "mode",
75
- "sort",
76
- "topk",
77
- "kthvalue",
78
- }:
79
- return (
80
- type(x).__name__
81
- + "(values="
82
- + self.tensor_format(x.values)
83
- + ", indices="
84
- + self.tensor_format(x.indices)
85
- + ")"
86
- )
87
- if type(x).__name__ == "svd":
88
- return (
89
- "svd(U="
90
- + self.tensor_format(x.U)
91
- + ", S="
92
- + self.tensor_format(x.S)
93
- + ", V="
94
- + self.tensor_format(x.V)
95
- + ")"
96
- )
97
- if type(x).__name__ == "slogdet":
98
- return (
99
- "slogdet(sign="
100
- + self.tensor_format(x.sign)
101
- + ", logabsdet="
102
- + self.tensor_format(x.logabsdet)
103
- + ")"
104
- )
105
- if type(x).__name__ == "qr":
106
- return (
107
- "qr(Q="
108
- + self.tensor_format(x.Q)
109
- + ", R="
110
- + self.tensor_format(x.R)
111
- + ")"
112
- )
113
- if type(x).__name__ == "solve":
114
- return (
115
- "solve(solution="
116
- + self.tensor_format(x.solution)
117
- + ", LU="
118
- + self.tensor_format(x.LU)
119
- + ")"
120
- )
121
- if type(x).__name__ == "geqrf":
122
- return (
123
- "geqrf(a="
124
- + self.tensor_format(x.a)
125
- + ", tau="
126
- + self.tensor_format(x.tau)
127
- + ")"
128
- )
129
- if type(x).__name__ in {"symeig", "eig"}:
130
- return (
131
- type(x).__name__
132
- + "(eigenvalues="
133
- + self.tensor_format(x.eigenvalues)
134
- + ", eigenvectors="
135
- + self.tensor_format(x.eigenvectors)
136
- + ")"
137
- )
138
- if type(x).__name__ == "triangular_solve":
139
- return (
140
- "triangular_solve(solution="
141
- + self.tensor_format(x.solution)
142
- + ", cloned_coefficient="
143
- + self.tensor_format(x.cloned_coefficient)
144
- + ")"
145
- )
146
- if type(x).__name__ == "gels":
147
- return (
148
- "gels(solution="
149
- + self.tensor_format(x.solution)
150
- + ", QR="
151
- + self.tensor_format(x.QR)
152
- + ")"
153
- )
154
- warnings.warn("Unknown return_types encountered, open a bug report!")
155
-
156
- def compute_repr(self, x):
157
- orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr)
158
- if torch.is_tensor(x):
159
- return self.tensor_format(x)
160
- if isinstance(x, numpy.ndarray):
161
- return self.numpy_format(x)
162
- if self.is_return_types(x):
163
- return self.return_types_repr(x)
164
- if orig_repr_func is not repr:
165
- return orig_repr_func(x)
166
- if isinstance(x, (list, tuple)):
167
- content = ""
168
- for i in x:
169
- if content != "":
170
- content += ", "
171
- content += self.compute_repr(i)
172
- if isinstance(x, tuple) and len(x) == 1:
173
- content += ","
174
- if isinstance(x, tuple):
175
- return "(" + content + ")"
176
- return "[" + content + "]"
177
- if isinstance(x, dict):
178
- content = ""
179
- for k, v in x.items():
180
- if content != "":
181
- content += ", "
182
- content += self.compute_repr(k) + ": " + self.compute_repr(v)
183
- return "{" + content + "}"
184
- return repr(x)
185
-
186
- class _Snoop:
187
- disable = nullcontext
188
- __call__ = TorchSnooper
189
-
190
- snoop: SnoopConstructor = cast(Any, _Snoop())
191
-
192
- except ImportError:
193
- import warnings
194
- from contextlib import nullcontext
195
-
196
- from typing_extensions import override
197
-
198
- _has_warned = False
199
-
200
- class _snoop_cls(nullcontext):
201
- @classmethod
202
- def disable(cls):
203
- return nullcontext()
204
-
205
- @override
206
- def __enter__(self):
207
- global _has_warned
208
- if not _has_warned:
209
- warnings.warn(
210
- "snoop is not installed, please install it to enable snoop"
211
- )
212
- _has_warned = True
213
-
214
- return super().__enter__()
215
-
216
- snoop: SnoopConstructor = cast(Any, _snoop_cls)
1
+ from nshutils.snoop import * # type: ignore # noqa: F403
@@ -1,7 +1,3 @@
1
+ from nshutils.actsave import * # type: ignore # noqa: F403
2
+
1
3
  from ._callback import ActSaveCallback as ActSaveCallback
2
- from ._loader import ActivationLoader as ActivationLoader
3
- from ._loader import ActLoad as ActLoad
4
- from ._saver import Activation as Activation
5
- from ._saver import ActivationSaver as ActivationSaver
6
- from ._saver import ActSave as ActSave
7
- from ._saver import Transform as Transform
@@ -3,10 +3,9 @@ from typing import TYPE_CHECKING, Literal, cast
3
3
 
4
4
  from lightning.pytorch import LightningModule, Trainer
5
5
  from lightning.pytorch.callbacks.callback import Callback
6
+ from nshutils.actsave import ActSave
6
7
  from typing_extensions import TypeAlias, override
7
8
 
8
- from ._saver import ActSave
9
-
10
9
  if TYPE_CHECKING:
11
10
  from ..model.config import BaseConfig
12
11
 
nshtrainer/typecheck.py CHANGED
@@ -1,145 +1 @@
1
- import os
2
- from collections.abc import Sequence
3
- from logging import getLogger
4
- from typing import Any
5
-
6
- import numpy as np
7
- import torch
8
- from jaxtyping import BFloat16 as BFloat16
9
- from jaxtyping import Bool as Bool
10
- from jaxtyping import Complex as Complex
11
- from jaxtyping import Complex64 as Complex64
12
- from jaxtyping import Complex128 as Complex128
13
- from jaxtyping import Float as Float
14
- from jaxtyping import Float16 as Float16
15
- from jaxtyping import Float32 as Float32
16
- from jaxtyping import Float64 as Float64
17
- from jaxtyping import Inexact as Inexact
18
- from jaxtyping import Int as Int
19
- from jaxtyping import Int4 as Int4
20
- from jaxtyping import Int8 as Int8
21
- from jaxtyping import Int16 as Int16
22
- from jaxtyping import Int32 as Int32
23
- from jaxtyping import Int64 as Int64
24
- from jaxtyping import Integer as Integer
25
- from jaxtyping import Key as Key
26
- from jaxtyping import Num as Num
27
- from jaxtyping import Real as Real
28
- from jaxtyping import Shaped as Shaped
29
- from jaxtyping import UInt as UInt
30
- from jaxtyping import UInt4 as UInt4
31
- from jaxtyping import UInt8 as UInt8
32
- from jaxtyping import UInt16 as UInt16
33
- from jaxtyping import UInt32 as UInt32
34
- from jaxtyping import UInt64 as UInt64
35
- from jaxtyping._storage import get_shape_memo, shape_str
36
- from torch import Tensor as Tensor
37
- from torch.nn.parameter import Parameter as Parameter
38
- from typing_extensions import TypeVar
39
-
40
- log = getLogger(__name__)
41
-
42
- DISABLE_ENV_KEY = "LL_DISABLE_TYPECHECKING"
43
-
44
-
45
- def typecheck_modules(modules: Sequence[str]):
46
- """
47
- Typecheck the given modules using `jaxtyping`.
48
-
49
- Args:
50
- modules: Modules to typecheck.
51
- """
52
- # If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
53
- # typechecking.
54
- if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
55
- log.critical(
56
- f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
57
- )
58
- return
59
-
60
- # Install the jaxtyping import hook for this module.
61
- from jaxtyping import install_import_hook
62
-
63
- install_import_hook(modules, "beartype.beartype")
64
-
65
- log.critical(f"Type checking the following modules: {modules}")
66
-
67
-
68
- def typecheck_this_module(additional_modules: Sequence[str] = ()):
69
- """
70
- Typecheck the calling module and any additional modules using `jaxtyping`.
71
-
72
- Args:
73
- additional_modules: Additional modules to typecheck.
74
- """
75
- # Get the calling module's name.
76
- # Here, we can just use beartype's internal implementation behind
77
- # `beartype_this_package`.
78
- from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
79
-
80
- # Get the calling module's name.
81
- assert get_frame is not None, "get_frame is None"
82
- frame = get_frame(1)
83
- assert frame is not None, "frame is None"
84
- calling_module_name = get_frame_package_name(frame)
85
- assert calling_module_name is not None, "calling_module_name is None"
86
-
87
- # Typecheck the calling module + any additional modules.
88
- typecheck_modules((calling_module_name, *additional_modules))
89
-
90
-
91
- def _make_error_str(input: Any, t: Any) -> str:
92
- error_components: list[str] = []
93
- error_components.append("Type checking error:")
94
- if hasattr(t, "__instancecheck_str__"):
95
- error_components.append(t.__instancecheck_str__(input))
96
- if torch.is_tensor(input):
97
- try:
98
- from lovely_tensors import lovely
99
-
100
- error_components.append(repr(lovely(input)))
101
- except BaseException:
102
- error_components.append(repr(input.shape))
103
- error_components.append(shape_str(get_shape_memo()))
104
-
105
- return "\n".join(error_components)
106
-
107
-
108
- T = TypeVar("T", torch.Tensor, np.ndarray, infer_variance=True)
109
-
110
- """
111
- Patch to jaxtyping:
112
-
113
- In `jaxtyping._import_hook`, we add:
114
- def _has_isinstance_or_tassert(func_def):
115
- for node in ast.walk(func_def):
116
- if isinstance(node, ast.Call):
117
- if isinstance(node.func, ast.Name) and node.func.id == "isinstance":
118
- return True
119
- elif isinstance(node.func, ast.Name) and node.func.id == "tassert":
120
- return True
121
- return False
122
-
123
- and we check this when adding the decorators.
124
- """
125
-
126
-
127
- def tassert(t: Any, input: T | tuple[T, ...]):
128
- """
129
- Typecheck the input against the given type.
130
-
131
- Args:
132
- t: Type to check against.
133
- input: Input to check.
134
- """
135
-
136
- # Ignore typechecking if the environment variable is set.
137
- if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
138
- return
139
-
140
- if isinstance(input, tuple):
141
- for i in input:
142
- assert isinstance(i, t), _make_error_str(i, t)
143
- return
144
- else:
145
- assert isinstance(input, t), _make_error_str(input, t)
1
+ from nshutils.typecheck import * # type: ignore # noqa: F403
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,15 +9,13 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
- Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
- Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
12
  Requires-Dist: lightning
15
13
  Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
16
14
  Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
17
15
  Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
18
16
  Requires-Dist: nshrunner (>=0.5.4,<0.6.0)
17
+ Requires-Dist: nshutils (>=0.3.0,<0.4.0)
19
18
  Requires-Dist: numpy
20
- Requires-Dist: pysnooper
21
19
  Requires-Dist: pytorch-lightning
22
20
  Requires-Dist: rich
23
21
  Requires-Dist: torch
@@ -3,11 +3,9 @@ nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ
3
3
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
4
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
5
5
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
6
- nshtrainer/_snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
7
- nshtrainer/actsave/__init__.py,sha256=G1T-fELuGWkVqdhdyoePtj2dTOUtcIOW4VgsXv9JNTA,338
8
- nshtrainer/actsave/_callback.py,sha256=QoTa60F70f1RxB41VKixN9l5_htfFQxXDPHHSNFreuk,2770
9
- nshtrainer/actsave/_loader.py,sha256=fAhD32DrJa4onkYfcwc21YIeGEYzOSXCK_HVo9SZLgQ,4604
10
- nshtrainer/actsave/_saver.py,sha256=0EHmQDhqVxQWRWWSyt03eP1K9ETiACMQYmsZkDMt6HY,9451
6
+ nshtrainer/_snoop.py,sha256=2rEemPyMP3aIo2QgPzo_-AlT1oXGWYQipId4RQskMls,58
7
+ nshtrainer/actsave/__init__.py,sha256=_ZuwgRtF1-ekouXNvtZCAS1g_IDYGB4NX8BFSGNGBT8,119
8
+ nshtrainer/actsave/_callback.py,sha256=mnHOtuG9vtHEzz9q4vCvDNC6VvjZsgb4MSSuOoUDh3M,2778
11
9
  nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
12
10
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
13
11
  nshtrainer/callbacks/base.py,sha256=LrcRUV02bZEKXRIRvhHT9qsvw_kwoWiAdQkVMyKc5NU,3542
@@ -54,12 +52,12 @@ nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwB
54
52
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
55
53
  nshtrainer/trainer/signal_connector.py,sha256=QAoPM_C5JJOVQebcrJOimUUD3GHyoeZUqCEAvzZlT4U,8710
56
54
  nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
57
- nshtrainer/typecheck.py,sha256=RGYHxDBcs97E6ayl6Olc43JBZXQolCtMxcLBniVCVBg,4688
55
+ nshtrainer/typecheck.py,sha256=ryV1Tzcf7hJ4I19H1oQVkikU9spmRk8jyIKQZ5UF7pQ,62
58
56
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
59
57
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
60
58
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
61
59
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
62
60
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
63
- nshtrainer-0.2.0.dist-info/METADATA,sha256=cwb3IbKGyJ9HbNSvsORYhCiI61nrDMb1dVm5nE1q_XA,882
64
- nshtrainer-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
65
- nshtrainer-0.2.0.dist-info/RECORD,,
61
+ nshtrainer-0.4.0.dist-info/METADATA,sha256=V4fr_C3pnSpaXT1KHAmnDFLcDznTo7AZZDAPFPm3AUk,812
62
+ nshtrainer-0.4.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
63
+ nshtrainer-0.4.0.dist-info/RECORD,,
@@ -1,144 +0,0 @@
1
- import pprint
2
- from dataclasses import dataclass, field
3
- from functools import cached_property
4
- from logging import getLogger
5
- from pathlib import Path
6
- from typing import cast, overload
7
-
8
- import numpy as np
9
- from typing_extensions import TypeVar, override
10
-
11
- log = getLogger(__name__)
12
-
13
- T = TypeVar("T", infer_variance=True)
14
-
15
-
16
- @dataclass
17
- class LoadedActivation:
18
- base_dir: Path = field(repr=False)
19
- name: str
20
- num_activations: int = field(init=False)
21
- activation_files: list[Path] = field(init=False, repr=False)
22
-
23
- def __post_init__(self):
24
- if not self.activation_dir.exists():
25
- raise ValueError(f"Activation dir {self.activation_dir} does not exist")
26
-
27
- # The number of activations = the * of .npy files in the activation dir
28
- self.activation_files = list(self.activation_dir.glob("*.npy"))
29
- # Sort the activation files by the numerical index in the filename
30
- self.activation_files.sort(key=lambda p: int(p.stem))
31
- self.num_activations = len(self.activation_files)
32
-
33
- @property
34
- def activation_dir(self) -> Path:
35
- return self.base_dir / self.name
36
-
37
- def _load_activation(self, item: int):
38
- activation_path = self.activation_files[item]
39
- if not activation_path.exists():
40
- raise ValueError(f"Activation {activation_path} does not exist")
41
- return cast(np.ndarray, np.load(activation_path, allow_pickle=True))
42
-
43
- @overload
44
- def __getitem__(self, item: int) -> np.ndarray: ...
45
-
46
- @overload
47
- def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ...
48
-
49
- def __getitem__(
50
- self, item: int | slice | list[int]
51
- ) -> np.ndarray | list[np.ndarray]:
52
- if isinstance(item, int):
53
- return self._load_activation(item)
54
- elif isinstance(item, slice):
55
- return [
56
- self._load_activation(i)
57
- for i in range(*item.indices(self.num_activations))
58
- ]
59
- elif isinstance(item, list):
60
- return [self._load_activation(i) for i in item]
61
- else:
62
- raise TypeError(f"Invalid type {type(item)} for item {item}")
63
-
64
- def __iter__(self):
65
- return iter(self[i] for i in range(self.num_activations))
66
-
67
- def __len__(self):
68
- return self.num_activations
69
-
70
- def all_activations(self):
71
- return [self[i] for i in range(self.num_activations)]
72
-
73
- @override
74
- def __repr__(self):
75
- return f"<LoadedActivation {self.name} ({self.num_activations} activations)>"
76
-
77
-
78
- class ActLoad:
79
- @classmethod
80
- def all_versions(cls, dir: str | Path):
81
- dir = Path(dir)
82
-
83
- # If the dir is not an activation base directory, we return None
84
- if not (dir / ".activationbase").exists():
85
- return None
86
-
87
- # The contents of `dir` should be directories, each of which is a version.
88
- return [
89
- (subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir()
90
- ]
91
-
92
- @classmethod
93
- def is_valid_activation_base(cls, dir: str | Path):
94
- return cls.all_versions(dir) is not None
95
-
96
- @classmethod
97
- def from_latest_version(cls, dir: str | Path):
98
- # The contents of `dir` should be directories, each of which is a version
99
- # We need to find the latest version
100
- if (all_versions := cls.all_versions(dir)) is None:
101
- raise ValueError(f"{dir} is not an activation base directory")
102
-
103
- path, _ = max(all_versions, key=lambda p: p[1])
104
- return cls(path)
105
-
106
- def __init__(self, dir: Path):
107
- self._dir = dir
108
-
109
- def activation(self, name: str):
110
- return LoadedActivation(self._dir, name)
111
-
112
- @cached_property
113
- def activations(self):
114
- dirs = list(self._dir.iterdir())
115
- # Sort the dirs by the last modified time
116
- dirs.sort(key=lambda p: p.stat().st_mtime)
117
-
118
- return {p.name: LoadedActivation(self._dir, p.name) for p in dirs}
119
-
120
- def __iter__(self):
121
- return iter(self.activations.values())
122
-
123
- def __getitem__(self, item: str):
124
- return self.activations[item]
125
-
126
- def __len__(self):
127
- return len(self.activations)
128
-
129
- @override
130
- def __repr__(self):
131
- acts_str = pprint.pformat(
132
- {
133
- name: f"<{activation.num_activations} activations>"
134
- for name, activation in self.activations.items()
135
- }
136
- )
137
- acts_str = acts_str.replace("'<", "<").replace(">'", ">")
138
- return f"ActLoad({acts_str})"
139
-
140
- def get(self, name: str, /, default: T) -> LoadedActivation | T:
141
- return self.activations.get(name, default)
142
-
143
-
144
- ActivationLoader = ActLoad
@@ -1,337 +0,0 @@
1
- import contextlib
2
- import fnmatch
3
- import tempfile
4
- import uuid
5
- import weakref
6
- from collections.abc import Callable, Mapping
7
- from dataclasses import dataclass
8
- from functools import wraps
9
- from logging import getLogger
10
- from pathlib import Path
11
- from typing import Generic, TypeAlias, cast, overload
12
-
13
- import numpy as np
14
- import torch
15
- from lightning_utilities.core.apply_func import apply_to_collection
16
- from typing_extensions import ParamSpec, TypeVar, override
17
-
18
- log = getLogger(__name__)
19
-
20
- Value: TypeAlias = int | float | complex | bool | str | np.ndarray | torch.Tensor | None
21
- ValueOrLambda = Value | Callable[..., Value]
22
-
23
-
24
- def _to_numpy(activation: Value) -> np.ndarray:
25
- # Make sure it's not `None`
26
- if activation is None:
27
- raise ValueError("Activation should not be `None`")
28
-
29
- if isinstance(activation, np.ndarray):
30
- return activation
31
- if isinstance(activation, torch.Tensor):
32
- activation = activation.detach()
33
- if activation.is_floating_point():
34
- # NOTE: We need to convert to float32 because [b]float16 is not supported by numpy
35
- activation = activation.float()
36
- return activation.cpu().numpy()
37
- if isinstance(activation, (int, float, complex, str, bool)):
38
- return np.array(activation)
39
-
40
- return activation
41
-
42
-
43
- T = TypeVar("T", infer_variance=True)
44
-
45
-
46
- # A wrapper around weakref.ref that allows for primitive types
47
- # To get around errors like:
48
- # TypeError: cannot create weak reference to 'int' object
49
- class WeakRef(Generic[T]):
50
- _ref: Callable[[], T] | None
51
-
52
- def __init__(self, obj: T):
53
- try:
54
- self._ref = cast(Callable[[], T], weakref.ref(obj))
55
- except TypeError as e:
56
- if "cannot create weak reference" not in str(e):
57
- raise
58
- self._ref = lambda: obj
59
-
60
- def __call__(self) -> T:
61
- if self._ref is None:
62
- raise RuntimeError("WeakRef is deleted")
63
- return self._ref()
64
-
65
- def delete(self):
66
- del self._ref
67
- self._ref = None
68
-
69
-
70
- @dataclass
71
- class Activation:
72
- name: str
73
- ref: WeakRef[ValueOrLambda] | None
74
- transformed: np.ndarray | None = None
75
-
76
- def __post_init__(self):
77
- # Update the `name` to replace `/` with `.`
78
- self.name = self.name.replace("/", ".")
79
-
80
- def __call__(self) -> np.ndarray | None:
81
- # If we have a transformed value, we return it
82
- if self.transformed is not None:
83
- return self.transformed
84
-
85
- if self.ref is None:
86
- raise RuntimeError("Activation is deleted")
87
-
88
- # If we have a lambda, we need to call it
89
- unrwapped_ref = self.ref()
90
- activation = unrwapped_ref
91
- if callable(unrwapped_ref):
92
- activation = unrwapped_ref()
93
-
94
- # If we have a `None`, we return early
95
- if activation is None:
96
- return None
97
-
98
- activation = apply_to_collection(activation, torch.Tensor, _to_numpy)
99
- activation = _to_numpy(activation)
100
-
101
- # Set the transformed value
102
- self.transformed = activation
103
-
104
- # Delete the reference
105
- self.ref.delete()
106
- del self.ref
107
- self.ref = None
108
-
109
- return self.transformed
110
-
111
- @classmethod
112
- def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda):
113
- return cls(name, WeakRef(value_or_lambda))
114
-
115
- @classmethod
116
- def from_dict(cls, d: Mapping[str, ValueOrLambda]):
117
- return [cls.from_value_or_lambda(k, v) for k, v in d.items()]
118
-
119
-
120
- Transform = Callable[[Activation], Mapping[str, ValueOrLambda]]
121
-
122
-
123
- def _ensure_supported():
124
- try:
125
- import torch.distributed as dist
126
-
127
- if dist.is_initialized() and dist.get_world_size() > 1:
128
- raise RuntimeError("Only single GPU is supported at the moment")
129
- except ImportError:
130
- pass
131
-
132
-
133
- P = ParamSpec("P")
134
-
135
-
136
- def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]:
137
- @wraps(fn)
138
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
139
- if torch.jit.is_scripting():
140
- return
141
-
142
- _ensure_supported()
143
- fn(*args, **kwargs)
144
-
145
- return wrapper
146
-
147
-
148
- class _Saver:
149
- def __init__(
150
- self,
151
- save_dir: Path,
152
- prefixes_fn: Callable[[], list[str]],
153
- *,
154
- filters: list[str] | None = None,
155
- ):
156
- # Create a directory under `save_dir` by autoincrementing
157
- # (i.e., every activation save context, we create a new directory)
158
- # The id = the number of activation subdirectories
159
- self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir())
160
- save_dir.mkdir(parents=True, exist_ok=True)
161
-
162
- # Add a .activationbase file to the save_dir to indicate that this is an activation base
163
- (save_dir / ".activationbase").touch(exist_ok=True)
164
-
165
- self._save_dir = save_dir / f"{self._id:04d}"
166
- # Make sure `self._save_dir` does not exist and create it
167
- self._save_dir.mkdir(exist_ok=False)
168
-
169
- self._prefixes_fn = prefixes_fn
170
- self._filters = filters
171
-
172
- def _save_activation(self, activation: Activation):
173
- # If the activation value is `None`, we skip it.
174
- if (activation_value := activation()) is None:
175
- return
176
-
177
- # Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer
178
- file_name = ".".join(self._prefixes_fn() + [activation.name])
179
- path = self._save_dir / file_name
180
- path.mkdir(exist_ok=True, parents=True)
181
-
182
- # Get the next id and save the activation
183
- id = len(list(path.glob("*.npy")))
184
- np.save(path / f"{id:04d}.npy", activation_value)
185
-
186
- @_ignore_if_scripting
187
- def save(
188
- self,
189
- acts: dict[str, ValueOrLambda] | None = None,
190
- /,
191
- **kwargs: ValueOrLambda,
192
- ):
193
- kwargs.update(acts or {})
194
-
195
- # Build activations
196
- activations = Activation.from_dict(kwargs)
197
-
198
- for activation in activations:
199
- # Make sure name matches at least one filter if filters are specified
200
- if self._filters is not None and all(
201
- not fnmatch.fnmatch(activation.name, f) for f in self._filters
202
- ):
203
- continue
204
-
205
- # Save the current activation
206
- self._save_activation(activation)
207
-
208
- del activations
209
-
210
-
211
- class ActSaveProvider:
212
- _saver: _Saver | None = None
213
- _prefixes: list[str] = []
214
-
215
- def initialize(self, save_dir: Path | None = None):
216
- """
217
- Initializes the saver with the given configuration and save directory.
218
-
219
- Args:
220
- save_dir (Path): The directory where the saved files will be stored.
221
- """
222
- if self._saver is None:
223
- if save_dir is None:
224
- save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}"
225
- log.critical(f"No save_dir specified, using {save_dir=}")
226
- self._saver = _Saver(
227
- save_dir,
228
- lambda: self._prefixes,
229
- )
230
-
231
- @contextlib.contextmanager
232
- def enabled(self, save_dir: Path | None = None):
233
- """
234
- Context manager that enables the actsave functionality with the specified configuration.
235
-
236
- Args:
237
- save_dir (Path): The directory where the saved files will be stored.
238
- """
239
- prev = self._saver
240
- self.initialize(save_dir)
241
- try:
242
- yield
243
- finally:
244
- self._saver = prev
245
-
246
- @override
247
- def __init__(self):
248
- super().__init__()
249
-
250
- self._saver = None
251
- self._prefixes = []
252
-
253
- @contextlib.contextmanager
254
- def context(self, label: str):
255
- """
256
- A context manager that adds a label to the current context.
257
-
258
- Args:
259
- label (str): The label for the context.
260
- """
261
- if torch.jit.is_scripting():
262
- yield
263
- return
264
-
265
- if self._saver is None:
266
- yield
267
- return
268
-
269
- _ensure_supported()
270
-
271
- log.debug(f"Entering ActSave context {label}")
272
- self._prefixes.append(label)
273
- try:
274
- yield
275
- finally:
276
- _ = self._prefixes.pop()
277
-
278
- prefix = context
279
-
280
- @overload
281
- def __call__(
282
- self,
283
- acts: dict[str, ValueOrLambda] | None = None,
284
- /,
285
- **kwargs: ValueOrLambda,
286
- ):
287
- """
288
- Saves the activations to disk.
289
-
290
- Args:
291
- acts (dict[str, ValueOrLambda] | None, optional): A dictionary of acts. Defaults to None.
292
- **kwargs (ValueOrLambda): Additional keyword arguments.
293
-
294
- Returns:
295
- None
296
-
297
- """
298
- ...
299
-
300
- @overload
301
- def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /):
302
- """
303
- Saves the activations to disk.
304
-
305
- Args:
306
- acts (Callable[[], dict[str, ValueOrLambda]]): A callable that returns a dictionary of acts.
307
- **kwargs (ValueOrLambda): Additional keyword arguments.
308
-
309
- Returns:
310
- None
311
-
312
- """
313
- ...
314
-
315
- def __call__(
316
- self,
317
- acts: (
318
- dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None
319
- ) = None,
320
- /,
321
- **kwargs: ValueOrLambda,
322
- ):
323
- if torch.jit.is_scripting():
324
- return
325
-
326
- if self._saver is None:
327
- return
328
-
329
- if acts is not None and callable(acts):
330
- acts = acts()
331
- self._saver.save(acts, **kwargs)
332
-
333
- save = __call__
334
-
335
-
336
- ActSave = ActSaveProvider()
337
- ActivationSaver = ActSave