lucid-dl 2.11.4__py3-none-any.whl → 2.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,2 @@
1
- # TODO: implement per-batch data loading and transformations
2
1
  from .mnist import *
3
2
  from .cifar import *
lucid/datasets/cifar.py CHANGED
@@ -4,6 +4,8 @@ import openml
4
4
  import math
5
5
 
6
6
  from typing import SupportsIndex, Tuple, ClassVar
7
+ from pathlib import Path
8
+ import re
7
9
 
8
10
  import lucid
9
11
  from lucid._tensor import Tensor
@@ -17,6 +19,42 @@ __all__ = ["CIFAR10", "CIFAR100"]
17
19
  class CIFAR10(DatasetBase):
18
20
  OPENML_ID: ClassVar[int] = 40927
19
21
 
22
+ def __init__(
23
+ self,
24
+ root: str | Path,
25
+ train: bool | None = True,
26
+ download: bool | None = False,
27
+ transform: lucid.nn.Module | None = None,
28
+ target_transform: lucid.nn.Module | None = None,
29
+ test_size: float = 0.2,
30
+ to_tensor: bool = True,
31
+ *,
32
+ cache: bool = True,
33
+ scale: float | None = None,
34
+ resize: tuple[int, int] | None = None,
35
+ normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
36
+ cache_preprocessed: bool = True,
37
+ preprocess_dtype: lucid.Numeric = lucid.Float16,
38
+ preprocess_chunk_size: int = 4096,
39
+ ) -> None:
40
+ self.cache = cache
41
+ self.scale = scale
42
+ self.resize = resize
43
+ self.normalize = normalize
44
+ self.cache_preprocessed = cache_preprocessed
45
+ self.preprocess_dtype = preprocess_dtype
46
+ self.preprocess_chunk_size = preprocess_chunk_size
47
+
48
+ super().__init__(
49
+ root=root,
50
+ train=train,
51
+ download=download,
52
+ transform=transform,
53
+ target_transform=target_transform,
54
+ test_size=test_size,
55
+ to_tensor=to_tensor,
56
+ )
57
+
20
58
  def _download(self) -> None:
21
59
  try:
22
60
  dataset = openml.datasets.get_dataset(self.OPENML_ID)
@@ -26,7 +64,36 @@ class CIFAR10(DatasetBase):
26
64
  except Exception as e:
27
65
  raise RuntimeError(f"Failed to download the CIFAR-10 dataset. Error: {e}")
28
66
 
29
- def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
67
+ def _cache_key(self) -> str:
68
+ parts: list[str] = []
69
+ if self.scale is not None:
70
+ parts.append(f"s{self.scale:g}")
71
+ if self.resize is not None:
72
+ parts.append(f"r{self.resize[0]}x{self.resize[1]}")
73
+ if self.normalize is not None:
74
+ mean, std = self.normalize
75
+ parts.append("m" + ",".join(f"{v:g}" for v in mean))
76
+ parts.append("v" + ",".join(f"{v:g}" for v in std))
77
+ if not parts:
78
+ return "raw"
79
+ key = "_".join(parts)
80
+ return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
81
+
82
+ def _raw_cache_path(self) -> Path:
83
+ return self.root / "CIFAR10_uint8.npz"
84
+
85
+ def _proc_cache_path(self) -> Path:
86
+ dtype_name = str(self.preprocess_dtype)
87
+ return self.root / f"CIFAR10_{self._cache_key()}_{dtype_name}.npz"
88
+
89
+ def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
90
+ raw_path = self._raw_cache_path()
91
+ if self.cache and raw_path.exists():
92
+ with np.load(raw_path) as npz:
93
+ images = npz["images"]
94
+ labels = npz["labels"]
95
+ return images, labels
96
+
30
97
  csv_path = self.root / "CIFAR10.csv"
31
98
  if not csv_path.exists():
32
99
  raise RuntimeError(
@@ -36,9 +103,69 @@ class CIFAR10(DatasetBase):
36
103
 
37
104
  df = pd.read_csv(csv_path)
38
105
  labels = df["class"].values.astype(np.int32)
39
- images = df.drop(columns=["class"]).values.astype(np.float32)
106
+ images = df.drop(columns=["class"]).values.astype(np.uint8, copy=False)
40
107
  images = images.reshape(-1, 3, 32, 32)
41
108
 
109
+ if self.cache:
110
+ np.savez_compressed(raw_path, images=images, labels=labels)
111
+
112
+ return images, labels
113
+
114
+ def _maybe_preprocess_and_cache(
115
+ self, images_uint8: np.ndarray, labels_int32: np.ndarray
116
+ ) -> tuple[np.ndarray, np.ndarray]:
117
+ if self.resize is None and self.scale is None and self.normalize is None:
118
+ return images_uint8.astype(np.float32), labels_int32
119
+
120
+ proc_path = self._proc_cache_path()
121
+ if self.cache and self.cache_preprocessed and proc_path.exists():
122
+ with np.load(proc_path) as npz:
123
+ images = npz["images"]
124
+ labels = npz["labels"]
125
+ return images, labels
126
+
127
+ from lucid.transforms import Compose, Resize, Normalize
128
+
129
+ class _Scale(lucid.nn.Module):
130
+ def __init__(self, factor: float) -> None:
131
+ super().__init__()
132
+ self.factor = factor
133
+
134
+ def forward(self, x: Tensor) -> Tensor:
135
+ return x * self.factor
136
+
137
+ transforms: list[lucid.nn.Module] = []
138
+ if self.resize is not None:
139
+ transforms.append(Resize(self.resize))
140
+ if self.scale is not None:
141
+ transforms.append(_Scale(self.scale))
142
+ if self.normalize is not None:
143
+ mean, std = self.normalize
144
+ transforms.append(Normalize(mean=mean, std=std))
145
+
146
+ transform = Compose(transforms)
147
+ n = images_uint8.shape[0]
148
+ out_h, out_w = self.resize if self.resize is not None else (32, 32)
149
+
150
+ out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
151
+ out_images = np.empty((n, 3, out_h, out_w), dtype=out_dtype)
152
+
153
+ for start in range(0, n, self.preprocess_chunk_size):
154
+ end = min(start + self.preprocess_chunk_size, n)
155
+ chunk = images_uint8[start:end].astype(np.float32)
156
+ x = lucid.to_tensor(chunk, dtype=lucid.Float32)
157
+ x = transform(x)
158
+ out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
159
+
160
+ if self.cache and self.cache_preprocessed:
161
+ np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
162
+
163
+ return out_images, labels_int32
164
+
165
+ def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
166
+ images, labels = self._ensure_raw_cache()
167
+ images, labels = self._maybe_preprocess_and_cache(images, labels)
168
+
42
169
  train_size = int(math.floor(len(images) * (1 - self.test_size)))
43
170
  if split == "train":
44
171
  images, labels = images[:train_size], labels[:train_size]
@@ -52,7 +179,7 @@ class CIFAR10(DatasetBase):
52
179
  return images, labels
53
180
 
54
181
  def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
55
- image = self.data[index].reshape(-1, 3, 32, 32)
182
+ image = self.data[index]
56
183
  label = self.targets[index]
57
184
 
58
185
  if self.transform:
@@ -66,6 +193,42 @@ class CIFAR10(DatasetBase):
66
193
  class CIFAR100(DatasetBase):
67
194
  OPENML_ID: ClassVar[int] = 41983
68
195
 
196
+ def __init__(
197
+ self,
198
+ root: str | Path,
199
+ train: bool | None = True,
200
+ download: bool | None = False,
201
+ transform: lucid.nn.Module | None = None,
202
+ target_transform: lucid.nn.Module | None = None,
203
+ test_size: float = 0.2,
204
+ to_tensor: bool = True,
205
+ *,
206
+ cache: bool = True,
207
+ scale: float | None = None,
208
+ resize: tuple[int, int] | None = None,
209
+ normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
210
+ cache_preprocessed: bool = True,
211
+ preprocess_dtype: lucid.Numeric = lucid.Float16,
212
+ preprocess_chunk_size: int = 4096,
213
+ ) -> None:
214
+ self.cache = cache
215
+ self.scale = scale
216
+ self.resize = resize
217
+ self.normalize = normalize
218
+ self.cache_preprocessed = cache_preprocessed
219
+ self.preprocess_dtype = preprocess_dtype
220
+ self.preprocess_chunk_size = preprocess_chunk_size
221
+
222
+ super().__init__(
223
+ root=root,
224
+ train=train,
225
+ download=download,
226
+ transform=transform,
227
+ target_transform=target_transform,
228
+ test_size=test_size,
229
+ to_tensor=to_tensor,
230
+ )
231
+
69
232
  def _download(self) -> None:
70
233
  try:
71
234
  dataset = openml.datasets.get_dataset(self.OPENML_ID)
@@ -75,7 +238,37 @@ class CIFAR100(DatasetBase):
75
238
  except Exception as e:
76
239
  raise RuntimeError(f"Failed to download the CIFAR-100 dataset. Error: {e}")
77
240
 
78
- def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
241
+ def _cache_key(self) -> str:
242
+ parts: list[str] = []
243
+ if self.scale is not None:
244
+ parts.append(f"s{self.scale:g}")
245
+ if self.resize is not None:
246
+ parts.append(f"r{self.resize[0]}x{self.resize[1]}")
247
+ if self.normalize is not None:
248
+ mean, std = self.normalize
249
+ parts.append("m" + ",".join(f"{v:g}" for v in mean))
250
+ parts.append("v" + ",".join(f"{v:g}" for v in std))
251
+ if not parts:
252
+ return "raw"
253
+
254
+ key = "_".join(parts)
255
+ return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
256
+
257
+ def _raw_cache_path(self) -> Path:
258
+ return self.root / "CIFAR100_uint8.npz"
259
+
260
+ def _proc_cache_path(self) -> Path:
261
+ dtype_name = str(self.preprocess_dtype)
262
+ return self.root / f"CIFAR100_{self._cache_key()}_{dtype_name}.npz"
263
+
264
+ def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
265
+ raw_path = self._raw_cache_path()
266
+ if self.cache and raw_path.exists():
267
+ with np.load(raw_path) as npz:
268
+ images = npz["images"]
269
+ labels = npz["labels"]
270
+ return images, labels
271
+
79
272
  csv_path = self.root / "CIFAR100.csv"
80
273
  if not csv_path.exists():
81
274
  raise RuntimeError(
@@ -85,9 +278,69 @@ class CIFAR100(DatasetBase):
85
278
 
86
279
  df = pd.read_csv(csv_path)
87
280
  labels = df["class"].values.astype(np.int32)
88
- images = df.drop(columns=["class"]).values.astype(np.float32)
281
+ images = df.drop(columns=["class"]).values.astype(np.uint8, copy=False)
89
282
  images = images.reshape(-1, 3, 32, 32)
90
283
 
284
+ if self.cache:
285
+ np.savez_compressed(raw_path, images=images, labels=labels)
286
+
287
+ return images, labels
288
+
289
+ def _maybe_preprocess_and_cache(
290
+ self, images_uint8: np.ndarray, labels_int32: np.ndarray
291
+ ) -> tuple[np.ndarray, np.ndarray]:
292
+ if self.resize is None and self.scale is None and self.normalize is None:
293
+ return images_uint8.astype(np.float32), labels_int32
294
+
295
+ proc_path = self._proc_cache_path()
296
+ if self.cache and self.cache_preprocessed and proc_path.exists():
297
+ with np.load(proc_path) as npz:
298
+ images = npz["images"]
299
+ labels = npz["labels"]
300
+ return images, labels
301
+
302
+ from lucid.transforms import Compose, Resize, Normalize
303
+
304
+ class _Scale(lucid.nn.Module):
305
+ def __init__(self, factor: float) -> None:
306
+ super().__init__()
307
+ self.factor = factor
308
+
309
+ def forward(self, x: Tensor) -> Tensor:
310
+ return x * self.factor
311
+
312
+ transforms: list[lucid.nn.Module] = []
313
+ if self.resize is not None:
314
+ transforms.append(Resize(self.resize))
315
+ if self.scale is not None:
316
+ transforms.append(_Scale(self.scale))
317
+ if self.normalize is not None:
318
+ mean, std = self.normalize
319
+ transforms.append(Normalize(mean=mean, std=std))
320
+
321
+ transform = Compose(transforms)
322
+ n = images_uint8.shape[0]
323
+ out_h, out_w = self.resize if self.resize is not None else (32, 32)
324
+
325
+ out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
326
+ out_images = np.empty((n, 3, out_h, out_w), dtype=out_dtype)
327
+
328
+ for start in range(0, n, self.preprocess_chunk_size):
329
+ end = min(start + self.preprocess_chunk_size, n)
330
+ chunk = images_uint8[start:end].astype(np.float32)
331
+ x = lucid.to_tensor(chunk, dtype=lucid.Float32)
332
+ x = transform(x)
333
+ out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
334
+
335
+ if self.cache and self.cache_preprocessed:
336
+ np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
337
+
338
+ return out_images, labels_int32
339
+
340
+ def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
341
+ images, labels = self._ensure_raw_cache()
342
+ images, labels = self._maybe_preprocess_and_cache(images, labels)
343
+
91
344
  train_size = int(math.floor(len(images) * (1 - self.test_size)))
92
345
  if split == "train":
93
346
  images, labels = images[:train_size], labels[:train_size]
@@ -101,7 +354,7 @@ class CIFAR100(DatasetBase):
101
354
  return images, labels
102
355
 
103
356
  def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
104
- image = self.data[index].reshape(-1, 3, 32, 32)
357
+ image = self.data[index]
105
358
  label = self.targets[index]
106
359
 
107
360
  if self.transform:
@@ -32,10 +32,12 @@ class ViT(nn.Module):
32
32
  in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
33
33
  )
34
34
 
35
- self.cls_token = nn.Parameter(lucid.random.randn(1, 1, embedding_dim))
36
- self.pos_emb = nn.Parameter(
37
- lucid.random.randn(1, 1 + self.num_patches, embedding_dim)
38
- )
35
+ self.cls_token = nn.Parameter(lucid.zeros(1, 1, embedding_dim))
36
+ self.pos_emb = nn.Parameter(lucid.zeros(1, 1 + self.num_patches, embedding_dim))
37
+
38
+ nn.init.normal(self.cls_token, std=0.02)
39
+ nn.init.normal(self.pos_emb, std=0.02)
40
+
39
41
  self.dropout = nn.Dropout(dropout_rate)
40
42
 
41
43
  encoder_layer = nn.TransformerEncoderLayer(
lucid/nn/__init__.py CHANGED
@@ -4,4 +4,4 @@ from lucid.nn.modules import *
4
4
  from lucid.nn.fused import *
5
5
 
6
6
  import lucid.nn.init as init
7
- import lucid.nn.util as util
7
+ import lucid.nn.utils as utils
lucid/nn/modules/rnn.py CHANGED
@@ -5,6 +5,7 @@ import lucid.nn as nn
5
5
  import lucid.nn.functional as F
6
6
 
7
7
  from lucid._tensor import Tensor
8
+ from lucid.nn.utils.rnn import PackedSequence
8
9
  from lucid.types import Numeric, _DeviceType
9
10
 
10
11
  from .activation import Tanh, ReLU
@@ -351,21 +352,47 @@ class RNNBase(nn.Module):
351
352
  )
352
353
 
353
354
  def forward(
354
- self, input_: Tensor, hx: Tensor | tuple[Tensor, Tensor] | None = None
355
- ) -> tuple[Tensor, Tensor] | tuple[Tensor, tuple[Tensor, Tensor]]:
356
- if input_.ndim != 3:
357
- raise ValueError(
358
- f"RNNBase expected input with 3 dimensions, got {input_.ndim} dimensions"
359
- )
355
+ self,
356
+ input_: Tensor | PackedSequence,
357
+ hx: Tensor | tuple[Tensor, Tensor] | None = None,
358
+ ) -> (
359
+ tuple[Tensor | PackedSequence, Tensor]
360
+ | tuple[Tensor | PackedSequence, tuple[Tensor, Tensor]]
361
+ ):
362
+ is_packed = isinstance(input_, PackedSequence)
363
+ if is_packed:
364
+ data = input_.data
365
+ batch_sizes = input_.batch_sizes
366
+ if data.ndim != 2:
367
+ raise ValueError(
368
+ "RNNBase expected packed data with 2 dimensions, "
369
+ f"got {data.ndim} dimensions"
370
+ )
371
+ if batch_sizes.ndim != 1 or batch_sizes.shape[0] == 0:
372
+ raise ValueError(
373
+ "PackedSequence batch_sizes must be a non-empty 1D tensor"
374
+ )
360
375
 
361
- if self.batch_first:
362
- input_ = input_.swapaxes(0, 1)
376
+ batch_size = int(batch_sizes[0].item())
377
+ feat = data.shape[1]
378
+ if feat != self.input_size:
379
+ raise ValueError(
380
+ f"RNNBase expected input with feature size {self.input_size}, got {feat}"
381
+ )
382
+ else:
383
+ if input_.ndim != 3:
384
+ raise ValueError(
385
+ f"RNNBase expected input with 3 dimensions, got {input_.ndim} dimensions"
386
+ )
363
387
 
364
- seq_len, batch_size, feat = input_.shape
365
- if feat != self.input_size:
366
- raise ValueError(
367
- f"RNNBase expected input with feature size {self.input_size}, got {feat}"
368
- )
388
+ if self.batch_first:
389
+ input_ = input_.swapaxes(0, 1)
390
+
391
+ seq_len, batch_size, feat = input_.shape
392
+ if feat != self.input_size:
393
+ raise ValueError(
394
+ f"RNNBase expected input with feature size {self.input_size}, got {feat}"
395
+ )
369
396
 
370
397
  if self.is_lstm:
371
398
  if hx is None:
@@ -410,7 +437,7 @@ class RNNBase(nn.Module):
410
437
  if hx.shape[2] != self.hidden_size:
411
438
  raise ValueError("Incorrect hidden size in hx")
412
439
 
413
- layer_input = input_
440
+ layer_input = data if is_packed else input_
414
441
  h_n_list: list[Tensor] = []
415
442
  c_n_list: list[Tensor] | None = [] if self.is_lstm else None
416
443
 
@@ -420,33 +447,111 @@ class RNNBase(nn.Module):
420
447
  c_t = hx_c[layer_idx]
421
448
  else:
422
449
  h_t = hx[layer_idx]
450
+
423
451
  outputs = []
452
+ if is_packed:
453
+ final_h: list[Tensor] = []
454
+ final_c: list[Tensor] | None = [] if self.is_lstm else None
455
+ offset = 0
456
+
457
+ prev_bs: int | None = None
458
+ max_len = int(batch_sizes.shape[0])
459
+ for t in range(max_len):
460
+ bs = int(batch_sizes[t].item())
461
+ if bs == 0:
462
+ break
463
+
464
+ if prev_bs is None:
465
+ prev_bs = bs
466
+ if bs > prev_bs:
467
+ raise ValueError(
468
+ "PackedSequence batch_sizes must be non-increasing"
469
+ )
470
+
471
+ if bs < prev_bs:
472
+ final_h.append(h_t[bs:prev_bs])
473
+ if self.is_lstm and final_c is not None:
474
+ final_c.append(c_t[bs:prev_bs])
475
+
476
+ h_t = h_t[:bs]
477
+ if self.is_lstm:
478
+ c_t = c_t[:bs]
479
+
480
+ step_input = layer_input[offset : offset + bs]
481
+ offset += bs
482
+
483
+ if self.is_lstm:
484
+ h_t, c_t = cell(step_input, (h_t, c_t))
485
+ else:
486
+ h_t = cell(step_input, h_t)
487
+
488
+ outputs.append(h_t)
489
+ prev_bs = bs
490
+
491
+ final_h.append(h_t)
492
+ if self.is_lstm and final_c is not None:
493
+ final_c.append(c_t)
494
+
495
+ h_n_list.append(
496
+ lucid.concatenate(tuple(reversed(final_h)), axis=0).unsqueeze(
497
+ axis=0
498
+ )
499
+ )
500
+ if self.is_lstm and final_c is not None and c_n_list is not None:
501
+ c_n_list.append(
502
+ lucid.concatenate(tuple(reversed(final_c)), axis=0).unsqueeze(
503
+ axis=0
504
+ )
505
+ )
506
+
507
+ layer_output = (
508
+ lucid.concatenate(tuple(outputs), axis=0)
509
+ if outputs
510
+ else layer_input[:0]
511
+ )
424
512
 
425
- for t in range(seq_len):
426
- if self.is_lstm:
427
- h_t, c_t = cell(layer_input[t], (h_t, c_t))
428
- outputs.append(h_t.unsqueeze(axis=0))
429
- else:
430
- h_t = cell(layer_input[t], h_t)
431
- outputs.append(h_t.unsqueeze(axis=0))
513
+ else:
514
+ for t in range(seq_len):
515
+ if self.is_lstm:
516
+ h_t, c_t = cell(layer_input[t], (h_t, c_t))
517
+ outputs.append(h_t.unsqueeze(axis=0))
518
+ else:
519
+ h_t = cell(layer_input[t], h_t)
520
+ outputs.append(h_t.unsqueeze(axis=0))
432
521
 
433
- layer_output = lucid.concatenate(tuple(outputs), axis=0)
522
+ layer_output = lucid.concatenate(tuple(outputs), axis=0)
434
523
 
435
524
  if self.training and self.dropout > 0.0 and layer_idx < self.num_layers - 1:
436
525
  layer_output = F.dropout(layer_output, p=self.dropout)
437
526
 
438
- h_n_list.append(h_t.unsqueeze(axis=0))
439
- if self.is_lstm and c_n_list is not None:
440
- c_n_list.append(c_t.unsqueeze(axis=0))
527
+ if not is_packed:
528
+ h_n_list.append(h_t.unsqueeze(axis=0))
529
+ if self.is_lstm and c_n_list is not None:
530
+ c_n_list.append(c_t.unsqueeze(axis=0))
441
531
  layer_input = layer_output
442
532
 
443
- output = layer_input
533
+ if is_packed:
534
+ output = PackedSequence(
535
+ data=layer_input,
536
+ batch_sizes=batch_sizes,
537
+ sorted_indices=input_.sorted_indices,
538
+ unsorted_indices=input_.unsorted_indices,
539
+ )
540
+ else:
541
+ output = layer_input
542
+
444
543
  h_n = lucid.concatenate(tuple(h_n_list), axis=0)
445
544
  if self.is_lstm and c_n_list is not None:
446
545
  c_n = lucid.concatenate(tuple(c_n_list), axis=0)
447
546
 
448
- if self.batch_first:
449
- output = output.swapaxes(0, 1)
547
+ if is_packed:
548
+ if input_.unsorted_indices is not None:
549
+ h_n = h_n[:, input_.unsorted_indices]
550
+ if self.is_lstm and c_n_list is not None:
551
+ c_n = c_n[:, input_.unsorted_indices]
552
+ else:
553
+ if self.batch_first:
554
+ output = output.swapaxes(0, 1)
450
555
 
451
556
  if self.is_lstm and c_n_list is not None:
452
557
  return output, (h_n, c_n)
@@ -0,0 +1,2 @@
1
+ from ._grad import *
2
+ from . import rnn as rnn
@@ -6,7 +6,7 @@ from lucid._tensor import Tensor
6
6
  from lucid.types import _Scalar
7
7
 
8
8
 
9
- __all__ = ["grad_norm", "clip_grad_norm", "clip_grad_value"]
9
+ __all__ = ["grad_norm", "get_total_norm", "clip_grad_norm", "clip_grad_value"]
10
10
 
11
11
 
12
12
  def _as_iter(parameters: Iterable[Tensor] | Tensor) -> list[Tensor]:
@@ -32,6 +32,25 @@ def grad_norm(parameters: Iterable[Tensor] | Tensor, norm_type: int = 2) -> Tens
32
32
  return Tensor(total_norm, device=device)
33
33
 
34
34
 
35
+ def get_total_norm(parameters: Iterable[Tensor] | Tensor, norm_type: int = 2) -> Tensor:
36
+ parameters = _as_iter(parameters)
37
+ if not parameters:
38
+ return Tensor(0.0)
39
+
40
+ device = parameters[0].device
41
+ grads: list[Tensor] = [p.grad for p in parameters if p.grad is not None]
42
+ if not grads:
43
+ return Tensor(0.0, device=device)
44
+
45
+ norm_pow_sum = 0.0
46
+ for g in grads:
47
+ grad_norm = lucid.linalg.norm(lucid.ravel(g), ord=norm_type).item()
48
+ norm_pow_sum += grad_norm**norm_type
49
+
50
+ total_norm = norm_pow_sum ** (1.0 / norm_type)
51
+ return Tensor(total_norm, device=device)
52
+
53
+
35
54
  def clip_grad_norm(
36
55
  parameters: Iterable[Tensor] | Tensor,
37
56
  max_norm: _Scalar,
@@ -39,7 +58,7 @@ def clip_grad_norm(
39
58
  eps: float = 1e-7,
40
59
  ) -> float:
41
60
  params: list[Tensor] = [p for p in _as_iter(parameters) if p.grad is not None]
42
- total_norm = grad_norm(params, norm_type=norm_type)
61
+ total_norm = get_total_norm(params, norm_type=norm_type)
43
62
 
44
63
  clip_coef = float(max_norm) / (total_norm.item() + eps)
45
64
  if clip_coef < 1.0:
lucid/nn/utils/rnn.py ADDED
@@ -0,0 +1,237 @@
1
+ from dataclasses import dataclass
2
+ from typing import Iterable, Sequence
3
+
4
+ import lucid
5
+
6
+ from lucid._tensor import Tensor
7
+ from lucid.types import _Scalar
8
+
9
+
10
+ __all__ = [
11
+ "PackedSequence",
12
+ "pad_sequence",
13
+ "pack_padded_sequence",
14
+ "pad_packed_sequence",
15
+ "pack_sequence",
16
+ "unpack_sequence",
17
+ ]
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class PackedSequence:
22
+ data: Tensor
23
+ batch_sizes: Tensor
24
+ sorted_indices: Tensor | None = None
25
+ unsorted_indices: Tensor | None = None
26
+
27
+
28
+ def pad_sequence(
29
+ sequences: Iterable[Tensor], batch_first: bool = False, padding_value: _Scalar = 0
30
+ ) -> Tensor:
31
+ seq_list = list(sequences)
32
+ if not seq_list:
33
+ raise ValueError("pad_sequence expected a non-empty iterable of Tensors")
34
+
35
+ first = seq_list[0]
36
+ if not isinstance(first, Tensor):
37
+ raise TypeError("pad_sequence expects Tensor elements")
38
+
39
+ ndim = first.ndim
40
+ if ndim < 1:
41
+ raise ValueError("pad_sequence expects tensors with at least 1 dimension")
42
+
43
+ trailing_shape = first.shape[1:]
44
+ device = first.device
45
+ dtype = first.dtype
46
+
47
+ lengths: list[int] = []
48
+ for idx, seq in enumerate(seq_list):
49
+ if not isinstance(seq, Tensor):
50
+ raise TypeError("pad_sequence expects Tensor elements")
51
+ if seq.ndim != ndim:
52
+ raise ValueError(
53
+ f"pad_sequence expects tensors with {ndim} dimensions, "
54
+ f"got {seq.ndim} at index {idx}"
55
+ )
56
+ if seq.shape[1:] != trailing_shape:
57
+ raise ValueError(
58
+ "pad_sequence expects all tensors to share the same trailing shape"
59
+ )
60
+ if seq.device != device:
61
+ raise ValueError("pad_sequence expects all tensors on the same device")
62
+ if seq.dtype != dtype:
63
+ raise ValueError("pad_sequence expects all tensors with the same dtype")
64
+ lengths.append(seq.shape[0])
65
+
66
+ max_len = max(lengths)
67
+ batch_size = len(seq_list)
68
+
69
+ if batch_first:
70
+ out_shape = (batch_size, max_len, *trailing_shape)
71
+ else:
72
+ out_shape = (max_len, batch_size, *trailing_shape)
73
+
74
+ output = lucid.full(out_shape, padding_value, dtype=dtype, device=device)
75
+ for i, seq in enumerate(seq_list):
76
+ length = lengths[i]
77
+ if length == 0:
78
+ continue
79
+ if batch_first:
80
+ output[i, :length] = seq
81
+ else:
82
+ output[:length, i] = seq
83
+
84
+ return output
85
+
86
+
87
+ def _as_lengths(lengths: Sequence[int] | Tensor, *, device: str) -> Tensor:
88
+ if isinstance(lengths, Tensor):
89
+ return lengths
90
+ return Tensor(list(lengths), device=device)
91
+
92
+
93
+ def _invert_permutation(indices: Tensor) -> Tensor:
94
+ return lucid.argsort(indices, axis=0)
95
+
96
+
97
+ def pack_padded_sequence(
98
+ input_: Tensor,
99
+ lengths: Sequence[int] | Tensor,
100
+ batch_first: bool = False,
101
+ enforce_sorted: bool = True,
102
+ ) -> PackedSequence:
103
+ if input_.ndim < 2:
104
+ raise ValueError(
105
+ f"pack_padded_sequence expected input with at least 2 dims, got {input_.ndim}"
106
+ )
107
+
108
+ if batch_first:
109
+ input_ = input_.swapaxes(0, 1)
110
+
111
+ seq_len, batch_size = input_.shape[0], input_.shape[1]
112
+ lengths_t = _as_lengths(lengths, device=input_.device)
113
+ if lengths_t.ndim != 1:
114
+ raise ValueError("lengths must be a 1D sequence or tensor")
115
+ if lengths_t.shape[0] != batch_size:
116
+ raise ValueError(
117
+ f"lengths size {lengths_t.shape[0]} does not match batch size {batch_size}"
118
+ )
119
+
120
+ sorted_indices = None
121
+ unsorted_indices = None
122
+
123
+ if enforce_sorted:
124
+ sorted_lengths = lengths_t
125
+ else:
126
+ sorted_indices = lucid.argsort(lengths_t, descending=True, axis=0)
127
+ unsorted_indices = _invert_permutation(sorted_indices)
128
+
129
+ lengths_t = lengths_t[sorted_indices]
130
+ input_ = input_[:, sorted_indices]
131
+ sorted_lengths = lengths_t
132
+
133
+ max_len = int(sorted_lengths[0].item())
134
+ if max_len > seq_len:
135
+ raise ValueError(
136
+ f"lengths has max {max_len} but input has sequence length {seq_len}"
137
+ )
138
+
139
+ batch_sizes: list[int] = []
140
+ chunks: list[Tensor] = []
141
+ for t in range(max_len):
142
+ bs = int((sorted_lengths > t).sum().item())
143
+ batch_sizes.append(bs)
144
+ if bs == 0:
145
+ break
146
+ chunks.append(input_[t, :bs])
147
+
148
+ if not chunks:
149
+ data = input_[:0]
150
+ else:
151
+ data = lucid.concatenate(tuple(chunks), axis=0)
152
+
153
+ return PackedSequence(
154
+ data=data,
155
+ batch_sizes=Tensor(batch_sizes, device=input_.device),
156
+ sorted_indices=sorted_indices,
157
+ unsorted_indices=unsorted_indices,
158
+ )
159
+
160
+
161
+ def pad_packed_sequence(
162
+ sequence: PackedSequence, batch_first: bool = False, padding_value: _Scalar = 0
163
+ ) -> tuple[Tensor, Tensor]:
164
+ data = sequence.data
165
+ batch_sizes = sequence.batch_sizes
166
+ if batch_sizes.ndim != 1:
167
+ raise ValueError("batch_sizes must be 1D")
168
+
169
+ max_len = int(batch_sizes.shape[0])
170
+ if max_len == 0:
171
+ raise ValueError("batch_sizes must be non-empty")
172
+
173
+ batch_size = int(batch_sizes[0].item())
174
+ trailing_shape = data.shape[1:]
175
+
176
+ if batch_first:
177
+ out_shape = (batch_size, max_len, *trailing_shape)
178
+ else:
179
+ out_shape = (max_len, batch_size, *trailing_shape)
180
+
181
+ output = lucid.full(out_shape, padding_value, dtype=data.dtype, device=data.device)
182
+
183
+ lengths = [0] * batch_size
184
+ offset = 0
185
+ for t in range(max_len):
186
+ bs = int(batch_sizes[t].item())
187
+ if bs == 0:
188
+ break
189
+
190
+ chunk = data[offset : offset + bs]
191
+ offset += bs
192
+ for i in range(bs):
193
+ lengths[i] += 1
194
+ if batch_first:
195
+ output[:bs, t] = chunk
196
+ else:
197
+ output[t, :bs] = chunk
198
+
199
+ lengths_t = Tensor(lengths, device=data.device)
200
+ if sequence.unsorted_indices is not None:
201
+ if batch_first:
202
+ output = output[sequence.unsorted_indices]
203
+ else:
204
+ output = output[:, sequence.unsorted_indices]
205
+ lengths_t = lengths_t[sequence.unsorted_indices]
206
+
207
+ return output, lengths_t
208
+
209
+
210
+ def pack_sequence(
211
+ sequences: Iterable[Tensor], enforce_sorted: bool = True
212
+ ) -> PackedSequence:
213
+ seq_list = list(sequences)
214
+ if not seq_list:
215
+ raise ValueError("pack_sequence expected a non-empty iterable of Tensors")
216
+
217
+ lengths = [seq.shape[0] for seq in seq_list]
218
+ padded = pad_sequence(seq_list, batch_first=False, padding_value=0.0)
219
+ return pack_padded_sequence(
220
+ padded, lengths, batch_first=False, enforce_sorted=enforce_sorted
221
+ )
222
+
223
+
224
+ def unpack_sequence(
225
+ sequence: PackedSequence, batch_first: bool = False
226
+ ) -> list[Tensor]:
227
+ padded, lengths = pad_packed_sequence(
228
+ sequence, batch_first=batch_first, padding_value=0.0
229
+ )
230
+ result: list[Tensor] = []
231
+ for i, length in enumerate(lengths):
232
+ l = int(length.item())
233
+ if batch_first:
234
+ result.append(padded[i, :l])
235
+ else:
236
+ result.append(padded[:l, i])
237
+ return result
lucid/transforms/image.py CHANGED
@@ -35,8 +35,8 @@ def add_batch_dim(func: Callable[..., Tensor]) -> Callable:
35
35
  class Normalize(nn.Module):
36
36
  def __init__(self, mean: tuple[float, ...], std: tuple[float, ...]) -> None:
37
37
  super().__init__()
38
- self.mean = lucid.tensor(mean)
39
- self.std = lucid.tensor(std)
38
+ self.mean = lucid.tensor(mean).reshape(1, len(mean), 1, 1)
39
+ self.std = lucid.tensor(std).reshape(1, len(std), 1, 1)
40
40
 
41
41
  @add_batch_dim
42
42
  def forward(self, img: Tensor) -> Tensor:
lucid/visual/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .graph import *
2
1
  from .mermaid import *
lucid/visual/mermaid.py CHANGED
@@ -9,7 +9,7 @@ from lucid._tensor import Tensor
9
9
  from lucid.types import _ShapeLike
10
10
 
11
11
 
12
- __all__ = ["build_mermaid_chart"]
12
+ __all__ = ["build_tensor_mermaid_chart", "build_module_mermaid_chart"]
13
13
 
14
14
 
15
15
  _NN_MODULES_PREFIX = "lucid.nn.modules."
@@ -255,7 +255,7 @@ def _collapse_repeated_children(
255
255
  return out
256
256
 
257
257
 
258
- def build_mermaid_chart(
258
+ def build_module_mermaid_chart(
259
259
  module: nn.Module,
260
260
  input_shape: _ShapeLike | list[_ShapeLike] | None = None,
261
261
  inputs: Iterable[Tensor] | Tensor | None = None,
@@ -751,6 +751,192 @@ def build_mermaid_chart(
751
751
  return text
752
752
 
753
753
 
754
+ def build_mermaid_chart(
755
+ module: nn.Module,
756
+ input_shape: _ShapeLike | list[_ShapeLike] | None = None,
757
+ inputs: Iterable[Tensor] | Tensor | None = None,
758
+ depth: int = 2,
759
+ direction: str = "LR",
760
+ include_io: bool = True,
761
+ show_params: bool = False,
762
+ return_lines: bool = False,
763
+ copy_to_clipboard: bool = False,
764
+ compact: bool = False,
765
+ use_class_defs: bool = False,
766
+ end_semicolons: bool = True,
767
+ edge_mode: Literal["dataflow", "execution"] = "execution",
768
+ collapse_repeats: bool = True,
769
+ repeat_min: int = 2,
770
+ color_by_subpackage: bool = True,
771
+ container_name_from_attr: bool = True,
772
+ edge_stroke_width: float = 2.0,
773
+ emphasize_model_title: bool = True,
774
+ model_title_font_px: int = 20,
775
+ show_shapes: bool = False,
776
+ hide_subpackages: Iterable[str] = (),
777
+ hide_module_names: Iterable[str] = (),
778
+ dash_multi_input_edges: bool = True,
779
+ subgraph_fill: str = "#000000",
780
+ subgraph_fill_opacity: float = 0.05,
781
+ subgraph_stroke: str = "#000000",
782
+ subgraph_stroke_opacity: float = 0.75,
783
+ force_text_color: str | None = None,
784
+ edge_curve: str = "natural",
785
+ node_spacing: int = 50,
786
+ rank_spacing: int = 50,
787
+ **forward_kwargs,
788
+ ) -> str | list[str]:
789
+ return build_module_mermaid_chart(
790
+ module,
791
+ input_shape=input_shape,
792
+ inputs=inputs,
793
+ depth=depth,
794
+ direction=direction,
795
+ include_io=include_io,
796
+ show_params=show_params,
797
+ return_lines=return_lines,
798
+ copy_to_clipboard=copy_to_clipboard,
799
+ compact=compact,
800
+ use_class_defs=use_class_defs,
801
+ end_semicolons=end_semicolons,
802
+ edge_mode=edge_mode,
803
+ collapse_repeats=collapse_repeats,
804
+ repeat_min=repeat_min,
805
+ color_by_subpackage=color_by_subpackage,
806
+ container_name_from_attr=container_name_from_attr,
807
+ edge_stroke_width=edge_stroke_width,
808
+ emphasize_model_title=emphasize_model_title,
809
+ model_title_font_px=model_title_font_px,
810
+ show_shapes=show_shapes,
811
+ hide_subpackages=hide_subpackages,
812
+ hide_module_names=hide_module_names,
813
+ dash_multi_input_edges=dash_multi_input_edges,
814
+ subgraph_fill=subgraph_fill,
815
+ subgraph_fill_opacity=subgraph_fill_opacity,
816
+ subgraph_stroke=subgraph_stroke,
817
+ subgraph_stroke_opacity=subgraph_stroke_opacity,
818
+ force_text_color=force_text_color,
819
+ edge_curve=edge_curve,
820
+ node_spacing=node_spacing,
821
+ rank_spacing=rank_spacing,
822
+ **forward_kwargs,
823
+ )
824
+
825
+
826
+ def build_tensor_mermaid_chart(
827
+ tensor: Tensor,
828
+ horizontal: bool = False,
829
+ title: str | None = None,
830
+ start_id: int | None = None,
831
+ end_semicolons: bool = True,
832
+ copy_to_clipboard: bool = False,
833
+ use_class_defs: bool = True,
834
+ op_fill: str = "lightgreen",
835
+ param_fill: str = "plum",
836
+ result_fill: str = "lightcoral",
837
+ leaf_fill: str = "lightgray",
838
+ grad_fill: str = "lightblue",
839
+ start_fill: str = "gold",
840
+ stroke_color: str = "#666",
841
+ stroke_width_px: int = 1,
842
+ ) -> str:
843
+ direction = "LR" if horizontal else "TD"
844
+ lines: list[str] = [f"flowchart {direction}"]
845
+ if title:
846
+ lines.append(f"%% {title}")
847
+
848
+ result_id: int = id(tensor)
849
+ visited: set[int] = set()
850
+ nodes_to_draw: list[Tensor] = []
851
+
852
+ def dfs(t: Tensor) -> None:
853
+ if id(t) in visited:
854
+ return
855
+ visited.add(id(t))
856
+ for p in t._prev:
857
+ dfs(p)
858
+ nodes_to_draw.append(t)
859
+
860
+ def tensor_node_id(t: Tensor) -> str:
861
+ return f"t_{id(t)}"
862
+
863
+ def op_node_id(op: object) -> str:
864
+ return f"op_{id(op)}"
865
+
866
+ def add_node(node_id: str, label: str, kind: str) -> None:
867
+ if node_id in defined_nodes:
868
+ return
869
+ defined_nodes.add(node_id)
870
+ if kind == "op":
871
+ lines.append(f'{node_id}(("{label}"))')
872
+ else:
873
+ lines.append(f'{node_id}["{label}"]')
874
+
875
+ dfs(tensor)
876
+
877
+ defined_nodes: set[str] = set()
878
+ edge_lines: list[str] = []
879
+ class_lines: list[str] = []
880
+
881
+ for t in nodes_to_draw:
882
+ t_id = tensor_node_id(t)
883
+
884
+ if not t.is_leaf and t._op is not None:
885
+ op_id = op_node_id(t._op)
886
+ op_label = type(t._op).__name__
887
+ add_node(op_id, op_label, "op")
888
+ edge_lines.append(f"{op_id} --> {t_id}")
889
+ class_lines.append(f"class {op_id} op")
890
+ for inp in t._prev:
891
+ edge_lines.append(f"{tensor_node_id(inp)} --> {op_id}")
892
+
893
+ shape_label = str(t.shape) if t.ndim > 0 else str(t.item())
894
+ add_node(t_id, shape_label, "tensor")
895
+
896
+ if start_id is not None and id(t) == start_id:
897
+ class_lines.append(f"class {t_id} start")
898
+ elif isinstance(t, nn.Parameter):
899
+ class_lines.append(f"class {t_id} param")
900
+ elif id(t) == result_id:
901
+ class_lines.append(f"class {t_id} result")
902
+ elif not t.requires_grad:
903
+ class_lines.append(f"class {t_id} leaf")
904
+ else:
905
+ class_lines.append(f"class {t_id} grad")
906
+
907
+ lines.extend(edge_lines)
908
+ if use_class_defs:
909
+ lines.append(
910
+ f"classDef op fill:{op_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
911
+ )
912
+ lines.append(
913
+ f"classDef param fill:{param_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
914
+ )
915
+ lines.append(
916
+ f"classDef result fill:{result_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
917
+ )
918
+ lines.append(
919
+ f"classDef leaf fill:{leaf_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
920
+ )
921
+ lines.append(
922
+ f"classDef grad fill:{grad_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
923
+ )
924
+ lines.append(
925
+ f"classDef start fill:{start_fill},stroke:{stroke_color},stroke-width:{stroke_width_px}px;"
926
+ )
927
+ lines.extend(class_lines)
928
+
929
+ if end_semicolons:
930
+ lines = [
931
+ f"{line};" if line and not line.endswith(";") else line for line in lines
932
+ ]
933
+
934
+ text = "\n".join(lines)
935
+ if copy_to_clipboard:
936
+ _copy_to_clipboard(text)
937
+ return text
938
+
939
+
754
940
  def _copy_to_clipboard(text: str) -> None:
755
941
  import os
756
942
  import shutil
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lucid-dl
3
- Version: 2.11.4
3
+ Version: 2.12.0
4
4
  Summary: Lumerico's Comprehensive Interface for Deep Learning
5
5
  Home-page: https://github.com/ChanLumerico/lucid
6
6
  Author: ChanLumerico
@@ -48,35 +48,9 @@ Whether you're a student, educator, or an advanced researcher seeking to demysti
48
48
 
49
49
  ### 🔥 What's New
50
50
 
51
- - Added additional `nn.Module` hooks for richer introspection during training:
51
+ - Added new visual tool: `lucid.visual.build_tensor_mermaid_chart` which builds a Mermaid chart of given tensor's computatoinal graph
52
52
 
53
- ```python
54
- def register_forward_pre_hook(self, hook: Callable, *, with_kwargs: bool = False)
55
- ```
56
- ```python
57
- def register_forward_hook(self, hook: Callable, *, with_kwargs: bool = False)
58
- ```
59
- ```python
60
- def register_backward_hook(self, hook: Callable)
61
- ```
62
- ```python
63
- def register_full_backward_pre_hook(self, hook: Callable)
64
- ```
65
- ```python
66
- def register_full_backward_hook(self, hook: Callable)
67
- ```
68
- ```python
69
- def register_state_dict_pre_hook(self, hook: Callable)
70
- ```
71
- ```python
72
- def register_state_dict_hook(self, hook: Callable)
73
- ```
74
- ```python
75
- def register_load_state_dict_pre_hook(self, hook: Callable)
76
- ```
77
- ```python
78
- def register_load_state_dict_post_hook(self, hook: Callable)
79
- ```
53
+ - Added additional `nn.Module` hooks for richer introspection during training:
80
54
 
81
55
  ## 🔧 How to Install
82
56
 
@@ -21,16 +21,16 @@ lucid/autograd/__init__.py,sha256=hDoK_B2chRFVhoxsT4vxRKangzBEMWqF8gj2hdoTenk,67
21
21
  lucid/data/__init__.py,sha256=qrDIQsnix5ZUEa0yrtomaaWbNJyJ3xEr2gdhRvg70_8,118
22
22
  lucid/data/_base.py,sha256=RM8xpBl8qFhm19n7eER_jOsRaxkL3rbOkwUvn6VetSE,5921
23
23
  lucid/data/_util.py,sha256=UsbliOrGmM0f1vqppoBPn3RSx53PIqcVx_yVOlHZB6A,1985
24
- lucid/datasets/__init__.py,sha256=vFlNOP38SSYG75_Yf0Jbyw3okSO7sY4wIjSxIzzIgWg,103
24
+ lucid/datasets/__init__.py,sha256=SY0bCxIUGrSBNfEs4KmOawirmkJ8i-Cyz2CZMBkMsmk,42
25
25
  lucid/datasets/_base.py,sha256=yeXPm3La3Beel7U_yPrxjXgGtjndJ3T6NYaQ2_H_Fak,1325
26
- lucid/datasets/cifar.py,sha256=Wf9r6PSgwFuuYsnZDEQOFNB-06AkqW4DrVsVPRyxEF0,3704
26
+ lucid/datasets/cifar.py,sha256=r8KX-j6svhx3Kk1hSNhexUsJws7Bj31PlLMVRS68dC4,13080
27
27
  lucid/datasets/mnist.py,sha256=PUXW2UwmlXJFVJiNkI9Jm58Qe4qWHGA63znkk-y9INM,8603
28
28
  lucid/einops/__init__.py,sha256=9Dlmfw6PsIU9b_a89Zre4yV2rztRHPCL4QpsUnXJwjM,802
29
29
  lucid/einops/_func.py,sha256=XXsX9lse_0turKoFnOTtLdY6hBUi0gq_8K81G7nr80I,21026
30
30
  lucid/linalg/__init__.py,sha256=N-LrlC3qSsOMt6Ad1-PP3Qc3QH6EWNf5P50GBvwb9aQ,1118
31
31
  lucid/linalg/_func.py,sha256=Iyeut5nHwQmO8N326kQUaTjgoKVoBaxt_gy_3NXXD60,16378
32
32
  lucid/models/__init__.py,sha256=wegfOBvwJTFFee8eVt90zJoLsbbEpdT5G2y-mpO5xcE,89
33
- lucid/models/util.py,sha256=2g8FLcMLRgVxgGEaYuwJyFxeXu-A_a4_MVr0K-TNh74,5195
33
+ lucid/models/utils.py,sha256=2g8FLcMLRgVxgGEaYuwJyFxeXu-A_a4_MVr0K-TNh74,5195
34
34
  lucid/models/imgclf/__init__.py,sha256=kQH-nNu8_TPJ7Av151WSpcY4GJ06gGAd6Ozs3m3KMcE,590
35
35
  lucid/models/imgclf/alex.py,sha256=fZsPdCjWUseCrxBwKj-i5fPSDYLgBpfm0SJe07YKRuE,1472
36
36
  lucid/models/imgclf/coatnet.py,sha256=HKjpy-lBKgz743EijT7jEeMxYjrZHzgU5fOrgtZfxYg,13720
@@ -55,7 +55,7 @@ lucid/models/imgclf/senet.py,sha256=I5o9eHWzquNyLqZM4thMtZtIBDYGczjARl1Isx6GyCk,
55
55
  lucid/models/imgclf/sknet.py,sha256=rENInsSB2yLXJ7A9kWZ-9lDFXcKaUOIpzV0359umPRI,4535
56
56
  lucid/models/imgclf/swin.py,sha256=lClJTX6ObF1PuzYR99Grgc7AhignbomwYFvqkQoCMx4,27969
57
57
  lucid/models/imgclf/vgg.py,sha256=fWy78AAHJre3Msy4DK5nhQwThI-7frsdqRS-JYtFiXM,2457
58
- lucid/models/imgclf/vit.py,sha256=NXzPIiyXxcE1-g25m36-_YwKnJZ0gl1-jf7G3V12jS0,3594
58
+ lucid/models/imgclf/vit.py,sha256=AUwsueQh9PY9d5org1PQjYzjSs9TVDOYElO9daO9Za8,3656
59
59
  lucid/models/imgclf/xception.py,sha256=Y7YKCzF_y4r864hLouW0eE7M-kxA59SiI3-iIFsXVhQ,3728
60
60
  lucid/models/imgclf/zfnet.py,sha256=brH5tHLVWTUfCqu-BwfFb0yZV9p5DmXN4O6cyP3U26U,1469
61
61
  lucid/models/imggen/__init__.py,sha256=J6MlEHqXxAYINbeQmyb85ev_IEOvQDTxTQjPgX6hdpY,59
@@ -76,11 +76,10 @@ lucid/models/objdet/yolo/yolo_v3.py,sha256=B5U42Npwfg8nSgU9E261zf0cbQS9RVYrX1ADD
76
76
  lucid/models/objdet/yolo/yolo_v4.py,sha256=RFbBumreXmy6s8IYZvUuhW0893ss8sx_8Vgi6KbBKWo,21467
77
77
  lucid/models/seq2seq/__init__.py,sha256=wjsrhj4H_AcqwwbebAN8b68QBA8L6p1_12dkG2995-w,27
78
78
  lucid/models/seq2seq/transformer.py,sha256=y5rerCs1s6jXTsVvbgscWScKpQKuSu1fezsBe7PNTRA,3513
79
- lucid/nn/__init__.py,sha256=_hk6KltQIJuWXowXstMSu3TjiaTP8zMLNvGpjnA9Mpw,182
79
+ lucid/nn/__init__.py,sha256=nyy6px1CxfchWUh68xCiQSxD7Gk65vamhWK8ztRvH68,184
80
80
  lucid/nn/fused.py,sha256=75fcXuo6fHSO-JtjuKhowhHSDr4qc5871WR63sUzH0g,5492
81
81
  lucid/nn/module.py,sha256=_EWtGkAuWWCPZ5f3t5pJOOzpi14gQBpP7JW2S8o4_GE,26855
82
82
  lucid/nn/parameter.py,sha256=NQS65YKn2B59wZbZIoT1mpDsU_F08y3yLi7hmV1B6yo,1232
83
- lucid/nn/util.py,sha256=Yw1iBSPrGV_r_F51qpqLYdafNE_hyaA0DPWYP-rjaig,1699
84
83
  lucid/nn/_kernel/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
85
84
  lucid/nn/_kernel/activation.py,sha256=mfe48Aw3_Hv0hZEVC7DxDw19XK9XSLfdCOvo2JcZz_o,5662
86
85
  lucid/nn/_kernel/attention.py,sha256=1k0gboLObMNVow2v3TwliXC_2v8uKf2o8jHYFuyQqcg,3699
@@ -112,10 +111,13 @@ lucid/nn/modules/linear.py,sha256=87cuFWYct9JlmtVC3jGR-8eouxxzANaVA6cd7p9r2Ho,28
112
111
  lucid/nn/modules/loss.py,sha256=pjEMIruhtpTHhHFsNThS9LFz-aI_DAXLqMV8KRXydEg,3431
113
112
  lucid/nn/modules/norm.py,sha256=bYsKOg58kxzhMhbyvHrDDgVzN_p3D9HBTdYWpDtDeHQ,6842
114
113
  lucid/nn/modules/pool.py,sha256=ymVnS2NZjh08Tw0VeOfkB6AVrMeLmCKvgxkmEO3KUuw,5044
115
- lucid/nn/modules/rnn.py,sha256=lsvQZiEHm1wGbiNWKQngAle7MbqGaXSBM1LUieCaZIk,17233
114
+ lucid/nn/modules/rnn.py,sha256=L2rqFRcdr0U33YFeVvthDwDFIE98PrO-OjFiX9IzlIs,21098
116
115
  lucid/nn/modules/sparse.py,sha256=EpjiviED2nI55wUjh1twFwa4Lvlrzw0TR6lpCDGeSbo,1147
117
116
  lucid/nn/modules/transformer.py,sha256=z56emF_eX18pxRELjfmmsY-7Bn9h2yjIdxCaxs6YDwA,11246
118
117
  lucid/nn/modules/vision.py,sha256=8xYasT7TNj4NXwMwwJIw1nbV1paeWEFg_ZohXn9kZBg,1579
118
+ lucid/nn/utils/__init__.py,sha256=ynHrPi9SPdRRXhGjghG42FRBcEiVN8Hb_04XHBZqy_o,46
119
+ lucid/nn/utils/_grad.py,sha256=8EFN7TDHb09LHXK9dPjAdSLgGnL3r48Ct2rYztXKQxM,2335
120
+ lucid/nn/utils/rnn.py,sha256=yJIktD-cbFvegzyDrif4aQFshpF64cCxAweCikrKm7s,6963
119
121
  lucid/optim/__init__.py,sha256=21EcCCPwrhPGP9TXvDje075_S2hPr0pHToygCaq8keI,201
120
122
  lucid/optim/_base.py,sha256=KxM5h5ONeO8hCpAzD2_vverFRKeymu2XC6AHN_L_v3g,4859
121
123
  lucid/optim/ada.py,sha256=-WQcC81oSYw3ffa59dPuNtDZfJ1KDrUw3zyKuPn5h5Y,6451
@@ -129,14 +131,13 @@ lucid/random/__init__.py,sha256=s8EAaKhEiTKT_vYjP4IFHx0xQVa1jqc_qIyvMauUu7M,2727
129
131
  lucid/random/_func.py,sha256=1Lu4m-ciEK037chNDGqv_j00RgGGzQ7UfslSfYActUk,2232
130
132
  lucid/transforms/__init__.py,sha256=DGznMbqhXdU9FLDMKnJawScO4HCqu40Sf_j4vJGJrjc,90
131
133
  lucid/transforms/_base.py,sha256=v3elm7l0VoWvrT_qgoJiRzLH42tHoUcPIKNaPuxI_2E,1448
132
- lucid/transforms/image.py,sha256=S9gZzMck4EQSmDQZ3ATi2fsUh4-hqFqeDjhMMJe8TdU,3762
133
- lucid/visual/__init__.py,sha256=NfHhHYNVv9mQQ4MST3-OAIkAcFyYrihJC4qUf88DySI,44
134
- lucid/visual/graph.py,sha256=ZSlrJI3dQwYjz8XbgAfNd8-8YuH9Ji7Mz1J6UsnHTaI,4711
135
- lucid/visual/mermaid.py,sha256=87hFe4l9EYP6Cg2l2hP2INQiBHKkgVClH5nBWFY9ddY,26499
134
+ lucid/transforms/image.py,sha256=Pn4AFQ5nQixOLmlpiSlxVd8tyALOvg24UgDueY-U8fc,3817
135
+ lucid/visual/__init__.py,sha256=tRgyNHzKWA8cp-a_GV586Bs0yJUN5ZTmKgnUhscutHQ,23
136
+ lucid/visual/mermaid.py,sha256=m0X0kkdLuCxEzKmXSy3zplUaa3Gov8RRonKyHiEvfHE,32738
136
137
  lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
137
138
  lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
138
- lucid_dl-2.11.4.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
139
- lucid_dl-2.11.4.dist-info/METADATA,sha256=F8r0MrpLAlRuT0IK0RFLQCurlxcj3gUYLK2-tyKhAOI,12273
140
- lucid_dl-2.11.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
141
- lucid_dl-2.11.4.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
142
- lucid_dl-2.11.4.dist-info/RECORD,,
139
+ lucid_dl-2.12.0.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
140
+ lucid_dl-2.12.0.dist-info/METADATA,sha256=Y7doYNmgXQugwLzkYsJBv4Jzw1g9ZMsIxXYofaCmdAc,11679
141
+ lucid_dl-2.12.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
142
+ lucid_dl-2.12.0.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
143
+ lucid_dl-2.12.0.dist-info/RECORD,,
lucid/visual/graph.py DELETED
@@ -1,141 +0,0 @@
1
- from typing import Union
2
- from warnings import deprecated
3
-
4
- import networkx as nx
5
- import matplotlib.pyplot as plt
6
-
7
- import lucid.nn as nn
8
- from lucid._tensor import Tensor
9
-
10
-
11
- __all__ = ["draw_tensor_graph"]
12
-
13
-
14
- @deprecated("This feature will be re-written with Mermaid in future relases.")
15
- def draw_tensor_graph(
16
- tensor: Tensor,
17
- horizontal: bool = False,
18
- title: Union[str, None] = None,
19
- start_id: Union[int, None] = None,
20
- ) -> plt.Figure:
21
- G: nx.DiGraph = nx.DiGraph()
22
- result_id: int = id(tensor)
23
-
24
- visited: set[int] = set()
25
- nodes_to_draw: list[Tensor] = []
26
-
27
- def dfs(t: Tensor) -> None:
28
- if id(t) in visited:
29
- return
30
- visited.add(id(t))
31
- for p in t._prev:
32
- dfs(p)
33
- nodes_to_draw.append(t)
34
-
35
- dfs(tensor)
36
-
37
- for t in nodes_to_draw:
38
- if not t.is_leaf and t._op is not None:
39
- op_id: int = id(t._op)
40
- op_label: str = type(t._op).__name__
41
- G.add_node(op_id, label=op_label, shape="circle", color="lightgreen")
42
- G.add_edge(op_id, id(t))
43
- for inp in t._prev:
44
- G.add_edge(id(inp), op_id)
45
-
46
- shape_label: str = str(t.shape) if t.ndim > 0 else str(t.item())
47
- if isinstance(t, nn.Parameter):
48
- color: str = "plum"
49
- else:
50
- color = (
51
- "lightcoral"
52
- if id(t) == result_id
53
- else "lightgray" if not t.requires_grad else "lightblue"
54
- )
55
- if start_id is not None and id(t) == start_id:
56
- color = "gold"
57
-
58
- G.add_node(id(t), label=shape_label, shape="rectangle", color=color)
59
-
60
- def grid_layout(
61
- G: nx.DiGraph, horizontal: bool = False
62
- ) -> tuple[dict, tuple, float, int]:
63
- levels: dict[int, int] = {}
64
- for node in nx.topological_sort(G):
65
- preds = list(G.predecessors(node))
66
- levels[node] = 0 if not preds else max(levels[p] for p in preds) + 1
67
-
68
- level_nodes: dict[int, list[int]] = {}
69
- for node, level in levels.items():
70
- level_nodes.setdefault(level, []).append(node)
71
-
72
- def autoscale(
73
- level_nodes: dict[int, list[int]],
74
- horizontal: bool = False,
75
- base_size: float = 0.5,
76
- base_nodesize: int = 500,
77
- ) -> tuple[tuple[float, float], float, int]:
78
- num_levels: int = len(level_nodes)
79
- max_width: int = max(len(nodes) for nodes in level_nodes.values())
80
- node_count: int = sum(len(nodes) for nodes in level_nodes.values())
81
-
82
- if horizontal:
83
- fig_w: float = min(32, max(4.0, base_size * num_levels))
84
- fig_h: float = min(32, max(4.0, base_size * max_width))
85
- else:
86
- fig_w = min(32, max(4.0, base_size * max_width))
87
- fig_h = min(32, max(4.0, base_size * num_levels))
88
-
89
- nodesize: float = (
90
- base_nodesize
91
- if node_count <= 100
92
- else base_nodesize * (100 / node_count)
93
- )
94
- fontsize: int = max(5, min(8, int(80 / node_count)))
95
- return (fig_w, fig_h), nodesize, fontsize
96
-
97
- figsize, nodesize, fontsize = autoscale(level_nodes, horizontal)
98
- pos: dict[int, tuple[float, float]] = {}
99
- for level, nodes in level_nodes.items():
100
- for i, node in enumerate(nodes):
101
- pos[node] = (
102
- (level * 2.5, -i * 2.0) if horizontal else (i * 2.5, -level * 2.0)
103
- )
104
- return pos, figsize, nodesize, fontsize
105
-
106
- labels: dict[int, str] = nx.get_node_attributes(G, "label")
107
- colors: dict[int, str] = nx.get_node_attributes(G, "color")
108
- shapes: dict[int, str] = nx.get_node_attributes(G, "shape")
109
- pos, figsize, nodesize, fontsize = grid_layout(G, horizontal)
110
-
111
- fig, ax = plt.subplots(figsize=figsize)
112
-
113
- rect_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "rectangle"]
114
- circ_nodes: list[int] = [n for n in G.nodes() if shapes.get(n) == "circle"]
115
- rect_colors: list[str] = [colors[n] for n in rect_nodes]
116
-
117
- nx.draw_networkx_nodes(
118
- G,
119
- pos,
120
- nodelist=rect_nodes,
121
- node_color=rect_colors,
122
- node_size=nodesize,
123
- node_shape="s",
124
- ax=ax,
125
- )
126
- nx.draw_networkx_nodes(
127
- G,
128
- pos,
129
- nodelist=circ_nodes,
130
- node_color="lightgreen",
131
- node_size=nodesize,
132
- node_shape="o",
133
- ax=ax,
134
- )
135
- nx.draw_networkx_edges(G, pos, width=0.5, arrows=True, edge_color="gray", ax=ax)
136
- nx.draw_networkx_labels(G, pos, labels=labels, font_size=fontsize, ax=ax)
137
-
138
- ax.axis("off")
139
- ax.set_title(title if title is not None else "")
140
-
141
- return fig
File without changes