lucid-dl 2.11.4__py3-none-any.whl → 2.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/datasets/__init__.py +0 -1
- lucid/datasets/cifar.py +259 -6
- lucid/models/imgclf/vit.py +6 -4
- lucid/nn/__init__.py +1 -1
- lucid/nn/modules/rnn.py +133 -28
- lucid/nn/utils/__init__.py +2 -0
- lucid/nn/{util.py → utils/_grad.py} +21 -2
- lucid/nn/utils/rnn.py +237 -0
- lucid/transforms/image.py +2 -2
- lucid/visual/__init__.py +0 -1
- lucid/visual/mermaid.py +188 -2
- {lucid_dl-2.11.4.dist-info → lucid_dl-2.12.0.dist-info}/METADATA +3 -29
- {lucid_dl-2.11.4.dist-info → lucid_dl-2.12.0.dist-info}/RECORD +17 -16
- lucid/visual/graph.py +0 -141
- /lucid/models/{util.py → utils.py} +0 -0
- {lucid_dl-2.11.4.dist-info → lucid_dl-2.12.0.dist-info}/WHEEL +0 -0
- {lucid_dl-2.11.4.dist-info → lucid_dl-2.12.0.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.4.dist-info → lucid_dl-2.12.0.dist-info}/top_level.txt +0 -0
lucid/datasets/__init__.py
CHANGED
lucid/datasets/cifar.py
CHANGED
|
@@ -4,6 +4,8 @@ import openml
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
6
|
from typing import SupportsIndex, Tuple, ClassVar
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import re
|
|
7
9
|
|
|
8
10
|
import lucid
|
|
9
11
|
from lucid._tensor import Tensor
|
|
@@ -17,6 +19,42 @@ __all__ = ["CIFAR10", "CIFAR100"]
|
|
|
17
19
|
class CIFAR10(DatasetBase):
|
|
18
20
|
OPENML_ID: ClassVar[int] = 40927
|
|
19
21
|
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
root: str | Path,
|
|
25
|
+
train: bool | None = True,
|
|
26
|
+
download: bool | None = False,
|
|
27
|
+
transform: lucid.nn.Module | None = None,
|
|
28
|
+
target_transform: lucid.nn.Module | None = None,
|
|
29
|
+
test_size: float = 0.2,
|
|
30
|
+
to_tensor: bool = True,
|
|
31
|
+
*,
|
|
32
|
+
cache: bool = True,
|
|
33
|
+
scale: float | None = None,
|
|
34
|
+
resize: tuple[int, int] | None = None,
|
|
35
|
+
normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
|
|
36
|
+
cache_preprocessed: bool = True,
|
|
37
|
+
preprocess_dtype: lucid.Numeric = lucid.Float16,
|
|
38
|
+
preprocess_chunk_size: int = 4096,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.cache = cache
|
|
41
|
+
self.scale = scale
|
|
42
|
+
self.resize = resize
|
|
43
|
+
self.normalize = normalize
|
|
44
|
+
self.cache_preprocessed = cache_preprocessed
|
|
45
|
+
self.preprocess_dtype = preprocess_dtype
|
|
46
|
+
self.preprocess_chunk_size = preprocess_chunk_size
|
|
47
|
+
|
|
48
|
+
super().__init__(
|
|
49
|
+
root=root,
|
|
50
|
+
train=train,
|
|
51
|
+
download=download,
|
|
52
|
+
transform=transform,
|
|
53
|
+
target_transform=target_transform,
|
|
54
|
+
test_size=test_size,
|
|
55
|
+
to_tensor=to_tensor,
|
|
56
|
+
)
|
|
57
|
+
|
|
20
58
|
def _download(self) -> None:
|
|
21
59
|
try:
|
|
22
60
|
dataset = openml.datasets.get_dataset(self.OPENML_ID)
|
|
@@ -26,7 +64,36 @@ class CIFAR10(DatasetBase):
|
|
|
26
64
|
except Exception as e:
|
|
27
65
|
raise RuntimeError(f"Failed to download the CIFAR-10 dataset. Error: {e}")
|
|
28
66
|
|
|
29
|
-
def
|
|
67
|
+
def _cache_key(self) -> str:
|
|
68
|
+
parts: list[str] = []
|
|
69
|
+
if self.scale is not None:
|
|
70
|
+
parts.append(f"s{self.scale:g}")
|
|
71
|
+
if self.resize is not None:
|
|
72
|
+
parts.append(f"r{self.resize[0]}x{self.resize[1]}")
|
|
73
|
+
if self.normalize is not None:
|
|
74
|
+
mean, std = self.normalize
|
|
75
|
+
parts.append("m" + ",".join(f"{v:g}" for v in mean))
|
|
76
|
+
parts.append("v" + ",".join(f"{v:g}" for v in std))
|
|
77
|
+
if not parts:
|
|
78
|
+
return "raw"
|
|
79
|
+
key = "_".join(parts)
|
|
80
|
+
return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
|
|
81
|
+
|
|
82
|
+
def _raw_cache_path(self) -> Path:
|
|
83
|
+
return self.root / "CIFAR10_uint8.npz"
|
|
84
|
+
|
|
85
|
+
def _proc_cache_path(self) -> Path:
|
|
86
|
+
dtype_name = str(self.preprocess_dtype)
|
|
87
|
+
return self.root / f"CIFAR10_{self._cache_key()}_{dtype_name}.npz"
|
|
88
|
+
|
|
89
|
+
def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
|
|
90
|
+
raw_path = self._raw_cache_path()
|
|
91
|
+
if self.cache and raw_path.exists():
|
|
92
|
+
with np.load(raw_path) as npz:
|
|
93
|
+
images = npz["images"]
|
|
94
|
+
labels = npz["labels"]
|
|
95
|
+
return images, labels
|
|
96
|
+
|
|
30
97
|
csv_path = self.root / "CIFAR10.csv"
|
|
31
98
|
if not csv_path.exists():
|
|
32
99
|
raise RuntimeError(
|
|
@@ -36,9 +103,69 @@ class CIFAR10(DatasetBase):
|
|
|
36
103
|
|
|
37
104
|
df = pd.read_csv(csv_path)
|
|
38
105
|
labels = df["class"].values.astype(np.int32)
|
|
39
|
-
images = df.drop(columns=["class"]).values.astype(np.
|
|
106
|
+
images = df.drop(columns=["class"]).values.astype(np.uint8, copy=False)
|
|
40
107
|
images = images.reshape(-1, 3, 32, 32)
|
|
41
108
|
|
|
109
|
+
if self.cache:
|
|
110
|
+
np.savez_compressed(raw_path, images=images, labels=labels)
|
|
111
|
+
|
|
112
|
+
return images, labels
|
|
113
|
+
|
|
114
|
+
def _maybe_preprocess_and_cache(
|
|
115
|
+
self, images_uint8: np.ndarray, labels_int32: np.ndarray
|
|
116
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
117
|
+
if self.resize is None and self.scale is None and self.normalize is None:
|
|
118
|
+
return images_uint8.astype(np.float32), labels_int32
|
|
119
|
+
|
|
120
|
+
proc_path = self._proc_cache_path()
|
|
121
|
+
if self.cache and self.cache_preprocessed and proc_path.exists():
|
|
122
|
+
with np.load(proc_path) as npz:
|
|
123
|
+
images = npz["images"]
|
|
124
|
+
labels = npz["labels"]
|
|
125
|
+
return images, labels
|
|
126
|
+
|
|
127
|
+
from lucid.transforms import Compose, Resize, Normalize
|
|
128
|
+
|
|
129
|
+
class _Scale(lucid.nn.Module):
|
|
130
|
+
def __init__(self, factor: float) -> None:
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.factor = factor
|
|
133
|
+
|
|
134
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
135
|
+
return x * self.factor
|
|
136
|
+
|
|
137
|
+
transforms: list[lucid.nn.Module] = []
|
|
138
|
+
if self.resize is not None:
|
|
139
|
+
transforms.append(Resize(self.resize))
|
|
140
|
+
if self.scale is not None:
|
|
141
|
+
transforms.append(_Scale(self.scale))
|
|
142
|
+
if self.normalize is not None:
|
|
143
|
+
mean, std = self.normalize
|
|
144
|
+
transforms.append(Normalize(mean=mean, std=std))
|
|
145
|
+
|
|
146
|
+
transform = Compose(transforms)
|
|
147
|
+
n = images_uint8.shape[0]
|
|
148
|
+
out_h, out_w = self.resize if self.resize is not None else (32, 32)
|
|
149
|
+
|
|
150
|
+
out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
|
|
151
|
+
out_images = np.empty((n, 3, out_h, out_w), dtype=out_dtype)
|
|
152
|
+
|
|
153
|
+
for start in range(0, n, self.preprocess_chunk_size):
|
|
154
|
+
end = min(start + self.preprocess_chunk_size, n)
|
|
155
|
+
chunk = images_uint8[start:end].astype(np.float32)
|
|
156
|
+
x = lucid.to_tensor(chunk, dtype=lucid.Float32)
|
|
157
|
+
x = transform(x)
|
|
158
|
+
out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
|
|
159
|
+
|
|
160
|
+
if self.cache and self.cache_preprocessed:
|
|
161
|
+
np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
|
|
162
|
+
|
|
163
|
+
return out_images, labels_int32
|
|
164
|
+
|
|
165
|
+
def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
|
|
166
|
+
images, labels = self._ensure_raw_cache()
|
|
167
|
+
images, labels = self._maybe_preprocess_and_cache(images, labels)
|
|
168
|
+
|
|
42
169
|
train_size = int(math.floor(len(images) * (1 - self.test_size)))
|
|
43
170
|
if split == "train":
|
|
44
171
|
images, labels = images[:train_size], labels[:train_size]
|
|
@@ -52,7 +179,7 @@ class CIFAR10(DatasetBase):
|
|
|
52
179
|
return images, labels
|
|
53
180
|
|
|
54
181
|
def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
|
|
55
|
-
image = self.data[index]
|
|
182
|
+
image = self.data[index]
|
|
56
183
|
label = self.targets[index]
|
|
57
184
|
|
|
58
185
|
if self.transform:
|
|
@@ -66,6 +193,42 @@ class CIFAR10(DatasetBase):
|
|
|
66
193
|
class CIFAR100(DatasetBase):
|
|
67
194
|
OPENML_ID: ClassVar[int] = 41983
|
|
68
195
|
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
root: str | Path,
|
|
199
|
+
train: bool | None = True,
|
|
200
|
+
download: bool | None = False,
|
|
201
|
+
transform: lucid.nn.Module | None = None,
|
|
202
|
+
target_transform: lucid.nn.Module | None = None,
|
|
203
|
+
test_size: float = 0.2,
|
|
204
|
+
to_tensor: bool = True,
|
|
205
|
+
*,
|
|
206
|
+
cache: bool = True,
|
|
207
|
+
scale: float | None = None,
|
|
208
|
+
resize: tuple[int, int] | None = None,
|
|
209
|
+
normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
|
|
210
|
+
cache_preprocessed: bool = True,
|
|
211
|
+
preprocess_dtype: lucid.Numeric = lucid.Float16,
|
|
212
|
+
preprocess_chunk_size: int = 4096,
|
|
213
|
+
) -> None:
|
|
214
|
+
self.cache = cache
|
|
215
|
+
self.scale = scale
|
|
216
|
+
self.resize = resize
|
|
217
|
+
self.normalize = normalize
|
|
218
|
+
self.cache_preprocessed = cache_preprocessed
|
|
219
|
+
self.preprocess_dtype = preprocess_dtype
|
|
220
|
+
self.preprocess_chunk_size = preprocess_chunk_size
|
|
221
|
+
|
|
222
|
+
super().__init__(
|
|
223
|
+
root=root,
|
|
224
|
+
train=train,
|
|
225
|
+
download=download,
|
|
226
|
+
transform=transform,
|
|
227
|
+
target_transform=target_transform,
|
|
228
|
+
test_size=test_size,
|
|
229
|
+
to_tensor=to_tensor,
|
|
230
|
+
)
|
|
231
|
+
|
|
69
232
|
def _download(self) -> None:
|
|
70
233
|
try:
|
|
71
234
|
dataset = openml.datasets.get_dataset(self.OPENML_ID)
|
|
@@ -75,7 +238,37 @@ class CIFAR100(DatasetBase):
|
|
|
75
238
|
except Exception as e:
|
|
76
239
|
raise RuntimeError(f"Failed to download the CIFAR-100 dataset. Error: {e}")
|
|
77
240
|
|
|
78
|
-
def
|
|
241
|
+
def _cache_key(self) -> str:
|
|
242
|
+
parts: list[str] = []
|
|
243
|
+
if self.scale is not None:
|
|
244
|
+
parts.append(f"s{self.scale:g}")
|
|
245
|
+
if self.resize is not None:
|
|
246
|
+
parts.append(f"r{self.resize[0]}x{self.resize[1]}")
|
|
247
|
+
if self.normalize is not None:
|
|
248
|
+
mean, std = self.normalize
|
|
249
|
+
parts.append("m" + ",".join(f"{v:g}" for v in mean))
|
|
250
|
+
parts.append("v" + ",".join(f"{v:g}" for v in std))
|
|
251
|
+
if not parts:
|
|
252
|
+
return "raw"
|
|
253
|
+
|
|
254
|
+
key = "_".join(parts)
|
|
255
|
+
return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
|
|
256
|
+
|
|
257
|
+
def _raw_cache_path(self) -> Path:
|
|
258
|
+
return self.root / "CIFAR100_uint8.npz"
|
|
259
|
+
|
|
260
|
+
def _proc_cache_path(self) -> Path:
|
|
261
|
+
dtype_name = str(self.preprocess_dtype)
|
|
262
|
+
return self.root / f"CIFAR100_{self._cache_key()}_{dtype_name}.npz"
|
|
263
|
+
|
|
264
|
+
def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
|
|
265
|
+
raw_path = self._raw_cache_path()
|
|
266
|
+
if self.cache and raw_path.exists():
|
|
267
|
+
with np.load(raw_path) as npz:
|
|
268
|
+
images = npz["images"]
|
|
269
|
+
labels = npz["labels"]
|
|
270
|
+
return images, labels
|
|
271
|
+
|
|
79
272
|
csv_path = self.root / "CIFAR100.csv"
|
|
80
273
|
if not csv_path.exists():
|
|
81
274
|
raise RuntimeError(
|
|
@@ -85,9 +278,69 @@ class CIFAR100(DatasetBase):
|
|
|
85
278
|
|
|
86
279
|
df = pd.read_csv(csv_path)
|
|
87
280
|
labels = df["class"].values.astype(np.int32)
|
|
88
|
-
images = df.drop(columns=["class"]).values.astype(np.
|
|
281
|
+
images = df.drop(columns=["class"]).values.astype(np.uint8, copy=False)
|
|
89
282
|
images = images.reshape(-1, 3, 32, 32)
|
|
90
283
|
|
|
284
|
+
if self.cache:
|
|
285
|
+
np.savez_compressed(raw_path, images=images, labels=labels)
|
|
286
|
+
|
|
287
|
+
return images, labels
|
|
288
|
+
|
|
289
|
+
def _maybe_preprocess_and_cache(
|
|
290
|
+
self, images_uint8: np.ndarray, labels_int32: np.ndarray
|
|
291
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
292
|
+
if self.resize is None and self.scale is None and self.normalize is None:
|
|
293
|
+
return images_uint8.astype(np.float32), labels_int32
|
|
294
|
+
|
|
295
|
+
proc_path = self._proc_cache_path()
|
|
296
|
+
if self.cache and self.cache_preprocessed and proc_path.exists():
|
|
297
|
+
with np.load(proc_path) as npz:
|
|
298
|
+
images = npz["images"]
|
|
299
|
+
labels = npz["labels"]
|
|
300
|
+
return images, labels
|
|
301
|
+
|
|
302
|
+
from lucid.transforms import Compose, Resize, Normalize
|
|
303
|
+
|
|
304
|
+
class _Scale(lucid.nn.Module):
|
|
305
|
+
def __init__(self, factor: float) -> None:
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.factor = factor
|
|
308
|
+
|
|
309
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
310
|
+
return x * self.factor
|
|
311
|
+
|
|
312
|
+
transforms: list[lucid.nn.Module] = []
|
|
313
|
+
if self.resize is not None:
|
|
314
|
+
transforms.append(Resize(self.resize))
|
|
315
|
+
if self.scale is not None:
|
|
316
|
+
transforms.append(_Scale(self.scale))
|
|
317
|
+
if self.normalize is not None:
|
|
318
|
+
mean, std = self.normalize
|
|
319
|
+
transforms.append(Normalize(mean=mean, std=std))
|
|
320
|
+
|
|
321
|
+
transform = Compose(transforms)
|
|
322
|
+
n = images_uint8.shape[0]
|
|
323
|
+
out_h, out_w = self.resize if self.resize is not None else (32, 32)
|
|
324
|
+
|
|
325
|
+
out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
|
|
326
|
+
out_images = np.empty((n, 3, out_h, out_w), dtype=out_dtype)
|
|
327
|
+
|
|
328
|
+
for start in range(0, n, self.preprocess_chunk_size):
|
|
329
|
+
end = min(start + self.preprocess_chunk_size, n)
|
|
330
|
+
chunk = images_uint8[start:end].astype(np.float32)
|
|
331
|
+
x = lucid.to_tensor(chunk, dtype=lucid.Float32)
|
|
332
|
+
x = transform(x)
|
|
333
|
+
out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
|
|
334
|
+
|
|
335
|
+
if self.cache and self.cache_preprocessed:
|
|
336
|
+
np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
|
|
337
|
+
|
|
338
|
+
return out_images, labels_int32
|
|
339
|
+
|
|
340
|
+
def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
|
|
341
|
+
images, labels = self._ensure_raw_cache()
|
|
342
|
+
images, labels = self._maybe_preprocess_and_cache(images, labels)
|
|
343
|
+
|
|
91
344
|
train_size = int(math.floor(len(images) * (1 - self.test_size)))
|
|
92
345
|
if split == "train":
|
|
93
346
|
images, labels = images[:train_size], labels[:train_size]
|
|
@@ -101,7 +354,7 @@ class CIFAR100(DatasetBase):
|
|
|
101
354
|
return images, labels
|
|
102
355
|
|
|
103
356
|
def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
|
|
104
|
-
image = self.data[index]
|
|
357
|
+
image = self.data[index]
|
|
105
358
|
label = self.targets[index]
|
|
106
359
|
|
|
107
360
|
if self.transform:
|
lucid/models/imgclf/vit.py
CHANGED
|
@@ -32,10 +32,12 @@ class ViT(nn.Module):
|
|
|
32
32
|
in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
-
self.cls_token = nn.Parameter(lucid.
|
|
36
|
-
self.pos_emb = nn.Parameter(
|
|
37
|
-
|
|
38
|
-
)
|
|
35
|
+
self.cls_token = nn.Parameter(lucid.zeros(1, 1, embedding_dim))
|
|
36
|
+
self.pos_emb = nn.Parameter(lucid.zeros(1, 1 + self.num_patches, embedding_dim))
|
|
37
|
+
|
|
38
|
+
nn.init.normal(self.cls_token, std=0.02)
|
|
39
|
+
nn.init.normal(self.pos_emb, std=0.02)
|
|
40
|
+
|
|
39
41
|
self.dropout = nn.Dropout(dropout_rate)
|
|
40
42
|
|
|
41
43
|
encoder_layer = nn.TransformerEncoderLayer(
|
lucid/nn/__init__.py
CHANGED
lucid/nn/modules/rnn.py
CHANGED
|
@@ -5,6 +5,7 @@ import lucid.nn as nn
|
|
|
5
5
|
import lucid.nn.functional as F
|
|
6
6
|
|
|
7
7
|
from lucid._tensor import Tensor
|
|
8
|
+
from lucid.nn.utils.rnn import PackedSequence
|
|
8
9
|
from lucid.types import Numeric, _DeviceType
|
|
9
10
|
|
|
10
11
|
from .activation import Tanh, ReLU
|
|
@@ -351,21 +352,47 @@ class RNNBase(nn.Module):
|
|
|
351
352
|
)
|
|
352
353
|
|
|
353
354
|
def forward(
|
|
354
|
-
self,
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
355
|
+
self,
|
|
356
|
+
input_: Tensor | PackedSequence,
|
|
357
|
+
hx: Tensor | tuple[Tensor, Tensor] | None = None,
|
|
358
|
+
) -> (
|
|
359
|
+
tuple[Tensor | PackedSequence, Tensor]
|
|
360
|
+
| tuple[Tensor | PackedSequence, tuple[Tensor, Tensor]]
|
|
361
|
+
):
|
|
362
|
+
is_packed = isinstance(input_, PackedSequence)
|
|
363
|
+
if is_packed:
|
|
364
|
+
data = input_.data
|
|
365
|
+
batch_sizes = input_.batch_sizes
|
|
366
|
+
if data.ndim != 2:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
"RNNBase expected packed data with 2 dimensions, "
|
|
369
|
+
f"got {data.ndim} dimensions"
|
|
370
|
+
)
|
|
371
|
+
if batch_sizes.ndim != 1 or batch_sizes.shape[0] == 0:
|
|
372
|
+
raise ValueError(
|
|
373
|
+
"PackedSequence batch_sizes must be a non-empty 1D tensor"
|
|
374
|
+
)
|
|
360
375
|
|
|
361
|
-
|
|
362
|
-
|
|
376
|
+
batch_size = int(batch_sizes[0].item())
|
|
377
|
+
feat = data.shape[1]
|
|
378
|
+
if feat != self.input_size:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"RNNBase expected input with feature size {self.input_size}, got {feat}"
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
if input_.ndim != 3:
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"RNNBase expected input with 3 dimensions, got {input_.ndim} dimensions"
|
|
386
|
+
)
|
|
363
387
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
388
|
+
if self.batch_first:
|
|
389
|
+
input_ = input_.swapaxes(0, 1)
|
|
390
|
+
|
|
391
|
+
seq_len, batch_size, feat = input_.shape
|
|
392
|
+
if feat != self.input_size:
|
|
393
|
+
raise ValueError(
|
|
394
|
+
f"RNNBase expected input with feature size {self.input_size}, got {feat}"
|
|
395
|
+
)
|
|
369
396
|
|
|
370
397
|
if self.is_lstm:
|
|
371
398
|
if hx is None:
|
|
@@ -410,7 +437,7 @@ class RNNBase(nn.Module):
|
|
|
410
437
|
if hx.shape[2] != self.hidden_size:
|
|
411
438
|
raise ValueError("Incorrect hidden size in hx")
|
|
412
439
|
|
|
413
|
-
layer_input = input_
|
|
440
|
+
layer_input = data if is_packed else input_
|
|
414
441
|
h_n_list: list[Tensor] = []
|
|
415
442
|
c_n_list: list[Tensor] | None = [] if self.is_lstm else None
|
|
416
443
|
|
|
@@ -420,33 +447,111 @@ class RNNBase(nn.Module):
|
|
|
420
447
|
c_t = hx_c[layer_idx]
|
|
421
448
|
else:
|
|
422
449
|
h_t = hx[layer_idx]
|
|
450
|
+
|
|
423
451
|
outputs = []
|
|
452
|
+
if is_packed:
|
|
453
|
+
final_h: list[Tensor] = []
|
|
454
|
+
final_c: list[Tensor] | None = [] if self.is_lstm else None
|
|
455
|
+
offset = 0
|
|
456
|
+
|
|
457
|
+
prev_bs: int | None = None
|
|
458
|
+
max_len = int(batch_sizes.shape[0])
|
|
459
|
+
for t in range(max_len):
|
|
460
|
+
bs = int(batch_sizes[t].item())
|
|
461
|
+
if bs == 0:
|
|
462
|
+
break
|
|
463
|
+
|
|
464
|
+
if prev_bs is None:
|
|
465
|
+
prev_bs = bs
|
|
466
|
+
if bs > prev_bs:
|
|
467
|
+
raise ValueError(
|
|
468
|
+
"PackedSequence batch_sizes must be non-increasing"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if bs < prev_bs:
|
|
472
|
+
final_h.append(h_t[bs:prev_bs])
|
|
473
|
+
if self.is_lstm and final_c is not None:
|
|
474
|
+
final_c.append(c_t[bs:prev_bs])
|
|
475
|
+
|
|
476
|
+
h_t = h_t[:bs]
|
|
477
|
+
if self.is_lstm:
|
|
478
|
+
c_t = c_t[:bs]
|
|
479
|
+
|
|
480
|
+
step_input = layer_input[offset : offset + bs]
|
|
481
|
+
offset += bs
|
|
482
|
+
|
|
483
|
+
if self.is_lstm:
|
|
484
|
+
h_t, c_t = cell(step_input, (h_t, c_t))
|
|
485
|
+
else:
|
|
486
|
+
h_t = cell(step_input, h_t)
|
|
487
|
+
|
|
488
|
+
outputs.append(h_t)
|
|
489
|
+
prev_bs = bs
|
|
490
|
+
|
|
491
|
+
final_h.append(h_t)
|
|
492
|
+
if self.is_lstm and final_c is not None:
|
|
493
|
+
final_c.append(c_t)
|
|
494
|
+
|
|
495
|
+
h_n_list.append(
|
|
496
|
+
lucid.concatenate(tuple(reversed(final_h)), axis=0).unsqueeze(
|
|
497
|
+
axis=0
|
|
498
|
+
)
|
|
499
|
+
)
|
|
500
|
+
if self.is_lstm and final_c is not None and c_n_list is not None:
|
|
501
|
+
c_n_list.append(
|
|
502
|
+
lucid.concatenate(tuple(reversed(final_c)), axis=0).unsqueeze(
|
|
503
|
+
axis=0
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
layer_output = (
|
|
508
|
+
lucid.concatenate(tuple(outputs), axis=0)
|
|
509
|
+
if outputs
|
|
510
|
+
else layer_input[:0]
|
|
511
|
+
)
|
|
424
512
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
513
|
+
else:
|
|
514
|
+
for t in range(seq_len):
|
|
515
|
+
if self.is_lstm:
|
|
516
|
+
h_t, c_t = cell(layer_input[t], (h_t, c_t))
|
|
517
|
+
outputs.append(h_t.unsqueeze(axis=0))
|
|
518
|
+
else:
|
|
519
|
+
h_t = cell(layer_input[t], h_t)
|
|
520
|
+
outputs.append(h_t.unsqueeze(axis=0))
|
|
432
521
|
|
|
433
|
-
|
|
522
|
+
layer_output = lucid.concatenate(tuple(outputs), axis=0)
|
|
434
523
|
|
|
435
524
|
if self.training and self.dropout > 0.0 and layer_idx < self.num_layers - 1:
|
|
436
525
|
layer_output = F.dropout(layer_output, p=self.dropout)
|
|
437
526
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
c_n_list
|
|
527
|
+
if not is_packed:
|
|
528
|
+
h_n_list.append(h_t.unsqueeze(axis=0))
|
|
529
|
+
if self.is_lstm and c_n_list is not None:
|
|
530
|
+
c_n_list.append(c_t.unsqueeze(axis=0))
|
|
441
531
|
layer_input = layer_output
|
|
442
532
|
|
|
443
|
-
|
|
533
|
+
if is_packed:
|
|
534
|
+
output = PackedSequence(
|
|
535
|
+
data=layer_input,
|
|
536
|
+
batch_sizes=batch_sizes,
|
|
537
|
+
sorted_indices=input_.sorted_indices,
|
|
538
|
+
unsorted_indices=input_.unsorted_indices,
|
|
539
|
+
)
|
|
540
|
+
else:
|
|
541
|
+
output = layer_input
|
|
542
|
+
|
|
444
543
|
h_n = lucid.concatenate(tuple(h_n_list), axis=0)
|
|
445
544
|
if self.is_lstm and c_n_list is not None:
|
|
446
545
|
c_n = lucid.concatenate(tuple(c_n_list), axis=0)
|
|
447
546
|
|
|
448
|
-
if
|
|
449
|
-
|
|
547
|
+
if is_packed:
|
|
548
|
+
if input_.unsorted_indices is not None:
|
|
549
|
+
h_n = h_n[:, input_.unsorted_indices]
|
|
550
|
+
if self.is_lstm and c_n_list is not None:
|
|
551
|
+
c_n = c_n[:, input_.unsorted_indices]
|
|
552
|
+
else:
|
|
553
|
+
if self.batch_first:
|
|
554
|
+
output = output.swapaxes(0, 1)
|
|
450
555
|
|
|
451
556
|
if self.is_lstm and c_n_list is not None:
|
|
452
557
|
return output, (h_n, c_n)
|
|
@@ -6,7 +6,7 @@ from lucid._tensor import Tensor
|
|
|
6
6
|
from lucid.types import _Scalar
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
__all__ = ["grad_norm", "clip_grad_norm", "clip_grad_value"]
|
|
9
|
+
__all__ = ["grad_norm", "get_total_norm", "clip_grad_norm", "clip_grad_value"]
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def _as_iter(parameters: Iterable[Tensor] | Tensor) -> list[Tensor]:
|
|
@@ -32,6 +32,25 @@ def grad_norm(parameters: Iterable[Tensor] | Tensor, norm_type: int = 2) -> Tens
|
|
|
32
32
|
return Tensor(total_norm, device=device)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
def get_total_norm(parameters: Iterable[Tensor] | Tensor, norm_type: int = 2) -> Tensor:
|
|
36
|
+
parameters = _as_iter(parameters)
|
|
37
|
+
if not parameters:
|
|
38
|
+
return Tensor(0.0)
|
|
39
|
+
|
|
40
|
+
device = parameters[0].device
|
|
41
|
+
grads: list[Tensor] = [p.grad for p in parameters if p.grad is not None]
|
|
42
|
+
if not grads:
|
|
43
|
+
return Tensor(0.0, device=device)
|
|
44
|
+
|
|
45
|
+
norm_pow_sum = 0.0
|
|
46
|
+
for g in grads:
|
|
47
|
+
grad_norm = lucid.linalg.norm(lucid.ravel(g), ord=norm_type).item()
|
|
48
|
+
norm_pow_sum += grad_norm**norm_type
|
|
49
|
+
|
|
50
|
+
total_norm = norm_pow_sum ** (1.0 / norm_type)
|
|
51
|
+
return Tensor(total_norm, device=device)
|
|
52
|
+
|
|
53
|
+
|
|
35
54
|
def clip_grad_norm(
|
|
36
55
|
parameters: Iterable[Tensor] | Tensor,
|
|
37
56
|
max_norm: _Scalar,
|
|
@@ -39,7 +58,7 @@ def clip_grad_norm(
|
|
|
39
58
|
eps: float = 1e-7,
|
|
40
59
|
) -> float:
|
|
41
60
|
params: list[Tensor] = [p for p in _as_iter(parameters) if p.grad is not None]
|
|
42
|
-
total_norm =
|
|
61
|
+
total_norm = get_total_norm(params, norm_type=norm_type)
|
|
43
62
|
|
|
44
63
|
clip_coef = float(max_norm) / (total_norm.item() + eps)
|
|
45
64
|
if clip_coef < 1.0:
|
lucid/nn/utils/rnn.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Iterable, Sequence
|
|
3
|
+
|
|
4
|
+
import lucid
|
|
5
|
+
|
|
6
|
+
from lucid._tensor import Tensor
|
|
7
|
+
from lucid.types import _Scalar
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"PackedSequence",
|
|
12
|
+
"pad_sequence",
|
|
13
|
+
"pack_padded_sequence",
|
|
14
|
+
"pad_packed_sequence",
|
|
15
|
+
"pack_sequence",
|
|
16
|
+
"unpack_sequence",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class PackedSequence:
|
|
22
|
+
data: Tensor
|
|
23
|
+
batch_sizes: Tensor
|
|
24
|
+
sorted_indices: Tensor | None = None
|
|
25
|
+
unsorted_indices: Tensor | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def pad_sequence(
|
|
29
|
+
sequences: Iterable[Tensor], batch_first: bool = False, padding_value: _Scalar = 0
|
|
30
|
+
) -> Tensor:
|
|
31
|
+
seq_list = list(sequences)
|
|
32
|
+
if not seq_list:
|
|
33
|
+
raise ValueError("pad_sequence expected a non-empty iterable of Tensors")
|
|
34
|
+
|
|
35
|
+
first = seq_list[0]
|
|
36
|
+
if not isinstance(first, Tensor):
|
|
37
|
+
raise TypeError("pad_sequence expects Tensor elements")
|
|
38
|
+
|
|
39
|
+
ndim = first.ndim
|
|
40
|
+
if ndim < 1:
|
|
41
|
+
raise ValueError("pad_sequence expects tensors with at least 1 dimension")
|
|
42
|
+
|
|
43
|
+
trailing_shape = first.shape[1:]
|
|
44
|
+
device = first.device
|
|
45
|
+
dtype = first.dtype
|
|
46
|
+
|
|
47
|
+
lengths: list[int] = []
|
|
48
|
+
for idx, seq in enumerate(seq_list):
|
|
49
|
+
if not isinstance(seq, Tensor):
|
|
50
|
+
raise TypeError("pad_sequence expects Tensor elements")
|
|
51
|
+
if seq.ndim != ndim:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"pad_sequence expects tensors with {ndim} dimensions, "
|
|
54
|
+
f"got {seq.ndim} at index {idx}"
|
|
55
|
+
)
|
|
56
|
+
if seq.shape[1:] != trailing_shape:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"pad_sequence expects all tensors to share the same trailing shape"
|
|
59
|
+
)
|
|
60
|
+
if seq.device != device:
|
|
61
|
+
raise ValueError("pad_sequence expects all tensors on the same device")
|
|
62
|
+
if seq.dtype != dtype:
|
|
63
|
+
raise ValueError("pad_sequence expects all tensors with the same dtype")
|
|
64
|
+
lengths.append(seq.shape[0])
|
|
65
|
+
|
|
66
|
+
max_len = max(lengths)
|
|
67
|
+
batch_size = len(seq_list)
|
|
68
|
+
|
|
69
|
+
if batch_first:
|
|
70
|
+
out_shape = (batch_size, max_len, *trailing_shape)
|
|
71
|
+
else:
|
|
72
|
+
out_shape = (max_len, batch_size, *trailing_shape)
|
|
73
|
+
|
|
74
|
+
output = lucid.full(out_shape, padding_value, dtype=dtype, device=device)
|
|
75
|
+
for i, seq in enumerate(seq_list):
|
|
76
|
+
length = lengths[i]
|
|
77
|
+
if length == 0:
|
|
78
|
+
continue
|
|
79
|
+
if batch_first:
|
|
80
|
+
output[i, :length] = seq
|
|
81
|
+
else:
|
|
82
|
+
output[:length, i] = seq
|
|
83
|
+
|
|
84
|
+
return output
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _as_lengths(lengths: Sequence[int] | Tensor, *, device: str) -> Tensor:
|
|
88
|
+
if isinstance(lengths, Tensor):
|
|
89
|
+
return lengths
|
|
90
|
+
return Tensor(list(lengths), device=device)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _invert_permutation(indices: Tensor) -> Tensor:
|
|
94
|
+
return lucid.argsort(indices, axis=0)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def pack_padded_sequence(
|
|
98
|
+
input_: Tensor,
|
|
99
|
+
lengths: Sequence[int] | Tensor,
|
|
100
|
+
batch_first: bool = False,
|
|
101
|
+
enforce_sorted: bool = True,
|
|
102
|
+
) -> PackedSequence:
|
|
103
|
+
if input_.ndim < 2:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"pack_padded_sequence expected input with at least 2 dims, got {input_.ndim}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if batch_first:
|
|
109
|
+
input_ = input_.swapaxes(0, 1)
|
|
110
|
+
|
|
111
|
+
seq_len, batch_size = input_.shape[0], input_.shape[1]
|
|
112
|
+
lengths_t = _as_lengths(lengths, device=input_.device)
|
|
113
|
+
if lengths_t.ndim != 1:
|
|
114
|
+
raise ValueError("lengths must be a 1D sequence or tensor")
|
|
115
|
+
if lengths_t.shape[0] != batch_size:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"lengths size {lengths_t.shape[0]} does not match batch size {batch_size}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
sorted_indices = None
|
|
121
|
+
unsorted_indices = None
|
|
122
|
+
|
|
123
|
+
if enforce_sorted:
|
|
124
|
+
sorted_lengths = lengths_t
|
|
125
|
+
else:
|
|
126
|
+
sorted_indices = lucid.argsort(lengths_t, descending=True, axis=0)
|
|
127
|
+
unsorted_indices = _invert_permutation(sorted_indices)
|
|
128
|
+
|
|
129
|
+
lengths_t = lengths_t[sorted_indices]
|
|
130
|
+
input_ = input_[:, sorted_indices]
|
|
131
|
+
sorted_lengths = lengths_t
|
|
132
|
+
|
|
133
|
+
max_len = int(sorted_lengths[0].item())
|
|
134
|
+
if max_len > seq_len:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"lengths has max {max_len} but input has sequence length {seq_len}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
batch_sizes: list[int] = []
|
|
140
|
+
chunks: list[Tensor] = []
|
|
141
|
+
for t in range(max_len):
|
|
142
|
+
bs = int((sorted_lengths > t).sum().item())
|
|
143
|
+
batch_sizes.append(bs)
|
|
144
|
+
if bs == 0:
|
|
145
|
+
break
|
|
146
|
+
chunks.append(input_[t, :bs])
|
|
147
|
+
|
|
148
|
+
if not chunks:
|
|
149
|
+
data = input_[:0]
|
|
150
|
+
else:
|
|
151
|
+
data = lucid.concatenate(tuple(chunks), axis=0)
|
|
152
|
+
|
|
153
|
+
return PackedSequence(
|
|
154
|
+
data=data,
|
|
155
|
+
batch_sizes=Tensor(batch_sizes, device=input_.device),
|
|
156
|
+
sorted_indices=sorted_indices,
|
|
157
|
+
unsorted_indices=unsorted_indices,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def pad_packed_sequence(
|
|
162
|
+
sequence: PackedSequence, batch_first: bool = False, padding_value: _Scalar = 0
|
|
163
|
+
) -> tuple[Tensor, Tensor]:
|
|
164
|
+
data = sequence.data
|
|
165
|
+
batch_sizes = sequence.batch_sizes
|
|
166
|
+
if batch_sizes.ndim != 1:
|
|
167
|
+
raise ValueError("batch_sizes must be 1D")
|
|
168
|
+
|
|
169
|
+
max_len = int(batch_sizes.shape[0])
|
|
170
|
+
if max_len == 0:
|
|
171
|
+
raise ValueError("batch_sizes must be non-empty")
|
|
172
|
+
|
|
173
|
+
batch_size = int(batch_sizes[0].item())
|
|
174
|
+
trailing_shape = data.shape[1:]
|
|
175
|
+
|
|
176
|
+
if batch_first:
|
|
177
|
+
out_shape = (batch_size, max_len, *trailing_shape)
|
|
178
|
+
else:
|
|
179
|
+
out_shape = (max_len, batch_size, *trailing_shape)
|
|
180
|
+
|
|
181
|
+
output = lucid.full(out_shape, padding_value, dtype=data.dtype, device=data.device)
|
|
182
|
+
|
|
183
|
+
lengths = [0] * batch_size
|
|
184
|
+
offset = 0
|
|
185
|
+
for t in range(max_len):
|
|
186
|
+
bs = int(batch_sizes[t].item())
|
|
187
|
+
if bs == 0:
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
chunk = data[offset : offset + bs]
|
|
191
|
+
offset += bs
|
|
192
|
+
for i in range(bs):
|
|
193
|
+
lengths[i] += 1
|
|
194
|
+
if batch_first:
|
|
195
|
+
output[:bs, t] = chunk
|
|
196
|
+
else:
|
|
197
|
+
output[t, :bs] = chunk
|
|
198
|
+
|
|
199
|
+
lengths_t = Tensor(lengths, device=data.device)
|
|
200
|
+
if sequence.unsorted_indices is not None:
|
|
201
|
+
if batch_first:
|
|
202
|
+
output = output[sequence.unsorted_indices]
|
|
203
|
+
else:
|
|
204
|
+
output = output[:, sequence.unsorted_indices]
|
|
205
|
+
lengths_t = lengths_t[sequence.unsorted_indices]
|
|
206
|
+
|
|
207
|
+
return output, lengths_t
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def pack_sequence(
|
|
211
|
+
sequences: Iterable[Tensor], enforce_sorted: bool = True
|
|
212
|
+
) -> PackedSequence:
|
|
213
|
+
seq_list = list(sequences)
|
|
214
|
+
if not seq_list:
|
|
215
|
+
raise ValueError("pack_sequence expected a non-empty iterable of Tensors")
|
|
216
|
+
|
|
217
|
+
lengths = [seq.shape[0] for seq in seq_list]
|
|
218
|
+
padded = pad_sequence(seq_list, batch_first=False, padding_value=0.0)
|
|
219
|
+
return pack_padded_sequence(
|
|
220
|
+
padded, lengths, batch_first=False, enforce_sorted=enforce_sorted
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def unpack_sequence(
|
|
225
|
+
sequence: PackedSequence, batch_first: bool = False
|
|
226
|
+
) -> list[Tensor]:
|
|
227
|
+
padded, lengths = pad_packed_sequence(
|
|
228
|
+
sequence, batch_first=batch_first, padding_value=0.0
|
|
229
|
+
)
|
|
230
|
+
result: list[Tensor] = []
|
|
231
|
+
for i, length in enumerate(lengths):
|
|
232
|
+
l = int(length.item())
|
|
233
|
+
if batch_first:
|
|
234
|
+
result.append(padded[i, :l])
|
|
235
|
+
else:
|
|
236
|
+
result.append(padded[:l, i])
|
|
237
|
+
return result
|
lucid/transforms/image.py
CHANGED
|
@@ -35,8 +35,8 @@ def add_batch_dim(func: Callable[..., Tensor]) -> Callable:
|
|
|
35
35
|
class Normalize(nn.Module):
|
|
36
36
|
def __init__(self, mean: tuple[float, ...], std: tuple[float, ...]) -> None:
|
|
37
37
|
super().__init__()
|
|
38
|
-
self.mean = lucid.tensor(mean)
|
|
39
|
-
self.std = lucid.tensor(std)
|
|
38
|
+
self.mean = lucid.tensor(mean).reshape(1, len(mean), 1, 1)
|
|
39
|
+
self.std = lucid.tensor(std).reshape(1, len(std), 1, 1)
|
|
40
40
|
|
|
41
41
|
@add_batch_dim
|
|
42
42
|
def forward(self, img: Tensor) -> Tensor:
|
lucid/visual/__init__.py
CHANGED
lucid/visual/mermaid.py
CHANGED
|
@@ -9,7 +9,7 @@ from lucid._tensor import Tensor
|
|
|
9
9
|
from lucid.types import _ShapeLike
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
__all__ = ["
|
|
12
|
+
__all__ = ["build_tensor_mermaid_chart", "build_module_mermaid_chart"]
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
_NN_MODULES_PREFIX = "lucid.nn.modules."
|
|
@@ -255,7 +255,7 @@ def _collapse_repeated_children(
|
|
|
255
255
|
return out
|
|
256
256
|
|
|
257
257
|
|
|
258
|
-
def
|
|
258
|
+
def build_module_mermaid_chart(
|
|
259
259
|
module: nn.Module,
|
|
260
260
|
input_shape: _ShapeLike | list[_ShapeLike] | None = None,
|
|
261
261
|
inputs: Iterable[Tensor] | Tensor | None = None,
|
|
@@ -751,6 +751,192 @@ def build_mermaid_chart(
|
|
|
751
751
|
return text
|
|
752
752
|
|
|
753
753
|
|
|
754
|
+
def build_mermaid_chart(
|
|
755
|
+
module: nn.Module,
|
|
756
|
+
input_shape: _ShapeLike | list[_ShapeLike] | None = None,
|
|
757
|
+
inputs: Iterable[Tensor] | Tensor | None = None,
|
|
758
|
+
depth: int = 2,
|
|
759
|
+
direction: str = "LR",
|
|
760
|
+
include_io: bool = True,
|
|
761
|
+
show_params: bool = False,
|
|
762
|
+
return_lines: bool = False,
|
|
763
|
+
copy_to_clipboard: bool = False,
|
|
764
|
+
compact: bool = False,
|
|
765
|
+
use_class_defs: bool = False,
|
|
766
|
+
end_semicolons: bool = True,
|
|
767
|
+
edge_mode: Literal["dataflow", "execution"] = "execution",
|
|
768
|
+
collapse_repeats: bool = True,
|
|
769
|
+
repeat_min: int = 2,
|
|
770
|
+
color_by_subpackage: bool = True,
|
|
771
|
+
container_name_from_attr: bool = True,
|
|
772
|
+
edge_stroke_width: float = 2.0,
|
|
773
|
+
emphasize_model_title: bool = True,
|
|
774
|
+
model_title_font_px: int = 20,
|
|
775
|
+
show_shapes: bool = False,
|
|
776
|
+
hide_subpackages: Iterable[str] = (),
|
|
777
|
+
hide_module_names: Iterable[str] = (),
|
|
778
|
+
dash_multi_input_edges: bool = True,
|
|
779
|
+
subgraph_fill: str = "#000000",
|
|
780
|
+
subgraph_fill_opacity: float = 0.05,
|
|
781
|
+
subgraph_stroke: str = "#000000",
|
|
782
|
+
subgraph_stroke_opacity: float = 0.75,
|
|
783
|
+
force_text_color: str | None = None,
|
|
784
|
+
edge_curve: str = "natural",
|
|
785
|
+
node_spacing: int = 50,
|
|
786
|
+
rank_spacing: int = 50,
|
|
787
|
+
**forward_kwargs,
|
|
788
|
+
) -> str | list[str]:
|
|
789
|
+
return build_module_mermaid_chart(
|
|
790
|
+
module,
|
|
791
|
+
input_shape=input_shape,
|
|
792
|
+
inputs=inputs,
|
|
793
|
+
depth=depth,
|
|
794
|
+
direction=direction,
|
|
795
|
+
include_io=include_io,
|
|
796
|
+
show_params=show_params,
|
|
797
|
+
return_lines=return_lines,
|
|
798
|
+
copy_to_clipboard=copy_to_clipboard,
|
|
799
|
+
compact=compact,
|
|
800
|
+
use_class_defs=use_class_defs,
|
|
801
|
+
end_semicolons=end_semicolons,
|
|
802
|
+
edge_mode=edge_mode,
|
|
803
|
+
collapse_repeats=collapse_repeats,
|
|
804
|
+
repeat_min=repeat_min,
|
|
805
|
+
color_by_subpackage=color_by_subpackage,
|
|
806
|
+
container_name_from_attr=container_name_from_attr,
|
|
807
|
+
edge_stroke_width=edge_stroke_width,
|
|
808
|
+
emphasize_model_title=emphasize_model_title,
|
|
809
|
+
model_title_font_px=model_title_font_px,
|
|
810
|
+
show_shapes=show_shapes,
|
|
811
|
+
hide_subpackages=hide_subpackages,
|
|
812
|
+
hide_module_names=hide_module_names,
|
|
813
|
+
dash_multi_input_edges=dash_multi_input_edges,
|
|
814
|
+
subgraph_fill=subgraph_fill,
|
|
815
|
+
subgraph_fill_opacity=subgraph_fill_opacity,
|
|
816
|
+
subgraph_stroke=subgraph_stroke,
|
|
817
|
+
subgraph_stroke_opacity=subgraph_stroke_opacity,
|
|
818
|
+
force_text_color=force_text_color,
|
|
819
|
+
edge_curve=edge_curve,
|
|
820
|
+
node_spacing=node_spacing,
|
|
821
|
+
rank_spacing=rank_spacing,
|
|
822
|
+
**forward_kwargs,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
def build_tensor_mermaid_chart(
|
|
827
|
+
tensor: Tensor,
|
|
828
|
+
horizontal: bool = False,
|
|
829
|
+
title: str | None = None,
|
|
830
|
+
start_id: int | None = None,
|
|
831
|
+
end_semicolons: bool = True,
|
|
832
|
+
copy_to_clipboard: bool = False,
|
|
833
|
+
use_class_defs: bool = True,
|
|
834
|
+
op_fill: str = "lightgreen",
|
|
835
|
+
param_fill: str = "plum",
|
|
836
|
+
result_fill: str = "lightcoral",
|
|
837
|
+
leaf_fill: str = "lightgray",
|
|
838
|
+
grad_fill: str = "lightblue",
|
|
839
|
+
start_fill: str = "gold",
|
|
840
|
+
stroke_color: str = "#666",
|
|
841
|
+
stroke_width_px: int = 1,
|
|
842
|
+
) -> str:
|
|
843
|
+
direction = "LR" if horizontal else "TD"
|
|
844
|
+
lines: list[str] = [f"flowchart {direction}"]
|
|
845
|
+
if title:
|
|
846
|
+
lines.append(f"%% {title}")
|
|
847
|
+
|
|
848
|
+
result_id: int = id(tensor)
|
|
849
|
+
visited: set[int] = set()
|
|
850
|
+
nodes_to_draw: list[Tensor] = []
|
|
851
|
+
|
|
852
|
+
def dfs(t: Tensor) -> None:
|
|
853
|
+
if id(t) in visited:
|
|
854
|
+
return
|
|
855
|
+
visited.add(id(t))
|
|
856
|
+
for p in t._prev:
|
|
857
|
+
dfs(p)
|
|
858
|
+
nodes_to_draw.append(t)
|
|
859
|
+
|
|
860
|
+
def tensor_node_id(t: Tensor) -> str:
|
|
861
|
+
return f"t_{id(t)}"
|
|
862
|
+
|
|
863
|
+
def op_node_id(op: object) -> str:
|
|
864
|
+
return f"op_{id(op)}"
|
|
865
|
+
|
|
866
|
+
def add_node(node_id: str, label: str, kind: str) -> None:
|
|
867
|
+
if node_id in defined_nodes:
|
|
868
|
+
return
|
|
869
|
+
defined_nodes.add(node_id)
|
|
870
|
+
if kind == "op":
|
|
871
|
+
lines.append(f'{node_id}(("{label}"))')
|
|
872
|
+
else:
|
|
873
|
+
lines.append(f'{node_id}["{label}"]')
|
|
874
|
+
|
|
875
|
+
dfs(tensor)
|
|
876
|
+
|
|
877
|
+
defined_nodes: set[str] = set()
|
|
878
|
+
edge_lines: list[str] = []
|
|
879
|
+
class_lines: list[str] = []
|
|
880
|
+
|
|
881
|
+
for t in nodes_to_draw:
|
|
882
|
+
t_id = tensor_node_id(t)
|
|
883
|
+
|
|
884
|
+
if not t.is_leaf and t._op is not None:
|
|
885
|
+
op_id = op_node_id(t._op)
|
|
886
|
+
op_label = type(t._op).__name__
|
|
887
|
+
add_node(op_id, op_label, "op")
|
|
888
|
+
edge_lines.append(f"{op_id} --> {t_id}")
|
|
889
|
+
class_lines.append(f"class {op_id} op")
|
|
890
|
+
for inp in t._prev:
|
|
891
|
+
edge_lines.append(f"{tensor_node_id(inp)} --> {op_id}")
|
|
892
|
+
|
|
893
|
+
shape_label = str(t.shape) if t.ndim > 0 else str(t.item())
|
|
894
|
+
add_node(t_id, shape_label, "tensor")
|
|
895
|
+
|
|
896
|
+
if start_id is not None and id(t) == start_id:
|
|
897
|
+
class_lines.append(f"class {t_id} start")
|
|
898
|
+
elif isinstance(t, nn.Parameter):
|
|
899
|
+
class_lines.append(f"class {t_id} param")
|
|
900
|
+
elif id(t) == result_id:
|
|
901
|
+
class_lines.append(f"class {t_id} result")
|
|
902
|
+
elif not t.requires_grad:
|
|
903
|
+
class_lines.append(f"class {t_id} leaf")
|
|
904
|
+
else:
|
|
905
|
+
class_lines.append(f"class {t_id} grad")
|
|
906
|
+
|
|
907
|
+
lines.extend(edge_lines)
|
|
908
|
+
if use_class_defs:
|
|
909
|
+
lines.append(
|
|
910
|
+
f"classDef op fill:{op_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
911
|
+
)
|
|
912
|
+
lines.append(
|
|
913
|
+
f"classDef param fill:{param_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
914
|
+
)
|
|
915
|
+
lines.append(
|
|
916
|
+
f"classDef result fill:{result_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
917
|
+
)
|
|
918
|
+
lines.append(
|
|
919
|
+
f"classDef leaf fill:{leaf_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
920
|
+
)
|
|
921
|
+
lines.append(
|
|
922
|
+
f"classDef grad fill:{grad_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
923
|
+
)
|
|
924
|
+
lines.append(
|
|
925
|
+
f"classDef start fill:{start_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
|
|
926
|
+
)
|
|
927
|
+
lines.extend(class_lines)
|
|
928
|
+
|
|
929
|
+
if end_semicolons:
|
|
930
|
+
lines = [
|
|
931
|
+
f"{line};" if line and not line.endswith(";") else line for line in lines
|
|
932
|
+
]
|
|
933
|
+
|
|
934
|
+
text = "\n".join(lines)
|
|
935
|
+
if copy_to_clipboard:
|
|
936
|
+
_copy_to_clipboard(text)
|
|
937
|
+
return text
|
|
938
|
+
|
|
939
|
+
|
|
754
940
|
def _copy_to_clipboard(text: str) -> None:
|
|
755
941
|
import os
|
|
756
942
|
import shutil
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lucid-dl
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.12.0
|
|
4
4
|
Summary: Lumerico's Comprehensive Interface for Deep Learning
|
|
5
5
|
Home-page: https://github.com/ChanLumerico/lucid
|
|
6
6
|
Author: ChanLumerico
|
|
@@ -48,35 +48,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
|
|
|
48
48
|
|
|
49
49
|
### 🔥 What's New
|
|
50
50
|
|
|
51
|
-
- Added
|
|
51
|
+
- Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
|
|
52
52
|
|
|
53
|
-
|
|
54
|
-
def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
55
|
-
```
|
|
56
|
-
```python
|
|
57
|
-
def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
|
|
58
|
-
```
|
|
59
|
-
```python
|
|
60
|
-
def register_backward_hook(self, hook: Callable)
|
|
61
|
-
```
|
|
62
|
-
```python
|
|
63
|
-
def register_full_backward_pre_hook(self, hook: Callable)
|
|
64
|
-
```
|
|
65
|
-
```python
|
|
66
|
-
def register_full_backward_hook(self, hook: Callable)
|
|
67
|
-
```
|
|
68
|
-
```python
|
|
69
|
-
def register_state_dict_pre_hook(self, hook: Callable)
|
|
70
|
-
```
|
|
71
|
-
```python
|
|
72
|
-
def register_state_dict_hook(self, hook: Callable)
|
|
73
|
-
```
|
|
74
|
-
```python
|
|
75
|
-
def register_load_state_dict_pre_hook(self, hook: Callable)
|
|
76
|
-
```
|
|
77
|
-
```python
|
|
78
|
-
def register_load_state_dict_post_hook(self, hook: Callable)
|
|
79
|
-
```
|
|
53
|
+
- Added additional `nn.Module` hooks for richer introspection during training:
|
|
80
54
|
|
|
81
55
|
## 🔧 How to Install
|
|
82
56
|
|
|
@@ -21,16 +21,16 @@ lucid/autograd/__init__.py,sha256=hDoK_B2chRFVhoxsT4vxRKangzBEMWqF8gj2hdoTenk,67
|
|
|
21
21
|
lucid/data/__init__.py,sha256=qrDIQsnix5ZUEa0yrtomaaWbNJyJ3xEr2gdhRvg70_8,118
|
|
22
22
|
lucid/data/_base.py,sha256=RM8xpBl8qFhm19n7eER_jOsRaxkL3rbOkwUvn6VetSE,5921
|
|
23
23
|
lucid/data/_util.py,sha256=UsbliOrGmM0f1vqppoBPn3RSx53PIqcVx_yVOlHZB6A,1985
|
|
24
|
-
lucid/datasets/__init__.py,sha256=
|
|
24
|
+
lucid/datasets/__init__.py,sha256=SY0bCxIUGrSBNfEs4KmOawirmkJ8i-Cyz2CZMBkMsmk,42
|
|
25
25
|
lucid/datasets/_base.py,sha256=yeXPm3La3Beel7U_yPrxjXgGtjndJ3T6NYaQ2_H_Fak,1325
|
|
26
|
-
lucid/datasets/cifar.py,sha256=
|
|
26
|
+
lucid/datasets/cifar.py,sha256=r8KX-j6svhx3Kk1hSNhexUsJws7Bj31PlLMVRS68dC4,13080
|
|
27
27
|
lucid/datasets/mnist.py,sha256=PUXW2UwmlXJFVJiNkI9Jm58Qe4qWHGA63znkk-y9INM,8603
|
|
28
28
|
lucid/einops/__init__.py,sha256=9Dlmfw6PsIU9b_a89Zre4yV2rztRHPCL4QpsUnXJwjM,802
|
|
29
29
|
lucid/einops/_func.py,sha256=XXsX9lse_0turKoFnOTtLdY6hBUi0gq_8K81G7nr80I,21026
|
|
30
30
|
lucid/linalg/__init__.py,sha256=N-LrlC3qSsOMt6Ad1-PP3Qc3QH6EWNf5P50GBvwb9aQ,1118
|
|
31
31
|
lucid/linalg/_func.py,sha256=Iyeut5nHwQmO8N326kQUaTjgoKVoBaxt_gy_3NXXD60,16378
|
|
32
32
|
lucid/models/__init__.py,sha256=wegfOBvwJTFFee8eVt90zJoLsbbEpdT5G2y-mpO5xcE,89
|
|
33
|
-
lucid/models/
|
|
33
|
+
lucid/models/utils.py,sha256=2g8FLcMLRgVxgGEaYuwJyFxeXu-A_a4_MVr0K-TNh74,5195
|
|
34
34
|
lucid/models/imgclf/__init__.py,sha256=kQH-nNu8_TPJ7Av151WSpcY4GJ06gGAd6Ozs3m3KMcE,590
|
|
35
35
|
lucid/models/imgclf/alex.py,sha256=fZsPdCjWUseCrxBwKj-i5fPSDYLgBpfm0SJe07YKRuE,1472
|
|
36
36
|
lucid/models/imgclf/coatnet.py,sha256=HKjpy-lBKgz743EijT7jEeMxYjrZHzgU5fOrgtZfxYg,13720
|
|
@@ -55,7 +55,7 @@ lucid/models/imgclf/senet.py,sha256=I5o9eHWzquNyLqZM4thMtZtIBDYGczjARl1Isx6GyCk,
|
|
|
55
55
|
lucid/models/imgclf/sknet.py,sha256=rENInsSB2yLXJ7A9kWZ-9lDFXcKaUOIpzV0359umPRI,4535
|
|
56
56
|
lucid/models/imgclf/swin.py,sha256=lClJTX6ObF1PuzYR99Grgc7AhignbomwYFvqkQoCMx4,27969
|
|
57
57
|
lucid/models/imgclf/vgg.py,sha256=fWy78AAHJre3Msy4DK5nhQwThI-7frsdqRS-JYtFiXM,2457
|
|
58
|
-
lucid/models/imgclf/vit.py,sha256=
|
|
58
|
+
lucid/models/imgclf/vit.py,sha256=AUwsueQh9PY9d5org1PQjYzjSs9TVDOYElO9daO9Za8,3656
|
|
59
59
|
lucid/models/imgclf/xception.py,sha256=Y7YKCzF_y4r864hLouW0eE7M-kxA59SiI3-iIFsXVhQ,3728
|
|
60
60
|
lucid/models/imgclf/zfnet.py,sha256=brH5tHLVWTUfCqu-BwfFb0yZV9p5DmXN4O6cyP3U26U,1469
|
|
61
61
|
lucid/models/imggen/__init__.py,sha256=J6MlEHqXxAYINbeQmyb85ev_IEOvQDTxTQjPgX6hdpY,59
|
|
@@ -76,11 +76,10 @@ lucid/models/objdet/yolo/yolo_v3.py,sha256=B5U42Npwfg8nSgU9E261zf0cbQS9RVYrX1ADD
|
|
|
76
76
|
lucid/models/objdet/yolo/yolo_v4.py,sha256=RFbBumreXmy6s8IYZvUuhW0893ss8sx_8Vgi6KbBKWo,21467
|
|
77
77
|
lucid/models/seq2seq/__init__.py,sha256=wjsrhj4H_AcqwwbebAN8b68QBA8L6p1_12dkG2995-w,27
|
|
78
78
|
lucid/models/seq2seq/transformer.py,sha256=y5rerCs1s6jXTsVvbgscWScKpQKuSu1fezsBe7PNTRA,3513
|
|
79
|
-
lucid/nn/__init__.py,sha256=
|
|
79
|
+
lucid/nn/__init__.py,sha256=nyy6px1CxfchWUh68xCiQSxD7Gk65vamhWK8ztRvH68,184
|
|
80
80
|
lucid/nn/fused.py,sha256=75fcXuo6fHSO-JtjuKhowhHSDr4qc5871WR63sUzH0g,5492
|
|
81
81
|
lucid/nn/module.py,sha256=_EWtGkAuWWCPZ5f3t5pJOOzpi14gQBpP7JW2S8o4_GE,26855
|
|
82
82
|
lucid/nn/parameter.py,sha256=NQS65YKn2B59wZbZIoT1mpDsU_F08y3yLi7hmV1B6yo,1232
|
|
83
|
-
lucid/nn/util.py,sha256=Yw1iBSPrGV_r_F51qpqLYdafNE_hyaA0DPWYP-rjaig,1699
|
|
84
83
|
lucid/nn/_kernel/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
|
|
85
84
|
lucid/nn/_kernel/activation.py,sha256=mfe48Aw3_Hv0hZEVC7DxDw19XK9XSLfdCOvo2JcZz_o,5662
|
|
86
85
|
lucid/nn/_kernel/attention.py,sha256=1k0gboLObMNVow2v3TwliXC_2v8uKf2o8jHYFuyQqcg,3699
|
|
@@ -112,10 +111,13 @@ lucid/nn/modules/linear.py,sha256=87cuFWYct9JlmtVC3jGR-8eouxxzANaVA6cd7p9r2Ho,28
|
|
|
112
111
|
lucid/nn/modules/loss.py,sha256=pjEMIruhtpTHhHFsNThS9LFz-aI_DAXLqMV8KRXydEg,3431
|
|
113
112
|
lucid/nn/modules/norm.py,sha256=bYsKOg58kxzhMhbyvHrDDgVzN_p3D9HBTdYWpDtDeHQ,6842
|
|
114
113
|
lucid/nn/modules/pool.py,sha256=ymVnS2NZjh08Tw0VeOfkB6AVrMeLmCKvgxkmEO3KUuw,5044
|
|
115
|
-
lucid/nn/modules/rnn.py,sha256=
|
|
114
|
+
lucid/nn/modules/rnn.py,sha256=L2rqFRcdr0U33YFeVvthDwDFIE98PrO-OjFiX9IzlIs,21098
|
|
116
115
|
lucid/nn/modules/sparse.py,sha256=EpjiviED2nI55wUjh1twFwa4Lvlrzw0TR6lpCDGeSbo,1147
|
|
117
116
|
lucid/nn/modules/transformer.py,sha256=z56emF_eX18pxRELjfmmsY-7Bn9h2yjIdxCaxs6YDwA,11246
|
|
118
117
|
lucid/nn/modules/vision.py,sha256=8xYasT7TNj4NXwMwwJIw1nbV1paeWEFg_ZohXn9kZBg,1579
|
|
118
|
+
lucid/nn/utils/__init__.py,sha256=ynHrPi9SPdRRXhGjghG42FRBcEiVN8Hb_04XHBZqy_o,46
|
|
119
|
+
lucid/nn/utils/_grad.py,sha256=8EFN7TDHb09LHXK9dPjAdSLgGnL3r48Ct2rYztXKQxM,2335
|
|
120
|
+
lucid/nn/utils/rnn.py,sha256=yJIktD-cbFvegzyDrif4aQFshpF64cCxAweCikrKm7s,6963
|
|
119
121
|
lucid/optim/__init__.py,sha256=21EcCCPwrhPGP9TXvDje075_S2hPr0pHToygCaq8keI,201
|
|
120
122
|
lucid/optim/_base.py,sha256=KxM5h5ONeO8hCpAzD2_vverFRKeymu2XC6AHN_L_v3g,4859
|
|
121
123
|
lucid/optim/ada.py,sha256=-WQcC81oSYw3ffa59dPuNtDZfJ1KDrUw3zyKuPn5h5Y,6451
|
|
@@ -129,14 +131,13 @@ lucid/random/__init__.py,sha256=s8EAaKhEiTKT_vYjP4IFHx0xQVa1jqc_qIyvMauUu7M,2727
|
|
|
129
131
|
lucid/random/_func.py,sha256=1Lu4m-ciEK037chNDGqv_j00RgGGzQ7UfslSfYActUk,2232
|
|
130
132
|
lucid/transforms/__init__.py,sha256=DGznMbqhXdU9FLDMKnJawScO4HCqu40Sf_j4vJGJrjc,90
|
|
131
133
|
lucid/transforms/_base.py,sha256=v3elm7l0VoWvrT_qgoJiRzLH42tHoUcPIKNaPuxI_2E,1448
|
|
132
|
-
lucid/transforms/image.py,sha256=
|
|
133
|
-
lucid/visual/__init__.py,sha256=
|
|
134
|
-
lucid/visual/
|
|
135
|
-
lucid/visual/mermaid.py,sha256=87hFe4l9EYP6Cg2l2hP2INQiBHKkgVClH5nBWFY9ddY,26499
|
|
134
|
+
lucid/transforms/image.py,sha256=Pn4AFQ5nQixOLmlpiSlxVd8tyALOvg24UgDueY-U8fc,3817
|
|
135
|
+
lucid/visual/__init__.py,sha256=tRgyNHzKWA8cp-a_GV586Bs0yJUN5ZTmKgnUhscutHQ,23
|
|
136
|
+
lucid/visual/mermaid.py,sha256=m0X0kkdLuCxEzKmXSy3zplUaa3Gov8RRonKyHiEvfHE,32738
|
|
136
137
|
lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
|
|
137
138
|
lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
|
|
138
|
-
lucid_dl-2.
|
|
139
|
-
lucid_dl-2.
|
|
140
|
-
lucid_dl-2.
|
|
141
|
-
lucid_dl-2.
|
|
142
|
-
lucid_dl-2.
|
|
139
|
+
lucid_dl-2.12.0.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
|
|
140
|
+
lucid_dl-2.12.0.dist-info/METADATA,sha256=Y7doYNmgXQugwLzkYsJBv4Jzw1g9ZMsIxXYofaCmdAc,11679
|
|
141
|
+
lucid_dl-2.12.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
142
|
+
lucid_dl-2.12.0.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
|
|
143
|
+
lucid_dl-2.12.0.dist-info/RECORD,,
|
lucid/visual/graph.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
from warnings import deprecated
|
|
3
|
-
|
|
4
|
-
import networkx as nx
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
|
|
7
|
-
import lucid.nn as nn
|
|
8
|
-
from lucid._tensor import Tensor
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
__all__ = ["draw_tensor_graph"]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@deprecated("This feature will be re-written with Mermaid in future relases.")
|
|
15
|
-
def draw_tensor_graph(
|
|
16
|
-
tensor: Tensor,
|
|
17
|
-
horizontal: bool = False,
|
|
18
|
-
title: Union[str, None] = None,
|
|
19
|
-
start_id: Union[int, None] = None,
|
|
20
|
-
) -> plt.Figure:
|
|
21
|
-
G: nx.DiGraph = nx.DiGraph()
|
|
22
|
-
result_id: int = id(tensor)
|
|
23
|
-
|
|
24
|
-
visited: set[int] = set()
|
|
25
|
-
nodes_to_draw: list[Tensor] = []
|
|
26
|
-
|
|
27
|
-
def dfs(t: Tensor) -> None:
|
|
28
|
-
if id(t) in visited:
|
|
29
|
-
return
|
|
30
|
-
visited.add(id(t))
|
|
31
|
-
for p in t._prev:
|
|
32
|
-
dfs(p)
|
|
33
|
-
nodes_to_draw.append(t)
|
|
34
|
-
|
|
35
|
-
dfs(tensor)
|
|
36
|
-
|
|
37
|
-
for t in nodes_to_draw:
|
|
38
|
-
if not t.is_leaf and t._op is not None:
|
|
39
|
-
op_id: int = id(t._op)
|
|
40
|
-
op_label: str = type(t._op).__name__
|
|
41
|
-
G.add_node(op_id, label=op_label, shape="circle", color="lightgreen")
|
|
42
|
-
G.add_edge(op_id, id(t))
|
|
43
|
-
for inp in t._prev:
|
|
44
|
-
G.add_edge(id(inp), op_id)
|
|
45
|
-
|
|
46
|
-
shape_label: str = str(t.shape) if t.ndim > 0 else str(t.item())
|
|
47
|
-
if isinstance(t, nn.Parameter):
|
|
48
|
-
color: str = "plum"
|
|
49
|
-
else:
|
|
50
|
-
color = (
|
|
51
|
-
"lightcoral"
|
|
52
|
-
if id(t) == result_id
|
|
53
|
-
else "lightgray" if not t.requires_grad else "lightblue"
|
|
54
|
-
)
|
|
55
|
-
if start_id is not None and id(t) == start_id:
|
|
56
|
-
color = "gold"
|
|
57
|
-
|
|
58
|
-
G.add_node(id(t), label=shape_label, shape="rectangle", color=color)
|
|
59
|
-
|
|
60
|
-
def grid_layout(
|
|
61
|
-
G: nx.DiGraph, horizontal: bool = False
|
|
62
|
-
) -> tuple[dict, tuple, float, int]:
|
|
63
|
-
levels: dict[int, int] = {}
|
|
64
|
-
for node in nx.topological_sort(G):
|
|
65
|
-
preds = list(G.predecessors(node))
|
|
66
|
-
levels[node] = 0 if not preds else max(levels[p] for p in preds) + 1
|
|
67
|
-
|
|
68
|
-
level_nodes: dict[int, list[int]] = {}
|
|
69
|
-
for node, level in levels.items():
|
|
70
|
-
level_nodes.setdefault(level, []).append(node)
|
|
71
|
-
|
|
72
|
-
def autoscale(
|
|
73
|
-
level_nodes: dict[int, list[int]],
|
|
74
|
-
horizontal: bool = False,
|
|
75
|
-
base_size: float = 0.5,
|
|
76
|
-
base_nodesize: int = 500,
|
|
77
|
-
) -> tuple[tuple[float, float], float, int]:
|
|
78
|
-
num_levels: int = len(level_nodes)
|
|
79
|
-
max_width: int = max(len(nodes) for nodes in level_nodes.values())
|
|
80
|
-
node_count: int = sum(len(nodes) for nodes in level_nodes.values())
|
|
81
|
-
|
|
82
|
-
if horizontal:
|
|
83
|
-
fig_w: float = min(32, max(4.0, base_size * num_levels))
|
|
84
|
-
fig_h: float = min(32, max(4.0, base_size * max_width))
|
|
85
|
-
else:
|
|
86
|
-
fig_w = min(32, max(4.0, base_size * max_width))
|
|
87
|
-
fig_h = min(32, max(4.0, base_size * num_levels))
|
|
88
|
-
|
|
89
|
-
nodesize: float = (
|
|
90
|
-
base_nodesize
|
|
91
|
-
if node_count <= 100
|
|
92
|
-
else base_nodesize * (100 / node_count)
|
|
93
|
-
)
|
|
94
|
-
fontsize: int = max(5, min(8, int(80 / node_count)))
|
|
95
|
-
return (fig_w, fig_h), nodesize, fontsize
|
|
96
|
-
|
|
97
|
-
figsize, nodesize, fontsize = autoscale(level_nodes, horizontal)
|
|
98
|
-
pos: dict[int, tuple[float, float]] = {}
|
|
99
|
-
for level, nodes in level_nodes.items():
|
|
100
|
-
for i, node in enumerate(nodes):
|
|
101
|
-
pos[node] = (
|
|
102
|
-
(level * 2.5, -i * 2.0) if horizontal else (i * 2.5, -level * 2.0)
|
|
103
|
-
)
|
|
104
|
-
return pos, figsize, nodesize, fontsize
|
|
105
|
-
|
|
106
|
-
labels: dict[int, str] = nx.get_node_attributes(G, "label")
|
|
107
|
-
colors: dict[int, str] = nx.get_node_attributes(G, "color")
|
|
108
|
-
shapes: dict[int, str] = nx.get_node_attributes(G, "shape")
|
|
109
|
-
pos, figsize, nodesize, fontsize = grid_layout(G, horizontal)
|
|
110
|
-
|
|
111
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
112
|
-
|
|
113
|
-
rect_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "rectangle"]
|
|
114
|
-
circ_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "circle"]
|
|
115
|
-
rect_colors: list[str] = [colors[n] for n in rect_nodes]
|
|
116
|
-
|
|
117
|
-
nx.draw_networkx_nodes(
|
|
118
|
-
G,
|
|
119
|
-
pos,
|
|
120
|
-
nodelist=rect_nodes,
|
|
121
|
-
node_color=rect_colors,
|
|
122
|
-
node_size=nodesize,
|
|
123
|
-
node_shape="s",
|
|
124
|
-
ax=ax,
|
|
125
|
-
)
|
|
126
|
-
nx.draw_networkx_nodes(
|
|
127
|
-
G,
|
|
128
|
-
pos,
|
|
129
|
-
nodelist=circ_nodes,
|
|
130
|
-
node_color="lightgreen",
|
|
131
|
-
node_size=nodesize,
|
|
132
|
-
node_shape="o",
|
|
133
|
-
ax=ax,
|
|
134
|
-
)
|
|
135
|
-
nx.draw_networkx_edges(G, pos, width=0.5, arrows=True, edge_color="gray", ax=ax)
|
|
136
|
-
nx.draw_networkx_labels(G, pos, labels=labels, font_size=fontsize, ax=ax)
|
|
137
|
-
|
|
138
|
-
ax.axis("off")
|
|
139
|
-
ax.set_title(title if title is not None else "")
|
|
140
|
-
|
|
141
|
-
return fig
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|