lucid-dl 2.11.2__py3-none-any.whl → 2.11.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,7 +79,7 @@ class _PatchEmbed(nn.Module):
79
79
  f"Input image size {(H, W)} does not match with {self.img_size}."
80
80
  )
81
81
 
82
- x = self.proj(x).flatten(axis=2).swapaxes(1, 2)
82
+ x = self.proj(x).flatten(start_axis=2).swapaxes(1, 2)
83
83
  return x
84
84
 
85
85
 
@@ -80,7 +80,7 @@ class _Attention(nn.Module):
80
80
  y, x = lucid.meshgrid(
81
81
  lucid.arange(resolution[0]), lucid.arange(resolution[1]), indexing="ij"
82
82
  )
83
- pos = lucid.stack([y, x]).flatten(axis=1)
83
+ pos = lucid.stack([y, x]).flatten(start_axis=1)
84
84
  rel_pos = lucid.abs(pos[..., :, None] - pos[..., None, :])
85
85
  rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
86
86
 
@@ -159,7 +159,7 @@ class _Downsample(nn.Module):
159
159
 
160
160
  class _Flatten(nn.Module):
161
161
  def forward(self, x: Tensor) -> Tensor:
162
- x = x.flatten(axis=2).swapaxes(1, 2)
162
+ x = x.flatten(start_axis=2).swapaxes(1, 2)
163
163
  return x
164
164
 
165
165
 
@@ -216,7 +216,7 @@ def _grid_reverse(
216
216
 
217
217
  def _get_relative_position_index(win_h: int, win_w: int) -> Tensor:
218
218
  coords = lucid.stack(lucid.meshgrid(lucid.arange(win_h), lucid.arange(win_w)))
219
- coords_flatten = lucid.flatten(coords, axis=1)
219
+ coords_flatten = lucid.flatten(coords, start_axis=1)
220
220
 
221
221
  relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
222
222
  relative_coords = relative_coords.transpose((1, 2, 0))
@@ -328,7 +328,7 @@ class _DWConv(nn.Module):
328
328
  B, _, C = x.shape
329
329
  x = x.swapaxes(1, 2).reshape(B, C, H, W)
330
330
  x = self.dwconv(x)
331
- x = x.flatten(axis=2).swapaxes(1, 2)
331
+ x = x.flatten(start_axis=2).swapaxes(1, 2)
332
332
 
333
333
  return x
334
334
 
@@ -548,7 +548,7 @@ class _OverlapPatchEmbed(nn.Module):
548
548
  def forward(self, x: Tensor) -> tuple[Tensor, int, int]:
549
549
  x = self.proj(x)
550
550
  H, W = x.shape[2:]
551
- x = x.flatten(axis=2).swapaxes(1, 2)
551
+ x = x.flatten(start_axis=2).swapaxes(1, 2)
552
552
  x = self.norm(x)
553
553
 
554
554
  return x, H, W
@@ -51,7 +51,7 @@ class VAE(nn.Module):
51
51
  h = x
52
52
  for encoder in self.encoders:
53
53
  h = encoder(h)
54
- mu, logvar = lucid.split(h, 2, axis=1)
54
+ mu, logvar = lucid.chunk(h, 2, axis=1)
55
55
  z = self.reparameterize(mu, logvar)
56
56
 
57
57
  mus.append(mu)
@@ -82,23 +82,37 @@ class _BiFPN(nn.Module):
82
82
  def _norm_weight(self, weight: Tensor) -> Tensor:
83
83
  return weight / (weight.sum(axis=0) + self.eps)
84
84
 
85
+ @staticmethod
86
+ def _resize_like(x: Tensor, ref: Tensor) -> Tensor:
87
+ if x.shape[2:] == ref.shape[2:]:
88
+ return x
89
+ return F.interpolate(x, size=ref.shape[2:], mode="nearest")
90
+
85
91
  def _forward_up(self, feats: tuple[Tensor]) -> tuple[Tensor]:
86
92
  p3_in, p4_in, p5_in, p6_in, p7_in = feats
87
93
 
88
94
  w1_p6_up = self._norm_weight(self.acts["6_w1"](self.weights["6_w1"]))
89
- p6_up_in = w1_p6_up[0] * p6_in + w1_p6_up[1] * self.ups["6"](p7_in)
95
+ p6_up_in = w1_p6_up[0] * p6_in + w1_p6_up[1] * self._resize_like(
96
+ self.ups["6"](p7_in), p6_in
97
+ )
90
98
  p6_up = self.convs["6_up"](p6_up_in)
91
99
 
92
100
  w1_p5_up = self._norm_weight(self.acts["5_w1"](self.weights["5_w1"]))
93
- p5_up_in = w1_p5_up[0] * p5_in + w1_p5_up[1] * self.ups["5"](p6_up)
101
+ p5_up_in = w1_p5_up[0] * p5_in + w1_p5_up[1] * self._resize_like(
102
+ self.ups["5"](p6_up), p5_in
103
+ )
94
104
  p5_up = self.convs["5_up"](p5_up_in)
95
105
 
96
106
  w1_p4_up = self._norm_weight(self.acts["4_w1"](self.weights["4_w1"]))
97
- p4_up_in = w1_p4_up[0] * p4_in + w1_p4_up[1] * self.ups["4"](p5_up)
107
+ p4_up_in = w1_p4_up[0] * p4_in + w1_p4_up[1] * self._resize_like(
108
+ self.ups["4"](p5_up), p4_in
109
+ )
98
110
  p4_up = self.convs["4_up"](p4_up_in)
99
111
 
100
112
  w1_p3_up = self._norm_weight(self.acts["3_w1"](self.weights["3_w1"]))
101
- p3_up_in = w1_p3_up[0] * p3_in + w1_p3_up[1] * self.ups["3"](p4_up)
113
+ p3_up_in = w1_p3_up[0] * p3_in + w1_p3_up[1] * self._resize_like(
114
+ self.ups["3"](p4_up), p3_in
115
+ )
102
116
  p3_out = self.convs["3_up"](p3_up_in)
103
117
 
104
118
  return p3_out, p4_up, p5_up, p6_up
@@ -113,7 +127,7 @@ class _BiFPN(nn.Module):
113
127
  p4_down_in = (
114
128
  w2_p4_down[0] * p4_in
115
129
  + w2_p4_down[1] * p4_up
116
- + w2_p4_down[2] * self.downs["4"](p3_out)
130
+ + w2_p4_down[2] * self._resize_like(self.downs["4"](p3_out), p4_in)
117
131
  )
118
132
  p4_out = self.convs["4_down"](p4_down_in)
119
133
 
@@ -121,7 +135,7 @@ class _BiFPN(nn.Module):
121
135
  p5_down_in = (
122
136
  w2_p5_down[0] * p5_in
123
137
  + w2_p5_down[1] * p5_up
124
- + w2_p5_down[2] * self.downs["5"](p4_out)
138
+ + w2_p5_down[2] * self._resize_like(self.downs["5"](p4_out), p5_in)
125
139
  )
126
140
  p5_out = self.convs["5_down"](p5_down_in)
127
141
 
@@ -129,12 +143,14 @@ class _BiFPN(nn.Module):
129
143
  p6_down_in = (
130
144
  w2_p6_down[0] * p6_in
131
145
  + w2_p6_down[1] * p6_up
132
- + w2_p6_down[2] * self.downs["6"](p5_out)
146
+ + w2_p6_down[2] * self._resize_like(self.downs["6"](p5_out), p6_in)
133
147
  )
134
148
  p6_out = self.convs["6_down"](p6_down_in)
135
149
 
136
150
  w2_p7_down = self._norm_weight(self.acts["7_w2"](self.weights["7_w2"]))
137
- p7_down_in = w2_p7_down[0] * p7_in + w2_p7_down[1] * self.downs["7"](p6_out)
151
+ p7_down_in = w2_p7_down[0] * p7_in + w2_p7_down[1] * self._resize_like(
152
+ self.downs["7"](p6_out), p7_in
153
+ )
138
154
  p7_out = self.convs["7_down"](p7_down_in)
139
155
 
140
156
  return p3_out, p4_out, p5_out, p6_out, p7_out
@@ -120,7 +120,7 @@ class RCNN(nn.Module):
120
120
 
121
121
  if isinstance(feats, (tuple, list)):
122
122
  feats = feats[-1]
123
- feats = feats.flatten(axis=1)
123
+ feats = feats.flatten(start_axis=1)
124
124
 
125
125
  cls_scores = self.svm(feats)
126
126
  bbox_deltas = self.bbox_reg(feats)
@@ -283,6 +283,11 @@ class SelectiveSearch(nn.Module):
283
283
 
284
284
 
285
285
  def iou(boxes_a: Tensor, boxes_b: Tensor) -> Tensor:
286
+ if boxes_a.ndim == 1:
287
+ boxes_a = boxes_a.unsqueeze(0)
288
+ if boxes_b.ndim == 1:
289
+ boxes_b = boxes_b.unsqueeze(0)
290
+
286
291
  x1a, y1a, x2a, y2a = boxes_a.unbind(axis=1)
287
292
  x1b, y1b, x2b, y2b = boxes_b.unbind(axis=1)
288
293
 
lucid/nn/module.py CHANGED
@@ -29,8 +29,25 @@ __all__ = [
29
29
  "set_state_dict_pass_attr",
30
30
  ]
31
31
 
32
- _ForwardHookType = Callable[["Module", tuple[Tensor], tuple[Tensor]], None]
33
- _BackwardHookType = Callable[[Tensor, _NumPyArray], None]
32
+
33
+ _ForwardPreHook = Callable[["Module", tuple[Any, ...]], tuple[Any, ...] | None]
34
+ _ForwardPreHookKwargs = Callable[
35
+ ["Module", tuple[Any, ...], dict[str, Any]],
36
+ tuple[tuple[Any, ...], dict[str, Any]] | None,
37
+ ]
38
+ _ForwardHook = Callable[["Module", tuple[Any, ...], Any], Any | None]
39
+ _ForwardHookKwargs = Callable[
40
+ ["Module", tuple[Any, ...], dict[str, Any], Any], Any | None
41
+ ]
42
+
43
+ _BackwardHook = Callable[[Tensor, _NumPyArray], None]
44
+ _FullBackwardPreHook = Callable[
45
+ ["Module", tuple[_NumPyArray | None, ...]], tuple[_NumPyArray | None, ...] | None
46
+ ]
47
+ _FullBackwardHook = Callable[
48
+ ["Module", tuple[_NumPyArray | None, ...], tuple[_NumPyArray | None, ...]],
49
+ tuple[_NumPyArray | None, ...] | None,
50
+ ]
34
51
 
35
52
 
36
53
  class Module:
@@ -49,8 +66,13 @@ class Module:
49
66
  self.training = True
50
67
  self.device: _DeviceType = "cpu"
51
68
 
52
- self._forward_hooks: list[_ForwardHookType] = []
53
- self._backward_hooks: list[_BackwardHookType] = []
69
+ self._forward_pre_hooks: list[
70
+ tuple[_ForwardPreHook | _ForwardPreHookKwargs, bool]
71
+ ] = []
72
+ self._forward_hooks: list[tuple[_ForwardHook | _ForwardHookKwargs, bool]] = []
73
+ self._backward_hooks: list[_BackwardHook] = []
74
+ self._full_backward_pre_hooks: list[_FullBackwardPreHook] = []
75
+ self._full_backward_hooks: list[_FullBackwardHook] = []
54
76
 
55
77
  self._state_dict_pass_attr = set()
56
78
 
@@ -106,14 +128,33 @@ class Module:
106
128
 
107
129
  self.__setattr__(name, buffer)
108
130
 
109
- def register_forward_hook(self, hook: _ForwardHookType) -> Callable:
110
- self._forward_hooks.append(hook)
111
- return lambda: self._forward_hooks.remove(hook)
112
-
113
- def register_backward_hook(self, hook: _BackwardHookType) -> Callable:
131
+ def register_forward_pre_hook(
132
+ self,
133
+ hook: _ForwardPreHook | _ForwardPreHookKwargs,
134
+ *,
135
+ with_kwargs: bool = False,
136
+ ) -> Callable:
137
+ self._forward_pre_hooks.append((hook, with_kwargs))
138
+ return lambda: self._forward_pre_hooks.remove((hook, with_kwargs))
139
+
140
+ def register_forward_hook(
141
+ self, hook: _ForwardHook | _ForwardHookKwargs, *, with_kwargs: bool = False
142
+ ) -> Callable:
143
+ self._forward_hooks.append((hook, with_kwargs))
144
+ return lambda: self._forward_hooks.remove((hook, with_kwargs))
145
+
146
+ def register_backward_hook(self, hook: _BackwardHook) -> Callable:
114
147
  self._backward_hooks.append(hook)
115
148
  return lambda: self._backward_hooks.remove(hook)
116
149
 
150
+ def register_full_backward_pre_hook(self, hook: _FullBackwardPreHook) -> Callable:
151
+ self._full_backward_pre_hooks.append(hook)
152
+ return lambda: self._full_backward_pre_hooks.remove(hook)
153
+
154
+ def register_full_backward_hook(self, hook: _FullBackwardHook) -> Callable:
155
+ self._full_backward_hooks.append(hook)
156
+ return lambda: self._full_backward_hooks.remove(hook)
157
+
117
158
  def reset_parameters(self) -> None:
118
159
  for param in self.parameters():
119
160
  param.zero()
@@ -237,14 +278,68 @@ class Module:
237
278
  raise KeyError(f"Unexpected key '{key}' in state_dict.")
238
279
 
239
280
  def __call__(self, *args: Any, **kwargs: Any) -> Tensor | tuple[Tensor, ...]:
281
+ for hook, with_kwargs in self._forward_pre_hooks:
282
+ if with_kwargs:
283
+ result = hook(self, args, kwargs)
284
+ if result is not None:
285
+ args, kwargs = result
286
+ else:
287
+ result = hook(self, args)
288
+ if result is not None:
289
+ args = result
290
+
240
291
  output = self.forward(*args, **kwargs)
241
- for hook in self._forward_hooks:
242
- hook(self, args, output)
292
+
293
+ for hook, with_kwargs in self._forward_hooks:
294
+ if with_kwargs:
295
+ result = hook(self, args, kwargs, output)
296
+ else:
297
+ result = hook(self, args, output)
298
+ if result is not None:
299
+ output = result
243
300
 
244
301
  if isinstance(output, Tensor) and self._backward_hooks:
245
302
  for hook in self._backward_hooks:
246
303
  output.register_hook(hook)
247
304
 
305
+ if self._full_backward_pre_hooks or self._full_backward_hooks:
306
+ outputs = output if isinstance(output, tuple) else (output,)
307
+ output_tensors = [out for out in outputs if isinstance(out, Tensor)]
308
+
309
+ if output_tensors:
310
+ grad_outputs: list[_NumPyArray | None] = [None] * len(output_tensors)
311
+ called = False
312
+
313
+ def _call_full_backward_hooks() -> None:
314
+ nonlocal called, grad_outputs
315
+ if called:
316
+ return
317
+ called = True
318
+
319
+ grad_output_tuple = tuple(grad_outputs)
320
+ for hook in self._full_backward_pre_hooks:
321
+ result = hook(self, grad_output_tuple)
322
+ if result is not None:
323
+ grad_output_tuple = result
324
+
325
+ grad_input_tuple = tuple(
326
+ arg.grad if isinstance(arg, Tensor) else None for arg in args
327
+ )
328
+ for hook in self._full_backward_hooks:
329
+ hook(self, grad_input_tuple, grad_output_tuple)
330
+
331
+ for idx, out in enumerate(output_tensors):
332
+
333
+ def _make_hook(index: int) -> Callable:
334
+ def _hook(_, grad: _NumPyArray) -> None:
335
+ grad_outputs[index] = grad
336
+ if all(g is not None for g in grad_outputs):
337
+ _call_full_backward_hooks()
338
+
339
+ return _hook
340
+
341
+ out.register_hook(_make_hook(idx))
342
+
248
343
  return output
249
344
 
250
345
  def __repr__(self) -> str:
lucid/visual/__init__.py CHANGED
@@ -1 +1,2 @@
1
1
  from .graph import *
2
+ from .mermaid import *
lucid/visual/graph.py CHANGED
@@ -1,4 +1,6 @@
1
1
  from typing import Union
2
+ from warnings import deprecated
3
+
2
4
  import networkx as nx
3
5
  import matplotlib.pyplot as plt
4
6
 
@@ -9,6 +11,7 @@ from lucid._tensor import Tensor
9
11
  __all__ = ["draw_tensor_graph"]
10
12
 
11
13
 
14
+ @deprecated("This feature will be re-written with Mermaid in future relases.")
12
15
  def draw_tensor_graph(
13
16
  tensor: Tensor,
14
17
  horizontal: bool = False,
@@ -0,0 +1,818 @@
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import Iterable, Literal
4
+
5
+ import lucid
6
+ import lucid.nn as nn
7
+
8
+ from lucid._tensor import Tensor
9
+ from lucid.types import _ShapeLike
10
+
11
+
12
+ __all__ = ["build_mermaid_chart"]
13
+
14
+
15
+ _NN_MODULES_PREFIX = "lucid.nn.modules."
16
+
17
+
18
+ # fmt: off
19
+ _BUILTIN_SUBPACKAGE_STYLE: dict[str, tuple[str, str]] = {
20
+ "conv": ("#ffe8e8", "#c53030"),
21
+ "norm": ("#e6fffa", "#2c7a7b"),
22
+ "activation": ("#faf5ff", "#6b46c1"),
23
+ "linear": ("#ebf8ff", "#2b6cb0"),
24
+ "pool": ("#fefcbf", "#b7791f"),
25
+ "drop": ("#edf2f7", "#4a5568"),
26
+ "transformer": ("#e2e8f0", "#334155"),
27
+ "attention": ("#f0fff4", "#2f855a"),
28
+ "vision": ("#fdf2f8", "#b83280"),
29
+ "rnn": ("#f0f9ff", "#0284c7"),
30
+ "sparse": ("#f1f5f9", "#475569"),
31
+ "loss": ("#fffbeb", "#d97706"),
32
+ "einops": ("#ecfccb", "#65a30d"),
33
+ }
34
+ # fmt: on
35
+
36
+
37
+ @dataclass
38
+ class _ModuleNode:
39
+ module: nn.Module
40
+ name: str
41
+ depth: int
42
+ children: list["_ModuleNode"]
43
+ group: list[nn.Module] | None = None
44
+
45
+ @property
46
+ def count(self) -> int:
47
+ return 1 if self.group is None else len(self.group)
48
+
49
+ def iter_modules(self) -> Iterable[nn.Module]:
50
+ if self.group is None:
51
+ yield self.module
52
+ return
53
+ yield from self.group
54
+
55
+
56
+ def _flatten_tensors(obj: object) -> list[Tensor]:
57
+ tensors: list[Tensor] = []
58
+
59
+ if isinstance(obj, Tensor):
60
+ tensors.append(obj)
61
+ elif isinstance(obj, (list, tuple)):
62
+ for item in obj:
63
+ tensors.extend(_flatten_tensors(item))
64
+ elif isinstance(obj, dict):
65
+ for item in obj.values():
66
+ tensors.extend(_flatten_tensors(item))
67
+
68
+ return tensors
69
+
70
+
71
+ def _build_tree(
72
+ module: nn.Module,
73
+ depth: int,
74
+ max_depth: int,
75
+ name: str = "",
76
+ *,
77
+ collapse_repeats: bool = False,
78
+ repeat_min: int = 3,
79
+ hide_subpackages: set[str] | None = None,
80
+ hide_module_names: set[str] | None = None,
81
+ ) -> _ModuleNode:
82
+ children: list[_ModuleNode] = []
83
+ if depth < max_depth:
84
+ for child_name, child in module._modules.items():
85
+ path = f"{name}.{child_name}" if name else child_name
86
+
87
+ node = _build_tree(
88
+ child,
89
+ depth + 1,
90
+ max_depth,
91
+ path,
92
+ collapse_repeats=collapse_repeats,
93
+ repeat_min=repeat_min,
94
+ hide_subpackages=hide_subpackages,
95
+ hide_module_names=hide_module_names,
96
+ )
97
+
98
+ child_cls_name = type(child).__name__
99
+ child_mod_path = type(child).__module__
100
+ child_subpkg = None
101
+
102
+ if child_mod_path.startswith(_NN_MODULES_PREFIX):
103
+ rest = child_mod_path[len(_NN_MODULES_PREFIX) :]
104
+ child_subpkg = rest.split(".", 1)[0] if rest else None
105
+
106
+ excluded = False
107
+ if hide_module_names and child_cls_name in hide_module_names:
108
+ excluded = True
109
+ if hide_subpackages and child_subpkg and child_subpkg in hide_subpackages:
110
+ excluded = True
111
+
112
+ if excluded:
113
+ children.extend(node.children)
114
+ else:
115
+ children.append(node)
116
+
117
+ if collapse_repeats and children:
118
+ children = _collapse_repeated_children(children, repeat_min=repeat_min)
119
+
120
+ return _ModuleNode(module=module, name=name, depth=depth, children=children)
121
+
122
+
123
+ def _module_label(module: nn.Module, show_params: bool) -> str:
124
+ class_name = getattr(module, "_alt_name", "") or type(module).__name__
125
+ if show_params:
126
+ return f"{class_name} ({module.parameter_size:,} params)"
127
+
128
+ return class_name
129
+
130
+
131
+ def _builtin_subpackage_key(module: nn.Module) -> str | None:
132
+ mod_path = type(module).__module__
133
+ if not mod_path.startswith(_NN_MODULES_PREFIX):
134
+ return None
135
+
136
+ rest = mod_path[len(_NN_MODULES_PREFIX) :]
137
+ return rest.split(".", 1)[0] if rest else None
138
+
139
+
140
+ def _shape_text_color(module: nn.Module) -> str | None:
141
+ subpkg = _builtin_subpackage_key(module)
142
+ if subpkg is None:
143
+ return None
144
+
145
+ style = _BUILTIN_SUBPACKAGE_STYLE.get(subpkg)
146
+ if style is None:
147
+ return None
148
+
149
+ _, stroke = style
150
+ return stroke
151
+
152
+
153
+ def _parse_rgba(value: str) -> tuple[str, float] | None:
154
+ v = value.strip().lower()
155
+ if not v.startswith("rgba(") or not v.endswith(")"):
156
+ return None
157
+
158
+ inner = v[5:-1]
159
+ parts = [p.strip() for p in inner.split(",")]
160
+ if len(parts) != 4:
161
+ return None
162
+
163
+ try:
164
+ r, g, b = (int(float(x)) for x in parts[:3])
165
+ a = float(parts[3])
166
+ except Exception:
167
+ return None
168
+
169
+ r = max(0, min(255, r))
170
+ g = max(0, min(255, g))
171
+ b = max(0, min(255, b))
172
+ a = max(0.0, min(1.0, a))
173
+
174
+ return (f"#{r:02x}{g:02x}{b:02x}", a)
175
+
176
+
177
+ def _container_attr_label(node: _ModuleNode) -> str | None:
178
+ if not isinstance(node.module, (nn.Sequential, nn.ModuleList, nn.ModuleDict)):
179
+ return None
180
+ if not node.name:
181
+ return None
182
+
183
+ leaf = node.name.rsplit(".", 1)[-1]
184
+ if leaf.isdigit():
185
+ return None
186
+
187
+ return leaf
188
+
189
+
190
+ def _escape_label(text: str) -> str:
191
+ return text.replace('"', "&quot;")
192
+
193
+
194
+ def _shape_str(shape: object) -> str:
195
+ try:
196
+ if isinstance(shape, tuple):
197
+ return "(" + ",".join(str(x) for x in shape) + ")"
198
+ except Exception:
199
+ pass
200
+ return str(shape)
201
+
202
+
203
+ def _flatten_shapes(obj: object) -> list[tuple[int, ...]]:
204
+ shapes: list[tuple[int, ...]] = []
205
+ for t in _flatten_tensors(obj):
206
+ try:
207
+ shapes.append(tuple(int(x) for x in t.shape))
208
+ except Exception:
209
+ continue
210
+ return shapes
211
+
212
+
213
+ def _shapes_brief(shapes: list[tuple[int, ...]]) -> str:
214
+ if not shapes:
215
+ return "?"
216
+ if len(shapes) == 1:
217
+ return _shape_str(shapes[0])
218
+ return f"{_shape_str(shapes[0])}x{len(shapes)}"
219
+
220
+
221
+ def _node_signature(node: _ModuleNode) -> tuple:
222
+ return (type(node.module), tuple(_node_signature(c) for c in node.children))
223
+
224
+
225
+ def _collapse_repeated_children(
226
+ children: list[_ModuleNode], *, repeat_min: int
227
+ ) -> list[_ModuleNode]:
228
+ if repeat_min <= 1:
229
+ return children
230
+
231
+ out: list[_ModuleNode] = []
232
+ i = 0
233
+ while i < len(children):
234
+ base = children[i]
235
+ base_sig = _node_signature(base)
236
+ j = i + 1
237
+ while j < len(children) and _node_signature(children[j]) == base_sig:
238
+ j += 1
239
+
240
+ run = children[i:j]
241
+ if len(run) >= repeat_min:
242
+ out.append(
243
+ _ModuleNode(
244
+ module=base.module,
245
+ name=base.name,
246
+ depth=base.depth,
247
+ children=base.children,
248
+ group=[n.module for n in run],
249
+ )
250
+ )
251
+ else:
252
+ out.extend(run)
253
+
254
+ i = j
255
+ return out
256
+
257
+
258
+ def build_mermaid_chart(
259
+ module: nn.Module,
260
+ input_shape: _ShapeLike | list[_ShapeLike] | None = None,
261
+ inputs: Iterable[Tensor] | Tensor | None = None,
262
+ depth: int = 2,
263
+ direction: str = "LR",
264
+ include_io: bool = True,
265
+ show_params: bool = False,
266
+ return_lines: bool = False,
267
+ copy_to_clipboard: bool = False,
268
+ compact: bool = False,
269
+ use_class_defs: bool = False,
270
+ end_semicolons: bool = True,
271
+ edge_mode: Literal["dataflow", "execution"] = "execution",
272
+ collapse_repeats: bool = True,
273
+ repeat_min: int = 2,
274
+ color_by_subpackage: bool = True,
275
+ container_name_from_attr: bool = True,
276
+ edge_stroke_width: float = 2.0,
277
+ emphasize_model_title: bool = True,
278
+ model_title_font_px: int = 20,
279
+ show_shapes: bool = False,
280
+ hide_subpackages: Iterable[str] = (),
281
+ hide_module_names: Iterable[str] = (),
282
+ dash_multi_input_edges: bool = True,
283
+ subgraph_fill: str = "#000000",
284
+ subgraph_fill_opacity: float = 0.05,
285
+ subgraph_stroke: str = "#000000",
286
+ subgraph_stroke_opacity: float = 0.75,
287
+ force_text_color: str | None = None,
288
+ edge_curve: str = "natural",
289
+ node_spacing: int = 50,
290
+ rank_spacing: int = 50,
291
+ **forward_kwargs,
292
+ ) -> str | list[str]:
293
+ if inputs is None and input_shape is None:
294
+ raise ValueError("Either inputs or input_shape must be provided.")
295
+ if depth < 0:
296
+ raise ValueError("depth must be >= 0")
297
+
298
+ tree = _build_tree(
299
+ module,
300
+ depth=0,
301
+ max_depth=depth,
302
+ collapse_repeats=collapse_repeats,
303
+ repeat_min=repeat_min,
304
+ hide_subpackages=set(hide_subpackages),
305
+ hide_module_names=set(hide_module_names),
306
+ )
307
+
308
+ nodes: list[_ModuleNode] = []
309
+
310
+ def _collect(n: _ModuleNode) -> None:
311
+ nodes.append(n)
312
+ for c in n.children:
313
+ _collect(c)
314
+
315
+ _collect(tree)
316
+
317
+ module_to_node: dict[nn.Module, _ModuleNode] = {n.module: n for n in nodes}
318
+
319
+ module_to_id: dict[nn.Module, str] = {}
320
+ for idx, n in enumerate(nodes):
321
+ node_id = f"m{idx}"
322
+ for mod in n.iter_modules():
323
+ module_to_id[mod] = node_id
324
+
325
+ def _build_parent_map(root: nn.Module) -> dict[nn.Module, nn.Module]:
326
+ parent: dict[nn.Module, nn.Module] = {}
327
+
328
+ def _walk(mod: nn.Module) -> None:
329
+ for child in mod._modules.values():
330
+ parent[child] = mod
331
+ _walk(child)
332
+
333
+ _walk(root)
334
+ return parent
335
+
336
+ parent_map = _build_parent_map(module)
337
+ has_non_root_included = any(mod is not module for mod in module_to_id)
338
+
339
+ def _map_to_included(mod: nn.Module) -> nn.Module | None:
340
+ cur = mod
341
+ while cur not in module_to_id and cur in parent_map:
342
+ cur = parent_map[cur]
343
+ return cur if cur in module_to_id else None
344
+
345
+ hooks = []
346
+ edges: set[tuple[str, str]] = set()
347
+ tensor_producer: dict[int, nn.Module] = {}
348
+ module_in_shapes: dict[nn.Module, list[tuple[int, ...]]] = {}
349
+ module_out_shapes: dict[nn.Module, list[tuple[int, ...]]] = {}
350
+ input_node_id = "input"
351
+ output_node_id = "output"
352
+ root_module = module
353
+ exec_order: list[nn.Module] = []
354
+
355
+ def _hook(
356
+ mod: nn.Module, input_arg: tuple, output: Tensor | tuple[Tensor, ...]
357
+ ) -> None:
358
+ mapped_mod = _map_to_included(mod)
359
+ if mapped_mod is None:
360
+ return
361
+
362
+ if show_shapes:
363
+ module_in_shapes[mapped_mod] = _flatten_shapes(input_arg)
364
+ module_out_shapes[mapped_mod] = _flatten_shapes(output)
365
+
366
+ if edge_mode == "dataflow":
367
+ input_tensors = _flatten_tensors(input_arg)
368
+ for t in input_tensors:
369
+ producer = tensor_producer.get(id(t))
370
+ if producer is None:
371
+ if include_io and (
372
+ mapped_mod is not root_module or not has_non_root_included
373
+ ):
374
+ edges.add((input_node_id, module_to_id[mapped_mod]))
375
+
376
+ else:
377
+ if producer is not mapped_mod:
378
+ edges.add((module_to_id[producer], module_to_id[mapped_mod]))
379
+
380
+ output_tensors = _flatten_tensors(output)
381
+ for t in output_tensors:
382
+ key = id(t)
383
+ if mapped_mod is root_module and key in tensor_producer:
384
+ continue
385
+ tensor_producer[key] = mapped_mod
386
+
387
+ exec_order.append(mapped_mod)
388
+
389
+ for mod in module.modules():
390
+ hooks.append(mod.register_forward_hook(_hook))
391
+
392
+ try:
393
+ if inputs is None:
394
+ if isinstance(input_shape, list):
395
+ input_tensors = [
396
+ lucid.random.randn(shape, device=module.device)
397
+ for shape in input_shape
398
+ ]
399
+ else:
400
+ input_tensors = [lucid.random.randn(input_shape, device=module.device)]
401
+ else:
402
+ if isinstance(inputs, Tensor):
403
+ input_tensors = [inputs]
404
+ else:
405
+ input_tensors = list(inputs)
406
+
407
+ outputs = module(*input_tensors, **forward_kwargs)
408
+ finally:
409
+ for remove in hooks:
410
+ remove()
411
+
412
+ model_input_shapes = _flatten_shapes(input_tensors)
413
+ model_output_shapes = _flatten_shapes(outputs)
414
+
415
+ if edge_mode == "execution":
416
+ seq = []
417
+ for m in exec_order:
418
+ if not seq or seq[-1] is not m:
419
+ seq.append(m)
420
+ seq_no_root = [m for m in seq if m is not root_module]
421
+ if not seq_no_root:
422
+ seq_no_root = [root_module]
423
+
424
+ def _is_included_container(mod: nn.Module) -> bool:
425
+ node = module_to_node.get(mod)
426
+ return bool(node and node.children)
427
+
428
+ seq_effective = [m for m in seq_no_root if not _is_included_container(m)]
429
+ if not seq_effective:
430
+ seq_effective = seq_no_root
431
+
432
+ for prev, cur in zip(seq_effective, seq_effective[1:]):
433
+ edges.add((module_to_id[prev], module_to_id[cur]))
434
+
435
+ if include_io and seq_effective:
436
+ edges.add((input_node_id, module_to_id[seq_effective[0]]))
437
+ edges.add((module_to_id[seq_effective[-1]], output_node_id))
438
+
439
+ else:
440
+ if include_io:
441
+ output_tensors = _flatten_tensors(outputs)
442
+ for t in output_tensors:
443
+ producer = tensor_producer.get(id(t))
444
+ if producer is not None:
445
+ edges.add((module_to_id[producer], output_node_id))
446
+
447
+ node_ids = set(module_to_id.values())
448
+ nodes_with_edges: set[str] = set()
449
+
450
+ for src, dst in edges:
451
+ if src in node_ids:
452
+ nodes_with_edges.add(src)
453
+ if dst in node_ids:
454
+ nodes_with_edges.add(dst)
455
+
456
+ container_node_ids = {module_to_id[n.module] for n in nodes if n.children}
457
+ container_with_edges = container_node_ids & nodes_with_edges
458
+ extra_edges: list[tuple[str, str]] = []
459
+
460
+ def _endpoint_in(node_id: str) -> str:
461
+ return f"{node_id}_in" if node_id in container_with_edges else node_id
462
+
463
+ def _endpoint_out(node_id: str) -> str:
464
+ return f"{node_id}_out" if node_id in container_with_edges else node_id
465
+
466
+ def _first_leaf_id(n: _ModuleNode) -> str:
467
+ cur = n
468
+ while cur.children:
469
+ cur = cur.children[0]
470
+ return module_to_id[cur.module]
471
+
472
+ def _last_leaf_id(n: _ModuleNode) -> str:
473
+ cur = n
474
+ while cur.children:
475
+ cur = cur.children[-1]
476
+ return module_to_id[cur.module]
477
+
478
+ lines: list[str] = []
479
+ init_cfg: dict[str, object] = {
480
+ "flowchart": {
481
+ "curve": "step" if edge_curve == "round" else edge_curve,
482
+ "nodeSpacing": node_spacing,
483
+ "rankSpacing": rank_spacing,
484
+ }
485
+ }
486
+ css_parts: list[str] = []
487
+ if edge_curve == "round":
488
+ css_parts.append(
489
+ ".edgePath path { stroke-linecap: round; stroke-linejoin: round; }"
490
+ )
491
+ if force_text_color:
492
+ css_parts.append(
493
+ f".nodeLabel, .edgeLabel, .cluster text, .node text "
494
+ f"{{ fill: {force_text_color} !important; }} "
495
+ f".node foreignObject *, .cluster foreignObject * "
496
+ f"{{ color: {force_text_color} !important; }}"
497
+ )
498
+ if css_parts:
499
+ init_cfg["themeCSS"] = " ".join(css_parts)
500
+
501
+ lines.append(f"%%{{init: {json.dumps(init_cfg, separators=(',', ':'))} }}%%")
502
+ lines.append(f"flowchart {direction}")
503
+
504
+ if edge_stroke_width and edge_stroke_width != 1.0:
505
+ lines.append(f" linkStyle default stroke-width:{edge_stroke_width}px")
506
+
507
+ if use_class_defs:
508
+ lines.append(" classDef module fill:#f9f9f9,stroke:#333,stroke-width:1px;")
509
+ lines.append(" classDef modelio fill:#fff3cd,stroke:#a67c00,stroke-width:1px;")
510
+ lines.append(
511
+ " classDef internalio fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
512
+ )
513
+ lines.append(
514
+ " classDef anchor fill:transparent,stroke:transparent,color:transparent;"
515
+ )
516
+ lines.append(" classDef repeat fill:#e8f1ff,stroke:#2b6cb0,stroke-width:1px;")
517
+
518
+ vertical_levels: set[int] = set()
519
+ if depth >= 3:
520
+ vertical_levels = {depth - 2, depth - 1}
521
+
522
+ def _render(n: _ModuleNode, indent: str = " ") -> None:
523
+ node_id = module_to_id[n.module]
524
+ base_label = _module_label(n.module, show_params)
525
+
526
+ if container_name_from_attr and n.children:
527
+ attr_label = _container_attr_label(n)
528
+ if attr_label is not None:
529
+ base_label = attr_label
530
+
531
+ base_label = _escape_label(base_label)
532
+ label_text = base_label if n.count == 1 else f"{base_label} x {n.count}"
533
+
534
+ if show_shapes and not n.children:
535
+ in_shapes = module_in_shapes.get(n.module, [])
536
+ out_shapes = module_out_shapes.get(n.module, [])
537
+
538
+ if in_shapes != out_shapes and (in_shapes or out_shapes):
539
+ ins = _shapes_brief(in_shapes)
540
+ outs = _shapes_brief(out_shapes)
541
+ color_css = ""
542
+ if not force_text_color:
543
+ color = _shape_text_color(n.module)
544
+ color_css = f"color:{color};" if color else ""
545
+ label_text = (
546
+ f"{label_text}<br/>"
547
+ f"<span style='font-size:11px;{color_css}font-weight:400'>"
548
+ f"{ins} \u2192 {outs}"
549
+ f"</span>"
550
+ )
551
+
552
+ label = label_text
553
+ if (
554
+ emphasize_model_title
555
+ and n.module is root_module
556
+ and model_title_font_px
557
+ and model_title_font_px > 0
558
+ ):
559
+ label = (
560
+ f"<span style='font-size:{model_title_font_px}px;font-weight:700'>"
561
+ f"{label}"
562
+ f"</span>"
563
+ )
564
+
565
+ if n.children:
566
+ lines.append(f'{indent}subgraph sg_{node_id}["{label}"]')
567
+ if n.depth in vertical_levels:
568
+ lines.append(f"{indent} direction TB")
569
+ if subgraph_fill or subgraph_stroke:
570
+ parts: list[str] = []
571
+
572
+ fill = subgraph_fill
573
+ fill_opacity = subgraph_fill_opacity
574
+ parsed = _parse_rgba(fill)
575
+
576
+ if parsed is not None:
577
+ fill, fill_opacity = parsed
578
+ if fill:
579
+ parts.append(f"fill:{fill}")
580
+ parts.append(f"fill-opacity:{fill_opacity}")
581
+
582
+ stroke = subgraph_stroke
583
+ stroke_opacity = subgraph_stroke_opacity
584
+ if stroke:
585
+ parsed = _parse_rgba(stroke)
586
+ if parsed is not None:
587
+ stroke, stroke_opacity = parsed
588
+
589
+ parts.append(f"stroke:{stroke}")
590
+ parts.append(f"stroke-opacity:{stroke_opacity}")
591
+ parts.append("stroke-width:1px")
592
+
593
+ lines.append(f'{indent}style sg_{node_id} {",".join(parts)}')
594
+
595
+ if node_id in nodes_with_edges:
596
+ if node_id in container_with_edges:
597
+ in_id = f"{node_id}_in"
598
+ out_id = f"{node_id}_out"
599
+
600
+ if use_class_defs:
601
+ lines.append(f'{indent} {in_id}(["Input"]):::internalio')
602
+ lines.append(f'{indent} {out_id}(["Output"]):::internalio')
603
+ else:
604
+ lines.append(f'{indent} {in_id}(["Input"])')
605
+ lines.append(f'{indent} {out_id}(["Output"])')
606
+ lines.append(
607
+ f" style {in_id} fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
608
+ )
609
+ lines.append(
610
+ f" style {out_id} fill:#e2e8f0,stroke:#64748b,stroke-width:1px;"
611
+ )
612
+
613
+ extra_edges.append((in_id, _endpoint_in(_first_leaf_id(n))))
614
+ extra_edges.append((_endpoint_out(_last_leaf_id(n)), out_id))
615
+
616
+ else:
617
+ if use_class_defs:
618
+ lines.append(f'{indent} {node_id}[""]:::anchor')
619
+ else:
620
+ lines.append(f'{indent} {node_id}["\u200b"]')
621
+ lines.append(
622
+ f" style {node_id} fill:transparent,stroke:transparent,color:transparent;"
623
+ )
624
+
625
+ for c in n.children:
626
+ _render(c, indent + " ")
627
+ lines.append(f"{indent}end")
628
+
629
+ else:
630
+ if use_class_defs:
631
+ class_name = "module" if n.count == 1 else "repeat"
632
+ lines.append(f'{indent}{node_id}["{label}"]:::{class_name}')
633
+ else:
634
+ if n.count > 1:
635
+ lines.append(f'{indent}{node_id}(["{label}"])')
636
+ else:
637
+ lines.append(f'{indent}{node_id}["{label}"]')
638
+
639
+ _render(tree)
640
+
641
+ if include_io:
642
+ input_label = "Input"
643
+ output_label = "Output"
644
+ if show_shapes:
645
+ in_s = _shapes_brief(model_input_shapes)
646
+ out_s = _shapes_brief(model_output_shapes)
647
+ io_color = force_text_color or "#a67c00"
648
+ input_label = (
649
+ f"{input_label}<br/>"
650
+ f"<span style='font-size:11px;color:{io_color};font-weight:400'>{in_s}</span>"
651
+ )
652
+ output_label = (
653
+ f"{output_label}<br/>"
654
+ f"<span style='font-size:11px;color:{io_color};font-weight:400'>{out_s}</span>"
655
+ )
656
+ if use_class_defs:
657
+ lines.append(f' {input_node_id}["{input_label}"]:::modelio')
658
+ lines.append(f' {output_node_id}["{output_label}"]:::modelio')
659
+ else:
660
+ lines.append(f' {input_node_id}["{input_label}"]')
661
+ lines.append(f' {output_node_id}["{output_label}"]')
662
+ lines.append(
663
+ f" style {input_node_id} fill:#fff3cd,stroke:#a67c00,stroke-width:1px;"
664
+ )
665
+ lines.append(
666
+ f" style {output_node_id} fill:#fff3cd,stroke:#a67c00,stroke-width:1px;"
667
+ )
668
+
669
+ if color_by_subpackage:
670
+ for n in nodes:
671
+ if n.children:
672
+ continue
673
+
674
+ subpkg = _builtin_subpackage_key(n.module)
675
+ if subpkg is None:
676
+ continue
677
+ style = _BUILTIN_SUBPACKAGE_STYLE.get(subpkg)
678
+ if style is None:
679
+ continue
680
+
681
+ fill, stroke = style
682
+ node_id = module_to_id[n.module]
683
+ if node_id in {input_node_id, output_node_id}:
684
+ continue
685
+
686
+ if node_id.endswith("_in") or node_id.endswith("_out"):
687
+ continue
688
+ lines.append(
689
+ f" style {node_id} fill:{fill},stroke:{stroke},stroke-width:1px;"
690
+ )
691
+
692
+ render_edges: set[tuple[str, str]] = set()
693
+ for src, dst in edges:
694
+ src_id = _endpoint_out(src)
695
+ dst_id = _endpoint_in(dst)
696
+ if src_id != dst_id:
697
+ render_edges.add((src_id, dst_id))
698
+
699
+ for src, dst in extra_edges:
700
+ if src != dst:
701
+ render_edges.add((src, dst))
702
+
703
+ indegree: dict[str, int] = {}
704
+ for _, dst in render_edges:
705
+ indegree[dst] = indegree.get(dst, 0) + 1
706
+
707
+ for src, dst in sorted(render_edges):
708
+ arrow = "-.->" if dash_multi_input_edges and indegree.get(dst, 0) > 1 else "-->"
709
+ lines.append(f" {src} {arrow} {dst}")
710
+
711
+ def _finalize_lines(src_lines: list[str]) -> list[str]:
712
+ if not end_semicolons:
713
+ return src_lines
714
+
715
+ out: list[str] = []
716
+ for line in src_lines:
717
+ stripped = line.rstrip()
718
+ head = stripped.lstrip()
719
+ if (
720
+ not head
721
+ or head.startswith("flowchart ")
722
+ or head.startswith("linkStyle ")
723
+ or head.startswith("subgraph ")
724
+ or head == "end"
725
+ or head.startswith("classDef ")
726
+ or head.startswith("class ")
727
+ or head.startswith("style ")
728
+ or head.startswith("%%")
729
+ ):
730
+ out.append(line)
731
+ continue
732
+
733
+ if stripped.endswith(";"):
734
+ out.append(line)
735
+ else:
736
+ out.append(f"{line};")
737
+
738
+ return out
739
+
740
+ final_lines = _finalize_lines(lines)
741
+
742
+ if compact:
743
+ text = " ".join(final_lines)
744
+ else:
745
+ text = "\n".join(final_lines)
746
+ if copy_to_clipboard:
747
+ _copy_to_clipboard(text)
748
+
749
+ if return_lines:
750
+ return final_lines
751
+ return text
752
+
753
+
754
+ def _copy_to_clipboard(text: str) -> None:
755
+ import os
756
+ import shutil
757
+ import subprocess
758
+ import sys
759
+
760
+ errors: list[str] = []
761
+
762
+ def _try(cmd: list[str]) -> bool:
763
+ try:
764
+ subprocess.run(cmd, input=text.encode("utf-8"), check=True)
765
+ return True
766
+ except Exception as e:
767
+ errors.append(f"{cmd!r}: {type(e).__name__}: {e}")
768
+ return False
769
+
770
+ if sys.platform == "darwin":
771
+ if shutil.which("pbcopy") and _try(["pbcopy"]):
772
+ return
773
+
774
+ elif sys.platform.startswith("win"):
775
+ if shutil.which("clip") and _try(["clip"]):
776
+ return
777
+ if shutil.which("powershell"):
778
+ try:
779
+ subprocess.run(
780
+ ["powershell", "-NoProfile", "-Command", "Set-Clipboard"],
781
+ input=text.encode("utf-8"),
782
+ check=True,
783
+ )
784
+ return
785
+ except Exception as e:
786
+ errors.append(f"powershell Set-Clipboard: {type(e).__name__}: {e}")
787
+
788
+ else:
789
+ if os.environ.get("WAYLAND_DISPLAY") and shutil.which("wl-copy"):
790
+ if _try(["wl-copy"]):
791
+ return
792
+ if shutil.which("xclip"):
793
+ if _try(["xclip", "-selection", "clipboard"]):
794
+ return
795
+ if shutil.which("xsel"):
796
+ if _try(["xsel", "--clipboard", "--input"]):
797
+ return
798
+
799
+ try:
800
+ import tkinter
801
+
802
+ root = tkinter.Tk()
803
+ root.withdraw()
804
+ root.clipboard_clear()
805
+ root.clipboard_append(text)
806
+ root.update()
807
+ root.destroy()
808
+ return
809
+
810
+ except Exception as e:
811
+ errors.append(f"tkinter: {type(e).__name__}: {e}")
812
+
813
+ detail = "\n".join(f"- {msg}" for msg in errors) if errors else "- (no details)"
814
+ raise RuntimeError(
815
+ "Failed to copy to clipboard. Install a clipboard utility (macOS: pbcopy; "
816
+ "Wayland: wl-copy; X11: xclip/xsel; Windows: clip/powershell) or enable tkinter.\n"
817
+ f"{detail}"
818
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.2
3
+ Version: 2.11.3
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,26 +48,19 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
- - Added various inplace tensor operations (e.g. `a.add_(b)`, `a.mul_(b)`)
51
+ - Added additional `nn.Module` hooks for richer introspection during training:
52
52
 
53
- - Added **Noise Conditional Score Network(NCSN)** to `lucid.models.NCSN`
53
+ ```python
54
+ def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
54
55
 
55
- - Branched a Stand-Alone Autograd Engine as `lucid.autograd`
56
+ def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
56
57
 
57
- - Provides a generalized API of computing gradients:
58
-
59
- ```python
60
- import lucid.autograd as autograd
61
- x = lucid.Tensor([1., 2.], requires_grad=True)
62
- y = (x ** 2).sum()
63
- autograd.grad(y, x) # ∂y/∂x
64
- ```
58
+ def register_backward_hook(self, hook: Callable)
65
59
 
66
- - Introduced **Backward Fusion** for CPU execution:
67
- - Automatically fuses selected operation patterns during backpropagation to reduce graph overhead
68
- - Supports identity/unary fusion (e.g. `log∘exp`, double negation, and view-like ops such as reshape/squeeze)
69
- - Uses heuristic thresholds to avoid fusion overhead on small tensors
70
- - Disabled by default on GPU paths to ensure stable performance
60
+ def register_full_backward_pre_hook(self, hook: Callable)
61
+
62
+ def register_full_backward_hook(self, hook: Callable)
63
+ ```
71
64
 
72
65
  ## 🔧 How to Install
73
66
 
@@ -35,19 +35,19 @@ lucid/models/imgclf/__init__.py,sha256=kQH-nNu8_TPJ7Av151WSpcY4GJ06gGAd6Ozs3m3KM
35
35
  lucid/models/imgclf/alex.py,sha256=fZsPdCjWUseCrxBwKj-i5fPSDYLgBpfm0SJe07YKRuE,1472
36
36
  lucid/models/imgclf/coatnet.py,sha256=HKjpy-lBKgz743EijT7jEeMxYjrZHzgU5fOrgtZfxYg,13720
37
37
  lucid/models/imgclf/convnext.py,sha256=mCZwauNw9Czf_fRNWAmxEwqshc_p1TYBj-L_XFBankk,8046
38
- lucid/models/imgclf/crossvit.py,sha256=CJcNqUeVzchgOrrNqMeqNRiPZc3uBLPT1GwCxM-n60E,21299
38
+ lucid/models/imgclf/crossvit.py,sha256=KKSC-Og5f5XmvBRcGC_hMUq6JZ-VdOIBWGxzilb_NIE,21305
39
39
  lucid/models/imgclf/cspnet.py,sha256=dTVwrYKECSqIfaGtYBO3TgcHYm1XyV6E2cXYvEZiuHQ,12618
40
40
  lucid/models/imgclf/cvt.py,sha256=oLGS-srIac1nHl8O6BFewcTM47DAJeUCHd2hRj_RXJ4,15567
41
41
  lucid/models/imgclf/dense.py,sha256=_fuw5Hpxm_9EKWz6IMVuGyp5gU6U2jfPfCDqByI_GRE,4431
42
42
  lucid/models/imgclf/efficient.py,sha256=X2WvLUOATT6bsA35LDPALq-FqNe1WCBJLl17r8xfx4w,13954
43
- lucid/models/imgclf/efficientformer.py,sha256=bX-hLl0YYjTeDc25yOtlQAwaHeVTFFr4TwJkW6l0w14,15184
43
+ lucid/models/imgclf/efficientformer.py,sha256=K--25NspAixmSQQ6s6NPCY8envMXpSJ6Ajuj5A7nclY,15196
44
44
  lucid/models/imgclf/inception.py,sha256=vpTRvfedp0DWHChqBkLH_7jaS4-jqdpBTwC6EzhhM88,21239
45
45
  lucid/models/imgclf/inception_next.py,sha256=W68yrOAAtKnt9_bMt9zj4Fa5PrgfK930Qg6igSML7os,8800
46
46
  lucid/models/imgclf/inception_res.py,sha256=O8L9-EPXbAbOB6Ujp2Vhv_k_9xOWVDqiYScRxbFqJZg,8341
47
47
  lucid/models/imgclf/lenet.py,sha256=55g3M8-OkJ9Se0QP66dTGzPINaq6J3OCfQNu9NqX9NI,2135
48
- lucid/models/imgclf/maxvit.py,sha256=lZS9LBEIbM9D_x9Xf6EgJ_YAE-9DaZThDpOiK_fuhaY,17069
48
+ lucid/models/imgclf/maxvit.py,sha256=SFq0IRBZBpHYPFr4LArDWz5qdGCkfrNqFCRjir_9jaY,17075
49
49
  lucid/models/imgclf/mobile.py,sha256=T52hZv2KfLGTp2e3mgxW-okmCFFIXP8EITcG5XQwhWc,34916
50
- lucid/models/imgclf/pvt.py,sha256=-rjswNvImFI3zjH-R1pPjNKWBLp6Mt17VXU2yKLo1jY,25985
50
+ lucid/models/imgclf/pvt.py,sha256=It3e9TRcgAlystccVChE6auTCrkWV1ngOy0E4BhXhDg,25997
51
51
  lucid/models/imgclf/resnest.py,sha256=c2ajAS70qstlbGTlVC8CpDGhX1GH_IlkPb-I2ZuqGUg,6970
52
52
  lucid/models/imgclf/resnet.py,sha256=Xlz1oTJVUQNZqBRrXTXboLYm9CYU_kLJqR3Xos6FsFY,10111
53
53
  lucid/models/imgclf/resnext.py,sha256=iIoo42rVy58EefHokjfGVuXc3N6x4BEUjtYOS4eczM4,1745
@@ -61,14 +61,14 @@ lucid/models/imgclf/zfnet.py,sha256=brH5tHLVWTUfCqu-BwfFb0yZV9p5DmXN4O6cyP3U26U,
61
61
  lucid/models/imggen/__init__.py,sha256=J6MlEHqXxAYINbeQmyb85ev_IEOvQDTxTQjPgX6hdpY,59
62
62
  lucid/models/imggen/ddpm.py,sha256=Nyi5bp7WMzqZ8Xl2AXuLsMmdz9QSZifbbmMPMeY1YfQ,10752
63
63
  lucid/models/imggen/ncsn.py,sha256=TF_kqnXcPZx4C_eIBqwyeRf3jp7KtiHy3SxQDnbeCj4,13555
64
- lucid/models/imggen/vae.py,sha256=avR8W0rzIsUyJF6veJCIEHavRsPjJsSuTDqzO2wj6sU,4084
64
+ lucid/models/imggen/vae.py,sha256=6KNY_mvalm7Kdu-StfaOtm7B1k_mW1XkFBmjwhaK8QU,4084
65
65
  lucid/models/objdet/__init__.py,sha256=y8_3gjSKlD8Pu0HYburcJ1FAOb4SG5eza_lg9SdFjy8,140
66
66
  lucid/models/objdet/detr.py,sha256=i-KbNoPA9nUuUVBAKRHER-NB6z_sFF5r3Y3oplRaJkI,30336
67
- lucid/models/objdet/efficientdet.py,sha256=o5sSB4giRdTdlDhxp8OkIQXxqLj97eMeLdmSqYlDlWA,22602
67
+ lucid/models/objdet/efficientdet.py,sha256=lU_dUVR9AebxPBcbDEU4uY9nWOPaDrcPThmuyx1yv2Y,23125
68
68
  lucid/models/objdet/fast_rcnn.py,sha256=d3IdI1UdN3CVhC5oPEp1cqo7Mf-d4oqaq3epm8sZgNk,5613
69
69
  lucid/models/objdet/faster_rcnn.py,sha256=wFIcQuKTpsO3H_zcrRLh5p4MdpPksV0cyC53QBekqLI,21671
70
- lucid/models/objdet/rcnn.py,sha256=GC0Wh1zUyZ4Bh_iw8sRmpGuoDhqzgrRU_dIYe8e9xTQ,6981
71
- lucid/models/objdet/util.py,sha256=_92_sPhHZBW6Mc0GFgiBk1eOghMCzWBT8UBP-B9PXek,17016
70
+ lucid/models/objdet/rcnn.py,sha256=IeyhE75iJtYxj99szeUvP-sYOMYGsKZcxcaBxVqqUK0,6987
71
+ lucid/models/objdet/util.py,sha256=KYWhx_aMiihabA4od0AXxen-EQsZMud50_b2OiJ4eqo,17147
72
72
  lucid/models/objdet/yolo/__init__.py,sha256=dBpDRsIdU6G0Q3ldBzVXdJ1yrQz0kTt2v2z-KUtVT3s,92
73
73
  lucid/models/objdet/yolo/yolo_v1.py,sha256=WEYexKTigLweN_wABq8VBBZjfzN8p2Bmawuc1Ayexbk,7160
74
74
  lucid/models/objdet/yolo/yolo_v2.py,sha256=AGl3Hlw7SCzqSVwK52m6__0MV7KmMbgvAVR6PtluOQo,12227
@@ -78,7 +78,7 @@ lucid/models/seq2seq/__init__.py,sha256=wjsrhj4H_AcqwwbebAN8b68QBA8L6p1_12dkG299
78
78
  lucid/models/seq2seq/transformer.py,sha256=y5rerCs1s6jXTsVvbgscWScKpQKuSu1fezsBe7PNTRA,3513
79
79
  lucid/nn/__init__.py,sha256=_hk6KltQIJuWXowXstMSu3TjiaTP8zMLNvGpjnA9Mpw,182
80
80
  lucid/nn/fused.py,sha256=75fcXuo6fHSO-JtjuKhowhHSDr4qc5871WR63sUzH0g,5492
81
- lucid/nn/module.py,sha256=XvFWJ8NqXeZpr3RmKBQBz5eqT535Oi_7DaPN1Zi9gJc,21971
81
+ lucid/nn/module.py,sha256=B2esFTdKJFjqsNUsFhVx_7IE-5nyJwW4b2I8QG1iLUk,25791
82
82
  lucid/nn/parameter.py,sha256=NQS65YKn2B59wZbZIoT1mpDsU_F08y3yLi7hmV1B6yo,1232
83
83
  lucid/nn/util.py,sha256=Yw1iBSPrGV_r_F51qpqLYdafNE_hyaA0DPWYP-rjaig,1699
84
84
  lucid/nn/_kernel/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
@@ -130,12 +130,13 @@ lucid/random/_func.py,sha256=1Lu4m-ciEK037chNDGqv_j00RgGGzQ7UfslSfYActUk,2232
130
130
  lucid/transforms/__init__.py,sha256=DGznMbqhXdU9FLDMKnJawScO4HCqu40Sf_j4vJGJrjc,90
131
131
  lucid/transforms/_base.py,sha256=v3elm7l0VoWvrT_qgoJiRzLH42tHoUcPIKNaPuxI_2E,1448
132
132
  lucid/transforms/image.py,sha256=S9gZzMck4EQSmDQZ3ATi2fsUh4-hqFqeDjhMMJe8TdU,3762
133
- lucid/visual/__init__.py,sha256=6TuFDfmXTwpLyHl7_KqBfdzW6zqHjGzIFvymjFPlvjI,21
134
- lucid/visual/graph.py,sha256=YjpIDM_lloZARw3sCBiXPl_hT5A2gTk2fEHvwvJWXTk,4599
133
+ lucid/visual/__init__.py,sha256=NfHhHYNVv9mQQ4MST3-OAIkAcFyYrihJC4qUf88DySI,44
134
+ lucid/visual/graph.py,sha256=ZSlrJI3dQwYjz8XbgAfNd8-8YuH9Ji7Mz1J6UsnHTaI,4711
135
+ lucid/visual/mermaid.py,sha256=87hFe4l9EYP6Cg2l2hP2INQiBHKkgVClH5nBWFY9ddY,26499
135
136
  lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
136
137
  lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
137
- lucid_dl-2.11.2.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
138
- lucid_dl-2.11.2.dist-info/METADATA,sha256=udbwTB1UUVhJNWAWTGG_aqKNovOu-Qb_iLTi_HVb13I,12312
139
- lucid_dl-2.11.2.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
140
- lucid_dl-2.11.2.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
141
- lucid_dl-2.11.2.dist-info/RECORD,,
138
+ lucid_dl-2.11.3.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
139
+ lucid_dl-2.11.3.dist-info/METADATA,sha256=hffKVg1_fBZA5K3NcRn9KUkIv7B7ilHgr_9OJscZlR0,11898
140
+ lucid_dl-2.11.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
141
+ lucid_dl-2.11.3.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
142
+ lucid_dl-2.11.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5