lucid-dl 2.11.3__py3-none-any.whl → 2.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/nn/module.py +55 -21
- lucid/types.py +58 -0
- {lucid_dl-2.11.3.dist-info → lucid_dl-2.11.4.dist-info}/METADATA +21 -5
- {lucid_dl-2.11.3.dist-info → lucid_dl-2.11.4.dist-info}/RECORD +7 -7
- {lucid_dl-2.11.3.dist-info → lucid_dl-2.11.4.dist-info}/WHEEL +0 -0
- {lucid_dl-2.11.3.dist-info → lucid_dl-2.11.4.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.3.dist-info → lucid_dl-2.11.4.dist-info}/top_level.txt +0 -0
lucid/nn/module.py
CHANGED
|
@@ -13,7 +13,22 @@ from typing import (
|
|
|
13
13
|
from collections import OrderedDict
|
|
14
14
|
|
|
15
15
|
from lucid._tensor import Tensor
|
|
16
|
-
from lucid.types import
|
|
16
|
+
from lucid.types import (
|
|
17
|
+
_ArrayOrScalar,
|
|
18
|
+
_BackwardHook,
|
|
19
|
+
_DeviceType,
|
|
20
|
+
_ForwardHook,
|
|
21
|
+
_ForwardHookKwargs,
|
|
22
|
+
_ForwardPreHook,
|
|
23
|
+
_ForwardPreHookKwargs,
|
|
24
|
+
_FullBackwardHook,
|
|
25
|
+
_FullBackwardPreHook,
|
|
26
|
+
_LoadStateDictPostHook,
|
|
27
|
+
_LoadStateDictPreHook,
|
|
28
|
+
_NumPyArray,
|
|
29
|
+
_StateDictHook,
|
|
30
|
+
_StateDictPreHook,
|
|
31
|
+
)
|
|
17
32
|
|
|
18
33
|
import lucid.nn as nn
|
|
19
34
|
|
|
@@ -30,26 +45,6 @@ __all__ = [
|
|
|
30
45
|
]
|
|
31
46
|
|
|
32
47
|
|
|
33
|
-
_ForwardPreHook = Callable[["Module", tuple[Any, ...]], tuple[Any, ...] | None]
|
|
34
|
-
_ForwardPreHookKwargs = Callable[
|
|
35
|
-
["Module", tuple[Any, ...], dict[str, Any]],
|
|
36
|
-
tuple[tuple[Any, ...], dict[str, Any]] | None,
|
|
37
|
-
]
|
|
38
|
-
_ForwardHook = Callable[["Module", tuple[Any, ...], Any], Any | None]
|
|
39
|
-
_ForwardHookKwargs = Callable[
|
|
40
|
-
["Module", tuple[Any, ...], dict[str, Any], Any], Any | None
|
|
41
|
-
]
|
|
42
|
-
|
|
43
|
-
_BackwardHook = Callable[[Tensor, _NumPyArray], None]
|
|
44
|
-
_FullBackwardPreHook = Callable[
|
|
45
|
-
["Module", tuple[_NumPyArray | None, ...]], tuple[_NumPyArray | None, ...] | None
|
|
46
|
-
]
|
|
47
|
-
_FullBackwardHook = Callable[
|
|
48
|
-
["Module", tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
|
|
49
|
-
tuple[_NumPyArray | None, ...] | None,
|
|
50
|
-
]
|
|
51
|
-
|
|
52
|
-
|
|
53
48
|
class Module:
|
|
54
49
|
_registry_map: dict[Type, OrderedDict[str, Any]] = {}
|
|
55
50
|
_alt_name: str = ""
|
|
@@ -70,10 +65,17 @@ class Module:
|
|
|
70
65
|
tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
|
|
71
66
|
] = []
|
|
72
67
|
self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
|
|
68
|
+
|
|
73
69
|
self._backward_hooks: list[_BackwardHook] = []
|
|
74
70
|
self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
|
|
75
71
|
self._full_backward_hooks: list[_FullBackwardHook] = []
|
|
76
72
|
|
|
73
|
+
self._state_dict_pre_hooks: list[_StateDictPreHook] = []
|
|
74
|
+
self._state_dict_hooks: list[_StateDictHook] = []
|
|
75
|
+
|
|
76
|
+
self._load_state_dict_pre_hooks: list[_LoadStateDictPreHook] = []
|
|
77
|
+
self._load_state_dict_post_hooks: list[_LoadStateDictPostHook] = []
|
|
78
|
+
|
|
77
79
|
self._state_dict_pass_attr = set()
|
|
78
80
|
|
|
79
81
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
@@ -155,6 +157,26 @@ class Module:
|
|
|
155
157
|
self._full_backward_hooks.append(hook)
|
|
156
158
|
return lambda: self._full_backward_hooks.remove(hook)
|
|
157
159
|
|
|
160
|
+
def register_state_dict_pre_hook(self, hook: _StateDictPreHook) -> Callable:
|
|
161
|
+
self._state_dict_pre_hooks.append(hook)
|
|
162
|
+
return lambda: self._state_dict_pre_hooks.remove(hook)
|
|
163
|
+
|
|
164
|
+
def register_state_dict_hook(self, hook: _StateDictHook) -> Callable:
|
|
165
|
+
self._state_dict_hooks.append(hook)
|
|
166
|
+
return lambda: self._state_dict_hooks.remove(hook)
|
|
167
|
+
|
|
168
|
+
def register_load_state_dict_pre_hook(
|
|
169
|
+
self, hook: _LoadStateDictPreHook
|
|
170
|
+
) -> Callable:
|
|
171
|
+
self._load_state_dict_pre_hooks.append(hook)
|
|
172
|
+
return lambda: self._load_state_dict_pre_hooks.remove(hook)
|
|
173
|
+
|
|
174
|
+
def register_load_state_dict_post_hook(
|
|
175
|
+
self, hook: _LoadStateDictPostHook
|
|
176
|
+
) -> Callable:
|
|
177
|
+
self._load_state_dict_post_hooks.append(hook)
|
|
178
|
+
return lambda: self._load_state_dict_post_hooks.remove(hook)
|
|
179
|
+
|
|
158
180
|
def reset_parameters(self) -> None:
|
|
159
181
|
for param in self.parameters():
|
|
160
182
|
param.zero()
|
|
@@ -231,6 +253,9 @@ class Module:
|
|
|
231
253
|
prefix: str = "",
|
|
232
254
|
keep_vars: bool = False,
|
|
233
255
|
) -> OrderedDict:
|
|
256
|
+
for hook in self._state_dict_pre_hooks:
|
|
257
|
+
hook(self, prefix, keep_vars)
|
|
258
|
+
|
|
234
259
|
if destination is None:
|
|
235
260
|
destination = OrderedDict()
|
|
236
261
|
|
|
@@ -249,9 +274,15 @@ class Module:
|
|
|
249
274
|
if key in self._state_dict_pass_attr:
|
|
250
275
|
del destination[key]
|
|
251
276
|
|
|
277
|
+
for hook in self._state_dict_hooks:
|
|
278
|
+
hook(self, destination, prefix, keep_vars)
|
|
279
|
+
|
|
252
280
|
return destination
|
|
253
281
|
|
|
254
282
|
def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
|
|
283
|
+
for hook in self._load_state_dict_pre_hooks:
|
|
284
|
+
hook(self, state_dict, strict)
|
|
285
|
+
|
|
255
286
|
own_state = self.state_dict(keep_vars=True)
|
|
256
287
|
|
|
257
288
|
missing_keys = set(own_state.keys()) - set(state_dict.keys())
|
|
@@ -277,6 +308,9 @@ class Module:
|
|
|
277
308
|
elif strict:
|
|
278
309
|
raise KeyError(f"Unexpected key '{key}' in state_dict.")
|
|
279
310
|
|
|
311
|
+
for hook in self._load_state_dict_post_hooks:
|
|
312
|
+
hook(self, missing_keys, unexpected_keys, strict)
|
|
313
|
+
|
|
280
314
|
def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
|
|
281
315
|
for hook, with_kwargs in self._forward_pre_hooks:
|
|
282
316
|
if with_kwargs:
|
lucid/types.py
CHANGED
|
@@ -6,8 +6,10 @@ from typing import (
|
|
|
6
6
|
Sequence,
|
|
7
7
|
Literal,
|
|
8
8
|
TypeAlias,
|
|
9
|
+
TYPE_CHECKING,
|
|
9
10
|
runtime_checkable,
|
|
10
11
|
)
|
|
12
|
+
from collections import OrderedDict
|
|
11
13
|
import re
|
|
12
14
|
|
|
13
15
|
import numpy as np
|
|
@@ -76,6 +78,62 @@ class _TensorLike(Protocol):
|
|
|
76
78
|
) -> None: ...
|
|
77
79
|
|
|
78
80
|
|
|
81
|
+
@runtime_checkable
|
|
82
|
+
class _ModuleHookable(Protocol):
|
|
83
|
+
def register_forward_pre_hook(
|
|
84
|
+
self, hook: Callable, *, with_kwargs: bool = False
|
|
85
|
+
) -> Callable: ...
|
|
86
|
+
|
|
87
|
+
def register_forward_hook(
|
|
88
|
+
self, hook: Callable, *, with_kwargs: bool = False
|
|
89
|
+
) -> Callable: ...
|
|
90
|
+
|
|
91
|
+
def register_backward_hook(self, hook: Callable) -> Callable: ...
|
|
92
|
+
|
|
93
|
+
def register_full_backward_pre_hook(self, hook: Callable) -> Callable: ...
|
|
94
|
+
|
|
95
|
+
def register_full_backward_hook(self, hook: Callable) -> Callable: ...
|
|
96
|
+
|
|
97
|
+
def register_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
|
|
98
|
+
|
|
99
|
+
def register_state_dict_hook(self, hook: Callable) -> Callable: ...
|
|
100
|
+
|
|
101
|
+
def register_load_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
|
|
102
|
+
|
|
103
|
+
def register_load_state_dict_post_hook(self, hook: Callable) -> Callable: ...
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
_ForwardPreHook: TypeAlias = Callable[
|
|
107
|
+
[_ModuleHookable, tuple[Any, ...]], tuple[Any, ...] | None
|
|
108
|
+
]
|
|
109
|
+
_ForwardPreHookKwargs: TypeAlias = Callable[
|
|
110
|
+
[_ModuleHookable, tuple[Any, ...], dict[str, Any]],
|
|
111
|
+
tuple[tuple[Any, ...], dict[str, Any]] | None,
|
|
112
|
+
]
|
|
113
|
+
_ForwardHook: TypeAlias = Callable[[_ModuleHookable, tuple[Any, ...], Any], Any | None]
|
|
114
|
+
_ForwardHookKwargs: TypeAlias = Callable[
|
|
115
|
+
[_ModuleHookable, tuple[Any, ...], dict[str, Any], Any], Any | None
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
_BackwardHook: TypeAlias = Callable[[_TensorLike, _NumPyArray], None]
|
|
119
|
+
_FullBackwardPreHook: TypeAlias = Callable[
|
|
120
|
+
[_ModuleHookable, tuple[_NumPyArray | None, ...]],
|
|
121
|
+
tuple[_NumPyArray | None, ...] | None,
|
|
122
|
+
]
|
|
123
|
+
_FullBackwardHook: TypeAlias = Callable[
|
|
124
|
+
[_ModuleHookable, tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
|
|
125
|
+
tuple[_NumPyArray | None, ...] | None,
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
_StateDictPreHook: TypeAlias = Callable[[_ModuleHookable, str, bool], None]
|
|
129
|
+
_StateDictHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, str, bool], None]
|
|
130
|
+
|
|
131
|
+
_LoadStateDictPreHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, bool], None]
|
|
132
|
+
_LoadStateDictPostHook: TypeAlias = Callable[
|
|
133
|
+
[_ModuleHookable, set[str], set[str], bool], None
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
|
|
79
137
|
class Numeric:
|
|
80
138
|
def __init__(
|
|
81
139
|
self, base_dtype: type[int | float | complex], bits: int | None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.11.
|
|
3
|
+
Version: 2.11.4
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -52,15 +52,31 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
52
52
|
|
|
53
53
|
```python
|
|
54
54
|
def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
55
|
-
|
|
55
|
+
```
|
|
56
|
+
```python
|
|
56
57
|
def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
57
|
-
|
|
58
|
+
```
|
|
59
|
+
```python
|
|
58
60
|
def register_backward_hook(self, hook: Callable)
|
|
59
|
-
|
|
61
|
+
```
|
|
62
|
+
```python
|
|
60
63
|
def register_full_backward_pre_hook(self, hook: Callable)
|
|
61
|
-
|
|
64
|
+
```
|
|
65
|
+
```python
|
|
62
66
|
def register_full_backward_hook(self, hook: Callable)
|
|
63
67
|
```
|
|
68
|
+
```python
|
|
69
|
+
def register_state_dict_pre_hook(self, hook: Callable)
|
|
70
|
+
```
|
|
71
|
+
```python
|
|
72
|
+
def register_state_dict_hook(self, hook: Callable)
|
|
73
|
+
```
|
|
74
|
+
```python
|
|
75
|
+
def register_load_state_dict_pre_hook(self, hook: Callable)
|
|
76
|
+
```
|
|
77
|
+
```python
|
|
78
|
+
def register_load_state_dict_post_hook(self, hook: Callable)
|
|
79
|
+
```
|
|
64
80
|
|
|
65
81
|
## 🔧 How to Install
|
|
66
82
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
lucid/__init__.py,sha256=EwNZkKALNS54RQw_D_1BtqbBo3kFMi4uSveSs6i-bdM,9161
|
|
2
2
|
lucid/error.py,sha256=qnTiVuZm3c5-DIt-OOyobZ7RUm7E1K4NR0j998LG1ug,709
|
|
3
3
|
lucid/port.py,sha256=Kt1YaSWef_eKF4KRj-UFhirvFC5urEESfYQ_BSlBZGE,3811
|
|
4
|
-
lucid/types.py,sha256=
|
|
4
|
+
lucid/types.py,sha256=Zdz2r4ledouEG-6Gi6yEza5vSLyyTzZJn7AcRKbxy8o,6906
|
|
5
5
|
lucid/_backend/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
|
|
6
6
|
lucid/_backend/core.py,sha256=neZF9uQlwNp-yHeyi0IbKlN556O-lsqThsIFExsf8-Y,11791
|
|
7
7
|
lucid/_backend/metal.py,sha256=vQegTENuPjeAM_EIXfuOpnIXZBeMbGVxCFvm4s-_NNo,4215
|
|
@@ -78,7 +78,7 @@ lucid/models/seq2seq/__init__.py,sha256=wjsrhj4H_AcqwwbebAN8b68QBA8L6p1_12dkG299
|
|
|
78
78
|
lucid/models/seq2seq/transformer.py,sha256=y5rerCs1s6jXTsVvbgscWScKpQKuSu1fezsBe7PNTRA,3513
|
|
79
79
|
lucid/nn/__init__.py,sha256=_hk6KltQIJuWXowXstMSu3TjiaTP8zMLNvGpjnA9Mpw,182
|
|
80
80
|
lucid/nn/fused.py,sha256=75fcXuo6fHSO-JtjuKhowhHSDr4qc5871WR63sUzH0g,5492
|
|
81
|
-
lucid/nn/module.py,sha256=
|
|
81
|
+
lucid/nn/module.py,sha256=_EWtGkAuWWCPZ5f3t5pJOOzpi14gQBpP7JW2S8o4_GE,26855
|
|
82
82
|
lucid/nn/parameter.py,sha256=NQS65YKn2B59wZbZIoT1mpDsU_F08y3yLi7hmV1B6yo,1232
|
|
83
83
|
lucid/nn/util.py,sha256=Yw1iBSPrGV_r_F51qpqLYdafNE_hyaA0DPWYP-rjaig,1699
|
|
84
84
|
lucid/nn/_kernel/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
|
|
@@ -135,8 +135,8 @@ lucid/visual/graph.py,sha256=ZSlrJI3dQwYjz8XbgAfNd8-8YuH9Ji7Mz1J6UsnHTaI,4711
|
|
|
135
135
|
lucid/visual/mermaid.py,sha256=87hFe4l9EYP6Cg2l2hP2INQiBHKkgVClH5nBWFY9ddY,26499
|
|
136
136
|
lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
|
|
137
137
|
lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
|
|
138
|
-
lucid_dl-2.11.
|
|
139
|
-
lucid_dl-2.11.
|
|
140
|
-
lucid_dl-2.11.
|
|
141
|
-
lucid_dl-2.11.
|
|
142
|
-
lucid_dl-2.11.
|
|
138
|
+
lucid_dl-2.11.4.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
|
|
139
|
+
lucid_dl-2.11.4.dist-info/METADATA,sha256=F8r0MrpLAlRuT0IK0RFLQCurlxcj3gUYLK2-tyKhAOI,12273
|
|
140
|
+
lucid_dl-2.11.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
141
|
+
lucid_dl-2.11.4.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
|
|
142
|
+
lucid_dl-2.11.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|