lucid-dl 2.11.3__py3-none-any.whl → 2.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lucid/nn/module.py CHANGED
@@ -13,7 +13,22 @@ from typing import (
13
13
  from collections import OrderedDict
14
14
 
15
15
  from lucid._tensor import Tensor
16
- from lucid.types import _ArrayOrScalar, _NumPyArray, _DeviceType
16
+ from lucid.types import (
17
+ _ArrayOrScalar,
18
+ _BackwardHook,
19
+ _DeviceType,
20
+ _ForwardHook,
21
+ _ForwardHookKwargs,
22
+ _ForwardPreHook,
23
+ _ForwardPreHookKwargs,
24
+ _FullBackwardHook,
25
+ _FullBackwardPreHook,
26
+ _LoadStateDictPostHook,
27
+ _LoadStateDictPreHook,
28
+ _NumPyArray,
29
+ _StateDictHook,
30
+ _StateDictPreHook,
31
+ )
17
32
 
18
33
  import lucid.nn as nn
19
34
 
@@ -30,26 +45,6 @@ __all__ = [
30
45
  ]
31
46
 
32
47
 
33
- _ForwardPreHook = Callable[["Module", tuple[Any, ...]], tuple[Any, ...] | None]
34
- _ForwardPreHookKwargs = Callable[
35
- ["Module", tuple[Any, ...], dict[str, Any]],
36
- tuple[tuple[Any, ...], dict[str, Any]] | None,
37
- ]
38
- _ForwardHook = Callable[["Module", tuple[Any, ...], Any], Any | None]
39
- _ForwardHookKwargs = Callable[
40
- ["Module", tuple[Any, ...], dict[str, Any], Any], Any | None
41
- ]
42
-
43
- _BackwardHook = Callable[[Tensor, _NumPyArray], None]
44
- _FullBackwardPreHook = Callable[
45
- ["Module", tuple[_NumPyArray | None, ...]], tuple[_NumPyArray | None, ...] | None
46
- ]
47
- _FullBackwardHook = Callable[
48
- ["Module", tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
49
- tuple[_NumPyArray | None, ...] | None,
50
- ]
51
-
52
-
53
48
  class Module:
54
49
  _registry_map: dict[Type, OrderedDict[str, Any]] = {}
55
50
  _alt_name: str = ""
@@ -70,10 +65,17 @@ class Module:
70
65
  tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
71
66
  ] = []
72
67
  self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
68
+
73
69
  self._backward_hooks: list[_BackwardHook] = []
74
70
  self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
75
71
  self._full_backward_hooks: list[_FullBackwardHook] = []
76
72
 
73
+ self._state_dict_pre_hooks: list[_StateDictPreHook] = []
74
+ self._state_dict_hooks: list[_StateDictHook] = []
75
+
76
+ self._load_state_dict_pre_hooks: list[_LoadStateDictPreHook] = []
77
+ self._load_state_dict_post_hooks: list[_LoadStateDictPostHook] = []
78
+
77
79
  self._state_dict_pass_attr = set()
78
80
 
79
81
  def __setattr__(self, name: str, value: Any) -> None:
@@ -155,6 +157,26 @@ class Module:
155
157
  self._full_backward_hooks.append(hook)
156
158
  return lambda: self._full_backward_hooks.remove(hook)
157
159
 
160
+ def register_state_dict_pre_hook(self, hook: _StateDictPreHook) -> Callable:
161
+ self._state_dict_pre_hooks.append(hook)
162
+ return lambda: self._state_dict_pre_hooks.remove(hook)
163
+
164
+ def register_state_dict_hook(self, hook: _StateDictHook) -> Callable:
165
+ self._state_dict_hooks.append(hook)
166
+ return lambda: self._state_dict_hooks.remove(hook)
167
+
168
+ def register_load_state_dict_pre_hook(
169
+ self, hook: _LoadStateDictPreHook
170
+ ) -> Callable:
171
+ self._load_state_dict_pre_hooks.append(hook)
172
+ return lambda: self._load_state_dict_pre_hooks.remove(hook)
173
+
174
+ def register_load_state_dict_post_hook(
175
+ self, hook: _LoadStateDictPostHook
176
+ ) -> Callable:
177
+ self._load_state_dict_post_hooks.append(hook)
178
+ return lambda: self._load_state_dict_post_hooks.remove(hook)
179
+
158
180
  def reset_parameters(self) -> None:
159
181
  for param in self.parameters():
160
182
  param.zero()
@@ -231,6 +253,9 @@ class Module:
231
253
  prefix: str = "",
232
254
  keep_vars: bool = False,
233
255
  ) -> OrderedDict:
256
+ for hook in self._state_dict_pre_hooks:
257
+ hook(self, prefix, keep_vars)
258
+
234
259
  if destination is None:
235
260
  destination = OrderedDict()
236
261
 
@@ -249,9 +274,15 @@ class Module:
249
274
  if key in self._state_dict_pass_attr:
250
275
  del destination[key]
251
276
 
277
+ for hook in self._state_dict_hooks:
278
+ hook(self, destination, prefix, keep_vars)
279
+
252
280
  return destination
253
281
 
254
282
  def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
283
+ for hook in self._load_state_dict_pre_hooks:
284
+ hook(self, state_dict, strict)
285
+
255
286
  own_state = self.state_dict(keep_vars=True)
256
287
 
257
288
  missing_keys = set(own_state.keys()) - set(state_dict.keys())
@@ -277,6 +308,9 @@ class Module:
277
308
  elif strict:
278
309
  raise KeyError(f"Unexpected key '{key}' in state_dict.")
279
310
 
311
+ for hook in self._load_state_dict_post_hooks:
312
+ hook(self, missing_keys, unexpected_keys, strict)
313
+
280
314
  def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
281
315
  for hook, with_kwargs in self._forward_pre_hooks:
282
316
  if with_kwargs:
lucid/types.py CHANGED
@@ -6,8 +6,10 @@ from typing import (
6
6
  Sequence,
7
7
  Literal,
8
8
  TypeAlias,
9
+ TYPE_CHECKING,
9
10
  runtime_checkable,
10
11
  )
12
+ from collections import OrderedDict
11
13
  import re
12
14
 
13
15
  import numpy as np
@@ -76,6 +78,62 @@ class _TensorLike(Protocol):
76
78
  ) -> None: ...
77
79
 
78
80
 
81
+ @runtime_checkable
82
+ class _ModuleHookable(Protocol):
83
+ def register_forward_pre_hook(
84
+ self, hook: Callable, *, with_kwargs: bool = False
85
+ ) -> Callable: ...
86
+
87
+ def register_forward_hook(
88
+ self, hook: Callable, *, with_kwargs: bool = False
89
+ ) -> Callable: ...
90
+
91
+ def register_backward_hook(self, hook: Callable) -> Callable: ...
92
+
93
+ def register_full_backward_pre_hook(self, hook: Callable) -> Callable: ...
94
+
95
+ def register_full_backward_hook(self, hook: Callable) -> Callable: ...
96
+
97
+ def register_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
98
+
99
+ def register_state_dict_hook(self, hook: Callable) -> Callable: ...
100
+
101
+ def register_load_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
102
+
103
+ def register_load_state_dict_post_hook(self, hook: Callable) -> Callable: ...
104
+
105
+
106
+ _ForwardPreHook: TypeAlias = Callable[
107
+ [_ModuleHookable, tuple[Any, ...]], tuple[Any, ...] | None
108
+ ]
109
+ _ForwardPreHookKwargs: TypeAlias = Callable[
110
+ [_ModuleHookable, tuple[Any, ...], dict[str, Any]],
111
+ tuple[tuple[Any, ...], dict[str, Any]] | None,
112
+ ]
113
+ _ForwardHook: TypeAlias = Callable[[_ModuleHookable, tuple[Any, ...], Any], Any | None]
114
+ _ForwardHookKwargs: TypeAlias = Callable[
115
+ [_ModuleHookable, tuple[Any, ...], dict[str, Any], Any], Any | None
116
+ ]
117
+
118
+ _BackwardHook: TypeAlias = Callable[[_TensorLike, _NumPyArray], None]
119
+ _FullBackwardPreHook: TypeAlias = Callable[
120
+ [_ModuleHookable, tuple[_NumPyArray | None, ...]],
121
+ tuple[_NumPyArray | None, ...] | None,
122
+ ]
123
+ _FullBackwardHook: TypeAlias = Callable[
124
+ [_ModuleHookable, tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
125
+ tuple[_NumPyArray | None, ...] | None,
126
+ ]
127
+
128
+ _StateDictPreHook: TypeAlias = Callable[[_ModuleHookable, str, bool], None]
129
+ _StateDictHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, str, bool], None]
130
+
131
+ _LoadStateDictPreHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, bool], None]
132
+ _LoadStateDictPostHook: TypeAlias = Callable[
133
+ [_ModuleHookable, set[str], set[str], bool], None
134
+ ]
135
+
136
+
79
137
  class Numeric:
80
138
  def __init__(
81
139
  self, base_dtype: type[int | float | complex], bits: int | None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.3
3
+ Version: 2.11.4
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -52,15 +52,31 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
52
52
 
53
53
  ```python
54
54
  def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
55
-
55
+ ```
56
+ ```python
56
57
  def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
57
-
58
+ ```
59
+ ```python
58
60
  def register_backward_hook(self, hook: Callable)
59
-
61
+ ```
62
+ ```python
60
63
  def register_full_backward_pre_hook(self, hook: Callable)
61
-
64
+ ```
65
+ ```python
62
66
  def register_full_backward_hook(self, hook: Callable)
63
67
  ```
68
+ ```python
69
+ def register_state_dict_pre_hook(self, hook: Callable)
70
+ ```
71
+ ```python
72
+ def register_state_dict_hook(self, hook: Callable)
73
+ ```
74
+ ```python
75
+ def register_load_state_dict_pre_hook(self, hook: Callable)
76
+ ```
77
+ ```python
78
+ def register_load_state_dict_post_hook(self, hook: Callable)
79
+ ```
64
80
 
65
81
  ## 🔧 How to Install
66
82
 
@@ -1,7 +1,7 @@
1
1
  lucid/__init__.py,sha256=EwNZkKALNS54RQw_D_1BtqbBo3kFMi4uSveSs6i-bdM,9161
2
2
  lucid/error.py,sha256=qnTiVuZm3c5-DIt-OOyobZ7RUm7E1K4NR0j998LG1ug,709
3
3
  lucid/port.py,sha256=Kt1YaSWef_eKF4KRj-UFhirvFC5urEESfYQ_BSlBZGE,3811
4
- lucid/types.py,sha256=XFeLUXhtMXGh9liRPB3X49PbJSK3bcOA6KN1gvCjksA,4818
4
+ lucid/types.py,sha256=Zdz2r4ledouEG-6Gi6yEza5vSLyyTzZJn7AcRKbxy8o,6906
5
5
  lucid/_backend/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
6
6
  lucid/_backend/core.py,sha256=neZF9uQlwNp-yHeyi0IbKlN556O-lsqThsIFExsf8-Y,11791
7
7
  lucid/_backend/metal.py,sha256=vQegTENuPjeAM_EIXfuOpnIXZBeMbGVxCFvm4s-_NNo,4215
@@ -78,7 +78,7 @@ lucid/models/seq2seq/__init__.py,sha256=wjsrhj4H_AcqwwbebAN8b68QBA8L6p1_12dkG299
78
78
  lucid/models/seq2seq/transformer.py,sha256=y5rerCs1s6jXTsVvbgscWScKpQKuSu1fezsBe7PNTRA,3513
79
79
  lucid/nn/__init__.py,sha256=_hk6KltQIJuWXowXstMSu3TjiaTP8zMLNvGpjnA9Mpw,182
80
80
  lucid/nn/fused.py,sha256=75fcXuo6fHSO-JtjuKhowhHSDr4qc5871WR63sUzH0g,5492
81
- lucid/nn/module.py,sha256=B2esFTdKJFjqsNUsFhVx_7IE-5nyJwW4b2I8QG1iLUk,25791
81
+ lucid/nn/module.py,sha256=_EWtGkAuWWCPZ5f3t5pJOOzpi14gQBpP7JW2S8o4_GE,26855
82
82
  lucid/nn/parameter.py,sha256=NQS65YKn2B59wZbZIoT1mpDsU_F08y3yLi7hmV1B6yo,1232
83
83
  lucid/nn/util.py,sha256=Yw1iBSPrGV_r_F51qpqLYdafNE_hyaA0DPWYP-rjaig,1699
84
84
  lucid/nn/_kernel/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
@@ -135,8 +135,8 @@ lucid/visual/graph.py,sha256=ZSlrJI3dQwYjz8XbgAfNd8-8YuH9Ji7Mz1J6UsnHTaI,4711
135
135
  lucid/visual/mermaid.py,sha256=87hFe4l9EYP6Cg2l2hP2INQiBHKkgVClH5nBWFY9ddY,26499
136
136
  lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
137
137
  lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
138
- lucid_dl-2.11.3.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
139
- lucid_dl-2.11.3.dist-info/METADATA,sha256=hffKVg1_fBZA5K3NcRn9KUkIv7B7ilHgr_9OJscZlR0,11898
140
- lucid_dl-2.11.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
141
- lucid_dl-2.11.3.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
142
- lucid_dl-2.11.3.dist-info/RECORD,,
138
+ lucid_dl-2.11.4.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
139
+ lucid_dl-2.11.4.dist-info/METADATA,sha256=F8r0MrpLAlRuT0IK0RFLQCurlxcj3gUYLK2-tyKhAOI,12273
140
+ lucid_dl-2.11.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
141
+ lucid_dl-2.11.4.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
142
+ lucid_dl-2.11.4.dist-info/RECORD,,