lucid-dl 2.11.2__py3-none-any.whl → 2.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/models/imgclf/crossvit.py +1 -1
- lucid/models/imgclf/efficientformer.py +2 -2
- lucid/models/imgclf/maxvit.py +1 -1
- lucid/models/imgclf/pvt.py +2 -2
- lucid/models/imggen/vae.py +1 -1
- lucid/models/objdet/efficientdet.py +24 -8
- lucid/models/objdet/rcnn.py +1 -1
- lucid/models/objdet/util.py +5 -0
- lucid/nn/module.py +142 -13
- lucid/types.py +58 -0
- lucid/visual/__init__.py +1 -0
- lucid/visual/graph.py +3 -0
- lucid/visual/mermaid.py +818 -0
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/METADATA +30 -21
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/RECORD +18 -17
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/WHEEL +1 -1
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.2.dist-info → lucid_dl-2.11.4.dist-info}/top_level.txt +0 -0
lucid/models/imgclf/crossvit.py
CHANGED
|
@@ -80,7 +80,7 @@ class _Attention(nn.Module):
|
|
|
80
80
|
y, x = lucid.meshgrid(
|
|
81
81
|
lucid.arange(resolution[0]), lucid.arange(resolution[1]), indexing="ij"
|
|
82
82
|
)
|
|
83
|
-
pos = lucid.stack([y, x]).flatten(
|
|
83
|
+
pos = lucid.stack([y, x]).flatten(start_axis=1)
|
|
84
84
|
rel_pos = lucid.abs(pos[..., :, None] - pos[..., None, :])
|
|
85
85
|
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
|
86
86
|
|
|
@@ -159,7 +159,7 @@ class _Downsample(nn.Module):
|
|
|
159
159
|
|
|
160
160
|
class _Flatten(nn.Module):
|
|
161
161
|
def forward(self, x: Tensor) -> Tensor:
|
|
162
|
-
x = x.flatten(
|
|
162
|
+
x = x.flatten(start_axis=2).swapaxes(1, 2)
|
|
163
163
|
return x
|
|
164
164
|
|
|
165
165
|
|
lucid/models/imgclf/maxvit.py
CHANGED
|
@@ -216,7 +216,7 @@ def _grid_reverse(
|
|
|
216
216
|
|
|
217
217
|
def _get_relative_position_index(win_h: int, win_w: int) -> Tensor:
|
|
218
218
|
coords = lucid.stack(lucid.meshgrid(lucid.arange(win_h), lucid.arange(win_w)))
|
|
219
|
-
coords_flatten = lucid.flatten(coords,
|
|
219
|
+
coords_flatten = lucid.flatten(coords, start_axis=1)
|
|
220
220
|
|
|
221
221
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
222
222
|
relative_coords = relative_coords.transpose((1, 2, 0))
|
lucid/models/imgclf/pvt.py
CHANGED
|
@@ -328,7 +328,7 @@ class _DWConv(nn.Module):
|
|
|
328
328
|
B, _, C = x.shape
|
|
329
329
|
x = x.swapaxes(1, 2).reshape(B, C, H, W)
|
|
330
330
|
x = self.dwconv(x)
|
|
331
|
-
x = x.flatten(
|
|
331
|
+
x = x.flatten(start_axis=2).swapaxes(1, 2)
|
|
332
332
|
|
|
333
333
|
return x
|
|
334
334
|
|
|
@@ -548,7 +548,7 @@ class _OverlapPatchEmbed(nn.Module):
|
|
|
548
548
|
def forward(self, x: Tensor) -> tuple[Tensor, int, int]:
|
|
549
549
|
x = self.proj(x)
|
|
550
550
|
H, W = x.shape[2:]
|
|
551
|
-
x = x.flatten(
|
|
551
|
+
x = x.flatten(start_axis=2).swapaxes(1, 2)
|
|
552
552
|
x = self.norm(x)
|
|
553
553
|
|
|
554
554
|
return x, H, W
|
lucid/models/imggen/vae.py
CHANGED
|
@@ -82,23 +82,37 @@ class _BiFPN(nn.Module):
|
|
|
82
82
|
def _norm_weight(self, weight: Tensor) -> Tensor:
|
|
83
83
|
return weight / (weight.sum(axis=0) + self.eps)
|
|
84
84
|
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _resize_like(x: Tensor, ref: Tensor) -> Tensor:
|
|
87
|
+
if x.shape[2:] == ref.shape[2:]:
|
|
88
|
+
return x
|
|
89
|
+
return F.interpolate(x, size=ref.shape[2:], mode="nearest")
|
|
90
|
+
|
|
85
91
|
def _forward_up(self, feats: tuple[Tensor]) -> tuple[Tensor]:
|
|
86
92
|
p3_in, p4_in, p5_in, p6_in, p7_in = feats
|
|
87
93
|
|
|
88
94
|
w1_p6_up = self._norm_weight(self.acts["6_w1"](self.weights["6_w1"]))
|
|
89
|
-
p6_up_in = w1_p6_up[0] * p6_in + w1_p6_up[1] * self.
|
|
95
|
+
p6_up_in = w1_p6_up[0] * p6_in + w1_p6_up[1] * self._resize_like(
|
|
96
|
+
self.ups["6"](p7_in), p6_in
|
|
97
|
+
)
|
|
90
98
|
p6_up = self.convs["6_up"](p6_up_in)
|
|
91
99
|
|
|
92
100
|
w1_p5_up = self._norm_weight(self.acts["5_w1"](self.weights["5_w1"]))
|
|
93
|
-
p5_up_in = w1_p5_up[0] * p5_in + w1_p5_up[1] * self.
|
|
101
|
+
p5_up_in = w1_p5_up[0] * p5_in + w1_p5_up[1] * self._resize_like(
|
|
102
|
+
self.ups["5"](p6_up), p5_in
|
|
103
|
+
)
|
|
94
104
|
p5_up = self.convs["5_up"](p5_up_in)
|
|
95
105
|
|
|
96
106
|
w1_p4_up = self._norm_weight(self.acts["4_w1"](self.weights["4_w1"]))
|
|
97
|
-
p4_up_in = w1_p4_up[0] * p4_in + w1_p4_up[1] * self.
|
|
107
|
+
p4_up_in = w1_p4_up[0] * p4_in + w1_p4_up[1] * self._resize_like(
|
|
108
|
+
self.ups["4"](p5_up), p4_in
|
|
109
|
+
)
|
|
98
110
|
p4_up = self.convs["4_up"](p4_up_in)
|
|
99
111
|
|
|
100
112
|
w1_p3_up = self._norm_weight(self.acts["3_w1"](self.weights["3_w1"]))
|
|
101
|
-
p3_up_in = w1_p3_up[0] * p3_in + w1_p3_up[1] * self.
|
|
113
|
+
p3_up_in = w1_p3_up[0] * p3_in + w1_p3_up[1] * self._resize_like(
|
|
114
|
+
self.ups["3"](p4_up), p3_in
|
|
115
|
+
)
|
|
102
116
|
p3_out = self.convs["3_up"](p3_up_in)
|
|
103
117
|
|
|
104
118
|
return p3_out, p4_up, p5_up, p6_up
|
|
@@ -113,7 +127,7 @@ class _BiFPN(nn.Module):
|
|
|
113
127
|
p4_down_in = (
|
|
114
128
|
w2_p4_down[0] * p4_in
|
|
115
129
|
+ w2_p4_down[1] * p4_up
|
|
116
|
-
+ w2_p4_down[2] * self.downs["4"](p3_out)
|
|
130
|
+
+ w2_p4_down[2] * self._resize_like(self.downs["4"](p3_out), p4_in)
|
|
117
131
|
)
|
|
118
132
|
p4_out = self.convs["4_down"](p4_down_in)
|
|
119
133
|
|
|
@@ -121,7 +135,7 @@ class _BiFPN(nn.Module):
|
|
|
121
135
|
p5_down_in = (
|
|
122
136
|
w2_p5_down[0] * p5_in
|
|
123
137
|
+ w2_p5_down[1] * p5_up
|
|
124
|
-
+ w2_p5_down[2] * self.downs["5"](p4_out)
|
|
138
|
+
+ w2_p5_down[2] * self._resize_like(self.downs["5"](p4_out), p5_in)
|
|
125
139
|
)
|
|
126
140
|
p5_out = self.convs["5_down"](p5_down_in)
|
|
127
141
|
|
|
@@ -129,12 +143,14 @@ class _BiFPN(nn.Module):
|
|
|
129
143
|
p6_down_in = (
|
|
130
144
|
w2_p6_down[0] * p6_in
|
|
131
145
|
+ w2_p6_down[1] * p6_up
|
|
132
|
-
+ w2_p6_down[2] * self.downs["6"](p5_out)
|
|
146
|
+
+ w2_p6_down[2] * self._resize_like(self.downs["6"](p5_out), p6_in)
|
|
133
147
|
)
|
|
134
148
|
p6_out = self.convs["6_down"](p6_down_in)
|
|
135
149
|
|
|
136
150
|
w2_p7_down = self._norm_weight(self.acts["7_w2"](self.weights["7_w2"]))
|
|
137
|
-
p7_down_in = w2_p7_down[0] * p7_in + w2_p7_down[1] * self.
|
|
151
|
+
p7_down_in = w2_p7_down[0] * p7_in + w2_p7_down[1] * self._resize_like(
|
|
152
|
+
self.downs["7"](p6_out), p7_in
|
|
153
|
+
)
|
|
138
154
|
p7_out = self.convs["7_down"](p7_down_in)
|
|
139
155
|
|
|
140
156
|
return p3_out, p4_out, p5_out, p6_out, p7_out
|
lucid/models/objdet/rcnn.py
CHANGED
lucid/models/objdet/util.py
CHANGED
|
@@ -283,6 +283,11 @@ class SelectiveSearch(nn.Module):
|
|
|
283
283
|
|
|
284
284
|
|
|
285
285
|
def iou(boxes_a: Tensor, boxes_b: Tensor) -> Tensor:
|
|
286
|
+
if boxes_a.ndim == 1:
|
|
287
|
+
boxes_a = boxes_a.unsqueeze(0)
|
|
288
|
+
if boxes_b.ndim == 1:
|
|
289
|
+
boxes_b = boxes_b.unsqueeze(0)
|
|
290
|
+
|
|
286
291
|
x1a, y1a, x2a, y2a = boxes_a.unbind(axis=1)
|
|
287
292
|
x1b, y1b, x2b, y2b = boxes_b.unbind(axis=1)
|
|
288
293
|
|
lucid/nn/module.py
CHANGED
|
@@ -13,7 +13,22 @@ from typing import (
|
|
|
13
13
|
from collections import OrderedDict
|
|
14
14
|
|
|
15
15
|
from lucid._tensor import Tensor
|
|
16
|
-
from lucid.types import
|
|
16
|
+
from lucid.types import (
|
|
17
|
+
_ArrayOrScalar,
|
|
18
|
+
_BackwardHook,
|
|
19
|
+
_DeviceType,
|
|
20
|
+
_ForwardHook,
|
|
21
|
+
_ForwardHookKwargs,
|
|
22
|
+
_ForwardPreHook,
|
|
23
|
+
_ForwardPreHookKwargs,
|
|
24
|
+
_FullBackwardHook,
|
|
25
|
+
_FullBackwardPreHook,
|
|
26
|
+
_LoadStateDictPostHook,
|
|
27
|
+
_LoadStateDictPreHook,
|
|
28
|
+
_NumPyArray,
|
|
29
|
+
_StateDictHook,
|
|
30
|
+
_StateDictPreHook,
|
|
31
|
+
)
|
|
17
32
|
|
|
18
33
|
import lucid.nn as nn
|
|
19
34
|
|
|
@@ -29,9 +44,6 @@ __all__ = [
|
|
|
29
44
|
"set_state_dict_pass_attr",
|
|
30
45
|
]
|
|
31
46
|
|
|
32
|
-
_ForwardHookType = Callable[["Module", tuple[Tensor], tuple[Tensor]], None]
|
|
33
|
-
_BackwardHookType = Callable[[Tensor, _NumPyArray], None]
|
|
34
|
-
|
|
35
47
|
|
|
36
48
|
class Module:
|
|
37
49
|
_registry_map: dict[Type, OrderedDict[str, Any]] = {}
|
|
@@ -49,8 +61,20 @@ class Module:
|
|
|
49
61
|
self.training = True
|
|
50
62
|
self.device: _DeviceType = "cpu"
|
|
51
63
|
|
|
52
|
-
self.
|
|
53
|
-
|
|
64
|
+
self._forward_pre_hooks: list[
|
|
65
|
+
tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
|
|
66
|
+
] = []
|
|
67
|
+
self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
|
|
68
|
+
|
|
69
|
+
self._backward_hooks: list[_BackwardHook] = []
|
|
70
|
+
self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
|
|
71
|
+
self._full_backward_hooks: list[_FullBackwardHook] = []
|
|
72
|
+
|
|
73
|
+
self._state_dict_pre_hooks: list[_StateDictPreHook] = []
|
|
74
|
+
self._state_dict_hooks: list[_StateDictHook] = []
|
|
75
|
+
|
|
76
|
+
self._load_state_dict_pre_hooks: list[_LoadStateDictPreHook] = []
|
|
77
|
+
self._load_state_dict_post_hooks: list[_LoadStateDictPostHook] = []
|
|
54
78
|
|
|
55
79
|
self._state_dict_pass_attr = set()
|
|
56
80
|
|
|
@@ -106,14 +130,53 @@ class Module:
|
|
|
106
130
|
|
|
107
131
|
self.__setattr__(name, buffer)
|
|
108
132
|
|
|
109
|
-
def
|
|
110
|
-
self
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
133
|
+
def register_forward_pre_hook(
|
|
134
|
+
self,
|
|
135
|
+
hook: _ForwardPreHook | _ForwardPreHookKwargs,
|
|
136
|
+
*,
|
|
137
|
+
with_kwargs: bool = False,
|
|
138
|
+
) -> Callable:
|
|
139
|
+
self._forward_pre_hooks.append((hook, with_kwargs))
|
|
140
|
+
return lambda: self._forward_pre_hooks.remove((hook, with_kwargs))
|
|
141
|
+
|
|
142
|
+
def register_forward_hook(
|
|
143
|
+
self, hook: _ForwardHook | _ForwardHookKwargs, *, with_kwargs: bool = False
|
|
144
|
+
) -> Callable:
|
|
145
|
+
self._forward_hooks.append((hook, with_kwargs))
|
|
146
|
+
return lambda: self._forward_hooks.remove((hook, with_kwargs))
|
|
147
|
+
|
|
148
|
+
def register_backward_hook(self, hook: _BackwardHook) -> Callable:
|
|
114
149
|
self._backward_hooks.append(hook)
|
|
115
150
|
return lambda: self._backward_hooks.remove(hook)
|
|
116
151
|
|
|
152
|
+
def register_full_backward_pre_hook(self, hook: _FullBackwardPreHook) -> Callable:
|
|
153
|
+
self._full_backward_pre_hooks.append(hook)
|
|
154
|
+
return lambda: self._full_backward_pre_hooks.remove(hook)
|
|
155
|
+
|
|
156
|
+
def register_full_backward_hook(self, hook: _FullBackwardHook) -> Callable:
|
|
157
|
+
self._full_backward_hooks.append(hook)
|
|
158
|
+
return lambda: self._full_backward_hooks.remove(hook)
|
|
159
|
+
|
|
160
|
+
def register_state_dict_pre_hook(self, hook: _StateDictPreHook) -> Callable:
|
|
161
|
+
self._state_dict_pre_hooks.append(hook)
|
|
162
|
+
return lambda: self._state_dict_pre_hooks.remove(hook)
|
|
163
|
+
|
|
164
|
+
def register_state_dict_hook(self, hook: _StateDictHook) -> Callable:
|
|
165
|
+
self._state_dict_hooks.append(hook)
|
|
166
|
+
return lambda: self._state_dict_hooks.remove(hook)
|
|
167
|
+
|
|
168
|
+
def register_load_state_dict_pre_hook(
|
|
169
|
+
self, hook: _LoadStateDictPreHook
|
|
170
|
+
) -> Callable:
|
|
171
|
+
self._load_state_dict_pre_hooks.append(hook)
|
|
172
|
+
return lambda: self._load_state_dict_pre_hooks.remove(hook)
|
|
173
|
+
|
|
174
|
+
def register_load_state_dict_post_hook(
|
|
175
|
+
self, hook: _LoadStateDictPostHook
|
|
176
|
+
) -> Callable:
|
|
177
|
+
self._load_state_dict_post_hooks.append(hook)
|
|
178
|
+
return lambda: self._load_state_dict_post_hooks.remove(hook)
|
|
179
|
+
|
|
117
180
|
def reset_parameters(self) -> None:
|
|
118
181
|
for param in self.parameters():
|
|
119
182
|
param.zero()
|
|
@@ -190,6 +253,9 @@ class Module:
|
|
|
190
253
|
prefix: str = "",
|
|
191
254
|
keep_vars: bool = False,
|
|
192
255
|
) -> OrderedDict:
|
|
256
|
+
for hook in self._state_dict_pre_hooks:
|
|
257
|
+
hook(self, prefix, keep_vars)
|
|
258
|
+
|
|
193
259
|
if destination is None:
|
|
194
260
|
destination = OrderedDict()
|
|
195
261
|
|
|
@@ -208,9 +274,15 @@ class Module:
|
|
|
208
274
|
if key in self._state_dict_pass_attr:
|
|
209
275
|
del destination[key]
|
|
210
276
|
|
|
277
|
+
for hook in self._state_dict_hooks:
|
|
278
|
+
hook(self, destination, prefix, keep_vars)
|
|
279
|
+
|
|
211
280
|
return destination
|
|
212
281
|
|
|
213
282
|
def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
|
|
283
|
+
for hook in self._load_state_dict_pre_hooks:
|
|
284
|
+
hook(self, state_dict, strict)
|
|
285
|
+
|
|
214
286
|
own_state = self.state_dict(keep_vars=True)
|
|
215
287
|
|
|
216
288
|
missing_keys = set(own_state.keys()) - set(state_dict.keys())
|
|
@@ -236,15 +308,72 @@ class Module:
|
|
|
236
308
|
elif strict:
|
|
237
309
|
raise KeyError(f"Unexpected key '{key}' in state_dict.")
|
|
238
310
|
|
|
311
|
+
for hook in self._load_state_dict_post_hooks:
|
|
312
|
+
hook(self, missing_keys, unexpected_keys, strict)
|
|
313
|
+
|
|
239
314
|
def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
|
|
315
|
+
for hook, with_kwargs in self._forward_pre_hooks:
|
|
316
|
+
if with_kwargs:
|
|
317
|
+
result = hook(self, args, kwargs)
|
|
318
|
+
if result is not None:
|
|
319
|
+
args, kwargs = result
|
|
320
|
+
else:
|
|
321
|
+
result = hook(self, args)
|
|
322
|
+
if result is not None:
|
|
323
|
+
args = result
|
|
324
|
+
|
|
240
325
|
output = self.forward(*args, **kwargs)
|
|
241
|
-
|
|
242
|
-
|
|
326
|
+
|
|
327
|
+
for hook, with_kwargs in self._forward_hooks:
|
|
328
|
+
if with_kwargs:
|
|
329
|
+
result = hook(self, args, kwargs, output)
|
|
330
|
+
else:
|
|
331
|
+
result = hook(self, args, output)
|
|
332
|
+
if result is not None:
|
|
333
|
+
output = result
|
|
243
334
|
|
|
244
335
|
if isinstance(output, Tensor) and self._backward_hooks:
|
|
245
336
|
for hook in self._backward_hooks:
|
|
246
337
|
output.register_hook(hook)
|
|
247
338
|
|
|
339
|
+
if self._full_backward_pre_hooks or self._full_backward_hooks:
|
|
340
|
+
outputs = output if isinstance(output, tuple) else (output,)
|
|
341
|
+
output_tensors = [out for out in outputs if isinstance(out, Tensor)]
|
|
342
|
+
|
|
343
|
+
if output_tensors:
|
|
344
|
+
grad_outputs: list[_NumPyArray | None] = [None] * len(output_tensors)
|
|
345
|
+
called = False
|
|
346
|
+
|
|
347
|
+
def _call_full_backward_hooks() -> None:
|
|
348
|
+
nonlocal called, grad_outputs
|
|
349
|
+
if called:
|
|
350
|
+
return
|
|
351
|
+
called = True
|
|
352
|
+
|
|
353
|
+
grad_output_tuple = tuple(grad_outputs)
|
|
354
|
+
for hook in self._full_backward_pre_hooks:
|
|
355
|
+
result = hook(self, grad_output_tuple)
|
|
356
|
+
if result is not None:
|
|
357
|
+
grad_output_tuple = result
|
|
358
|
+
|
|
359
|
+
grad_input_tuple = tuple(
|
|
360
|
+
arg.grad if isinstance(arg, Tensor) else None for arg in args
|
|
361
|
+
)
|
|
362
|
+
for hook in self._full_backward_hooks:
|
|
363
|
+
hook(self, grad_input_tuple, grad_output_tuple)
|
|
364
|
+
|
|
365
|
+
for idx, out in enumerate(output_tensors):
|
|
366
|
+
|
|
367
|
+
def _make_hook(index: int) -> Callable:
|
|
368
|
+
def _hook(_, grad: _NumPyArray) -> None:
|
|
369
|
+
grad_outputs[index] = grad
|
|
370
|
+
if all(g is not None for g in grad_outputs):
|
|
371
|
+
_call_full_backward_hooks()
|
|
372
|
+
|
|
373
|
+
return _hook
|
|
374
|
+
|
|
375
|
+
out.register_hook(_make_hook(idx))
|
|
376
|
+
|
|
248
377
|
return output
|
|
249
378
|
|
|
250
379
|
def __repr__(self) -> str:
|
lucid/types.py
CHANGED
|
@@ -6,8 +6,10 @@ from typing import (
|
|
|
6
6
|
Sequence,
|
|
7
7
|
Literal,
|
|
8
8
|
TypeAlias,
|
|
9
|
+
TYPE_CHECKING,
|
|
9
10
|
runtime_checkable,
|
|
10
11
|
)
|
|
12
|
+
from collections import OrderedDict
|
|
11
13
|
import re
|
|
12
14
|
|
|
13
15
|
import numpy as np
|
|
@@ -76,6 +78,62 @@ class _TensorLike(Protocol):
|
|
|
76
78
|
) -> None: ...
|
|
77
79
|
|
|
78
80
|
|
|
81
|
+
@runtime_checkable
|
|
82
|
+
class _ModuleHookable(Protocol):
|
|
83
|
+
def register_forward_pre_hook(
|
|
84
|
+
self, hook: Callable, *, with_kwargs: bool = False
|
|
85
|
+
) -> Callable: ...
|
|
86
|
+
|
|
87
|
+
def register_forward_hook(
|
|
88
|
+
self, hook: Callable, *, with_kwargs: bool = False
|
|
89
|
+
) -> Callable: ...
|
|
90
|
+
|
|
91
|
+
def register_backward_hook(self, hook: Callable) -> Callable: ...
|
|
92
|
+
|
|
93
|
+
def register_full_backward_pre_hook(self, hook: Callable) -> Callable: ...
|
|
94
|
+
|
|
95
|
+
def register_full_backward_hook(self, hook: Callable) -> Callable: ...
|
|
96
|
+
|
|
97
|
+
def register_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
|
|
98
|
+
|
|
99
|
+
def register_state_dict_hook(self, hook: Callable) -> Callable: ...
|
|
100
|
+
|
|
101
|
+
def register_load_state_dict_pre_hook(self, hook: Callable) -> Callable: ...
|
|
102
|
+
|
|
103
|
+
def register_load_state_dict_post_hook(self, hook: Callable) -> Callable: ...
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
_ForwardPreHook: TypeAlias = Callable[
|
|
107
|
+
[_ModuleHookable, tuple[Any, ...]], tuple[Any, ...] | None
|
|
108
|
+
]
|
|
109
|
+
_ForwardPreHookKwargs: TypeAlias = Callable[
|
|
110
|
+
[_ModuleHookable, tuple[Any, ...], dict[str, Any]],
|
|
111
|
+
tuple[tuple[Any, ...], dict[str, Any]] | None,
|
|
112
|
+
]
|
|
113
|
+
_ForwardHook: TypeAlias = Callable[[_ModuleHookable, tuple[Any, ...], Any], Any | None]
|
|
114
|
+
_ForwardHookKwargs: TypeAlias = Callable[
|
|
115
|
+
[_ModuleHookable, tuple[Any, ...], dict[str, Any], Any], Any | None
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
_BackwardHook: TypeAlias = Callable[[_TensorLike, _NumPyArray], None]
|
|
119
|
+
_FullBackwardPreHook: TypeAlias = Callable[
|
|
120
|
+
[_ModuleHookable, tuple[_NumPyArray | None, ...]],
|
|
121
|
+
tuple[_NumPyArray | None, ...] | None,
|
|
122
|
+
]
|
|
123
|
+
_FullBackwardHook: TypeAlias = Callable[
|
|
124
|
+
[_ModuleHookable, tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
|
|
125
|
+
tuple[_NumPyArray | None, ...] | None,
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
_StateDictPreHook: TypeAlias = Callable[[_ModuleHookable, str, bool], None]
|
|
129
|
+
_StateDictHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, str, bool], None]
|
|
130
|
+
|
|
131
|
+
_LoadStateDictPreHook: TypeAlias = Callable[[_ModuleHookable, OrderedDict, bool], None]
|
|
132
|
+
_LoadStateDictPostHook: TypeAlias = Callable[
|
|
133
|
+
[_ModuleHookable, set[str], set[str], bool], None
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
|
|
79
137
|
class Numeric:
|
|
80
138
|
def __init__(
|
|
81
139
|
self, base_dtype: type[int | float | complex], bits: int | None
|
lucid/visual/__init__.py
CHANGED
lucid/visual/graph.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from typing import Union
|
|
2
|
+
from warnings import deprecated
|
|
3
|
+
|
|
2
4
|
import networkx as nx
|
|
3
5
|
import matplotlib.pyplot as plt
|
|
4
6
|
|
|
@@ -9,6 +11,7 @@ from lucid._tensor import Tensor
|
|
|
9
11
|
__all__ = ["draw_tensor_graph"]
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
@deprecated("This feature will be re-written with Mermaid in future relases.")
|
|
12
15
|
def draw_tensor_graph(
|
|
13
16
|
tensor: Tensor,
|
|
14
17
|
horizontal: bool = False,
|