lucid-dl 2.11.5__tar.gz → 2.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/PKG-INFO +1 -1
- lucid_dl-2.12.0/lucid/datasets/__init__.py +2 -0
- lucid_dl-2.12.0/lucid/datasets/cifar.py +365 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/vit.py +6 -4
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/__init__.py +1 -1
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/rnn.py +133 -28
- lucid_dl-2.12.0/lucid/nn/utils/__init__.py +2 -0
- lucid_dl-2.11.5/lucid/nn/util.py → lucid_dl-2.12.0/lucid/nn/utils/_grad.py +21 -2
- lucid_dl-2.12.0/lucid/nn/utils/rnn.py +237 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/transforms/image.py +2 -2
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid_dl.egg-info/PKG-INFO +1 -1
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid_dl.egg-info/SOURCES.txt +4 -2
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/setup.py +1 -1
- lucid_dl-2.11.5/lucid/datasets/__init__.py +0 -3
- lucid_dl-2.11.5/lucid/datasets/cifar.py +0 -112
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/LICENSE +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/README.md +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_backend/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_backend/core.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_backend/metal.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_func/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_func/bfunc.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_func/gfunc.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_func/ufunc.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_fusion/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_fusion/base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_fusion/func.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_tensor/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_tensor/base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_tensor/tensor.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_util/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/_util/func.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/autograd/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/data/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/data/_base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/data/_util.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/datasets/_base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/datasets/mnist.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/einops/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/einops/_func.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/error.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/linalg/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/linalg/_func.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/alex.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/coatnet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/convnext.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/crossvit.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/cspnet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/cvt.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/dense.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/efficient.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/efficientformer.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/inception.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/inception_next.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/inception_res.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/lenet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/maxvit.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/mobile.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/pvt.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/resnest.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/resnet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/resnext.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/senet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/sknet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/swin.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/vgg.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/xception.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imgclf/zfnet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imggen/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imggen/ddpm.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imggen/ncsn.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/imggen/vae.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/detr.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/efficientdet.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/fast_rcnn.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/faster_rcnn.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/rcnn.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/util.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/yolo/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/yolo/yolo_v1.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/yolo/yolo_v2.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/yolo/yolo_v3.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/objdet/yolo/yolo_v4.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/seq2seq/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/models/seq2seq/transformer.py +0 -0
- /lucid_dl-2.11.5/lucid/models/util.py → /lucid_dl-2.12.0/lucid/models/utils.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/activation.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/attention.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/conv.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/embedding.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/loss.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/norm.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/_kernel/pool.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_activation.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_attention.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_conv.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_drop.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_linear.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_loss.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_norm.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_pool.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_spatial.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/functional/_util.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/fused.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/init/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/init/_dist.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/module.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/activation.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/attention.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/conv.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/drop.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/einops.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/linear.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/loss.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/norm.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/pool.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/sparse.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/transformer.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/modules/vision.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/nn/parameter.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/_base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/ada.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/adam.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/lr_scheduler/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/lr_scheduler/_base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/lr_scheduler/_schedulers.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/prop.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/optim/sgd.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/port.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/random/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/random/_func.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/transforms/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/transforms/_base.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/types.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/visual/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/visual/mermaid.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/weights/__init__.py +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid/weights/__init__.pyi +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid_dl.egg-info/dependency_links.txt +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid_dl.egg-info/requires.txt +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/lucid_dl.egg-info/top_level.txt +0 -0
- {lucid_dl-2.11.5 → lucid_dl-2.12.0}/setup.cfg +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import openml
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from typing import SupportsIndex, Tuple, ClassVar
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import re
|
|
9
|
+
|
|
10
|
+
import lucid
|
|
11
|
+
from lucid._tensor import Tensor
|
|
12
|
+
|
|
13
|
+
from ._base import DatasetBase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = ["CIFAR10", "CIFAR100"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CIFAR10(DatasetBase):
|
|
20
|
+
OPENML_ID: ClassVar[int] = 40927
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
root: str | Path,
|
|
25
|
+
train: bool | None = True,
|
|
26
|
+
download: bool | None = False,
|
|
27
|
+
transform: lucid.nn.Module | None = None,
|
|
28
|
+
target_transform: lucid.nn.Module | None = None,
|
|
29
|
+
test_size: float = 0.2,
|
|
30
|
+
to_tensor: bool = True,
|
|
31
|
+
*,
|
|
32
|
+
cache: bool = True,
|
|
33
|
+
scale: float | None = None,
|
|
34
|
+
resize: tuple[int, int] | None = None,
|
|
35
|
+
normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
|
|
36
|
+
cache_preprocessed: bool = True,
|
|
37
|
+
preprocess_dtype: lucid.Numeric = lucid.Float16,
|
|
38
|
+
preprocess_chunk_size: int = 4096,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.cache = cache
|
|
41
|
+
self.scale = scale
|
|
42
|
+
self.resize = resize
|
|
43
|
+
self.normalize = normalize
|
|
44
|
+
self.cache_preprocessed = cache_preprocessed
|
|
45
|
+
self.preprocess_dtype = preprocess_dtype
|
|
46
|
+
self.preprocess_chunk_size = preprocess_chunk_size
|
|
47
|
+
|
|
48
|
+
super().__init__(
|
|
49
|
+
root=root,
|
|
50
|
+
train=train,
|
|
51
|
+
download=download,
|
|
52
|
+
transform=transform,
|
|
53
|
+
target_transform=target_transform,
|
|
54
|
+
test_size=test_size,
|
|
55
|
+
to_tensor=to_tensor,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def _download(self) -> None:
|
|
59
|
+
try:
|
|
60
|
+
dataset = openml.datasets.get_dataset(self.OPENML_ID)
|
|
61
|
+
df, _, _, _ = dataset.get_data(dataset_format="dataframe")
|
|
62
|
+
df.to_csv(self.root / "CIFAR10.csv", index=False)
|
|
63
|
+
|
|
64
|
+
except Exception as e:
|
|
65
|
+
raise RuntimeError(f"Failed to download the CIFAR-10 dataset. Error: {e}")
|
|
66
|
+
|
|
67
|
+
def _cache_key(self) -> str:
|
|
68
|
+
parts: list[str] = []
|
|
69
|
+
if self.scale is not None:
|
|
70
|
+
parts.append(f"s{self.scale:g}")
|
|
71
|
+
if self.resize is not None:
|
|
72
|
+
parts.append(f"r{self.resize[0]}x{self.resize[1]}")
|
|
73
|
+
if self.normalize is not None:
|
|
74
|
+
mean, std = self.normalize
|
|
75
|
+
parts.append("m" + ",".join(f"{v:g}" for v in mean))
|
|
76
|
+
parts.append("v" + ",".join(f"{v:g}" for v in std))
|
|
77
|
+
if not parts:
|
|
78
|
+
return "raw"
|
|
79
|
+
key = "_".join(parts)
|
|
80
|
+
return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
|
|
81
|
+
|
|
82
|
+
def _raw_cache_path(self) -> Path:
|
|
83
|
+
return self.root / "CIFAR10_uint8.npz"
|
|
84
|
+
|
|
85
|
+
def _proc_cache_path(self) -> Path:
|
|
86
|
+
dtype_name = str(self.preprocess_dtype)
|
|
87
|
+
return self.root / f"CIFAR10_{self._cache_key()}_{dtype_name}.npz"
|
|
88
|
+
|
|
89
|
+
def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
|
|
90
|
+
raw_path = self._raw_cache_path()
|
|
91
|
+
if self.cache and raw_path.exists():
|
|
92
|
+
with np.load(raw_path) as npz:
|
|
93
|
+
images = npz["images"]
|
|
94
|
+
labels = npz["labels"]
|
|
95
|
+
return images, labels
|
|
96
|
+
|
|
97
|
+
csv_path = self.root / "CIFAR10.csv"
|
|
98
|
+
if not csv_path.exists():
|
|
99
|
+
raise RuntimeError(
|
|
100
|
+
f"CIFAR-10 dataset CSV file not found at {csv_path}. "
|
|
101
|
+
+ "Use `download=True`."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
df = pd.read_csv(csv_path)
|
|
105
|
+
labels = df["class"].values.astype(np.int32)
|
|
106
|
+
images = df.drop(columns=["class"]).values.astype(np.uint8, copy=False)
|
|
107
|
+
images = images.reshape(-1, 3, 32, 32)
|
|
108
|
+
|
|
109
|
+
if self.cache:
|
|
110
|
+
np.savez_compressed(raw_path, images=images, labels=labels)
|
|
111
|
+
|
|
112
|
+
return images, labels
|
|
113
|
+
|
|
114
|
+
def _maybe_preprocess_and_cache(
|
|
115
|
+
self, images_uint8: np.ndarray, labels_int32: np.ndarray
|
|
116
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
117
|
+
if self.resize is None and self.scale is None and self.normalize is None:
|
|
118
|
+
return images_uint8.astype(np.float32), labels_int32
|
|
119
|
+
|
|
120
|
+
proc_path = self._proc_cache_path()
|
|
121
|
+
if self.cache and self.cache_preprocessed and proc_path.exists():
|
|
122
|
+
with np.load(proc_path) as npz:
|
|
123
|
+
images = npz["images"]
|
|
124
|
+
labels = npz["labels"]
|
|
125
|
+
return images, labels
|
|
126
|
+
|
|
127
|
+
from lucid.transforms import Compose, Resize, Normalize
|
|
128
|
+
|
|
129
|
+
class _Scale(lucid.nn.Module):
|
|
130
|
+
def __init__(self, factor: float) -> None:
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.factor = factor
|
|
133
|
+
|
|
134
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
135
|
+
return x * self.factor
|
|
136
|
+
|
|
137
|
+
transforms: list[lucid.nn.Module] = []
|
|
138
|
+
if self.resize is not None:
|
|
139
|
+
transforms.append(Resize(self.resize))
|
|
140
|
+
if self.scale is not None:
|
|
141
|
+
transforms.append(_Scale(self.scale))
|
|
142
|
+
if self.normalize is not None:
|
|
143
|
+
mean, std = self.normalize
|
|
144
|
+
transforms.append(Normalize(mean=mean, std=std))
|
|
145
|
+
|
|
146
|
+
transform = Compose(transforms)
|
|
147
|
+
n = images_uint8.shape[0]
|
|
148
|
+
out_h, out_w = self.resize if self.resize is not None else (32, 32)
|
|
149
|
+
|
|
150
|
+
out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
|
|
151
|
+
out_images = np.empty((n, 3, out_h, out_w), dtype=out_dtype)
|
|
152
|
+
|
|
153
|
+
for start in range(0, n, self.preprocess_chunk_size):
|
|
154
|
+
end = min(start + self.preprocess_chunk_size, n)
|
|
155
|
+
chunk = images_uint8[start:end].astype(np.float32)
|
|
156
|
+
x = lucid.to_tensor(chunk, dtype=lucid.Float32)
|
|
157
|
+
x = transform(x)
|
|
158
|
+
out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
|
|
159
|
+
|
|
160
|
+
if self.cache and self.cache_preprocessed:
|
|
161
|
+
np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
|
|
162
|
+
|
|
163
|
+
return out_images, labels_int32
|
|
164
|
+
|
|
165
|
+
def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
|
|
166
|
+
images, labels = self._ensure_raw_cache()
|
|
167
|
+
images, labels = self._maybe_preprocess_and_cache(images, labels)
|
|
168
|
+
|
|
169
|
+
train_size = int(math.floor(len(images) * (1 - self.test_size)))
|
|
170
|
+
if split == "train":
|
|
171
|
+
images, labels = images[:train_size], labels[:train_size]
|
|
172
|
+
else:
|
|
173
|
+
images, labels = images[train_size:], labels[train_size:]
|
|
174
|
+
|
|
175
|
+
if self.to_tensor:
|
|
176
|
+
images = lucid.to_tensor(images, dtype=lucid.Float32)
|
|
177
|
+
labels = lucid.to_tensor(labels, dtype=lucid.Int32)
|
|
178
|
+
|
|
179
|
+
return images, labels
|
|
180
|
+
|
|
181
|
+
def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
|
|
182
|
+
image = self.data[index]
|
|
183
|
+
label = self.targets[index]
|
|
184
|
+
|
|
185
|
+
if self.transform:
|
|
186
|
+
image = self.transform(image)
|
|
187
|
+
if self.target_transform:
|
|
188
|
+
label = self.target_transform(label)
|
|
189
|
+
|
|
190
|
+
return image, label
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class CIFAR100(DatasetBase):
|
|
194
|
+
OPENML_ID: ClassVar[int] = 41983
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
root: str | Path,
|
|
199
|
+
train: bool | None = True,
|
|
200
|
+
download: bool | None = False,
|
|
201
|
+
transform: lucid.nn.Module | None = None,
|
|
202
|
+
target_transform: lucid.nn.Module | None = None,
|
|
203
|
+
test_size: float = 0.2,
|
|
204
|
+
to_tensor: bool = True,
|
|
205
|
+
*,
|
|
206
|
+
cache: bool = True,
|
|
207
|
+
scale: float | None = None,
|
|
208
|
+
resize: tuple[int, int] | None = None,
|
|
209
|
+
normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
|
|
210
|
+
cache_preprocessed: bool = True,
|
|
211
|
+
preprocess_dtype: lucid.Numeric = lucid.Float16,
|
|
212
|
+
preprocess_chunk_size: int = 4096,
|
|
213
|
+
) -> None:
|
|
214
|
+
self.cache = cache
|
|
215
|
+
self.scale = scale
|
|
216
|
+
self.resize = resize
|
|
217
|
+
self.normalize = normalize
|
|
218
|
+
self.cache_preprocessed = cache_preprocessed
|
|
219
|
+
self.preprocess_dtype = preprocess_dtype
|
|
220
|
+
self.preprocess_chunk_size = preprocess_chunk_size
|
|
221
|
+
|
|
222
|
+
super().__init__(
|
|
223
|
+
root=root,
|
|
224
|
+
train=train,
|
|
225
|
+
download=download,
|
|
226
|
+
transform=transform,
|
|
227
|
+
target_transform=target_transform,
|
|
228
|
+
test_size=test_size,
|
|
229
|
+
to_tensor=to_tensor,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def _download(self) -> None:
|
|
233
|
+
try:
|
|
234
|
+
dataset = openml.datasets.get_dataset(self.OPENML_ID)
|
|
235
|
+
df, _, _, _ = dataset.get_data(dataset_format="dataframe")
|
|
236
|
+
df.to_csv(self.root / "CIFAR100.csv", index=False)
|
|
237
|
+
|
|
238
|
+
except Exception as e:
|
|
239
|
+
raise RuntimeError(f"Failed to download the CIFAR-100 dataset. Error: {e}")
|
|
240
|
+
|
|
241
|
+
def _cache_key(self) -> str:
|
|
242
|
+
parts: list[str] = []
|
|
243
|
+
if self.scale is not None:
|
|
244
|
+
parts.append(f"s{self.scale:g}")
|
|
245
|
+
if self.resize is not None:
|
|
246
|
+
parts.append(f"r{self.resize[0]}x{self.resize[1]}")
|
|
247
|
+
if self.normalize is not None:
|
|
248
|
+
mean, std = self.normalize
|
|
249
|
+
parts.append("m" + ",".join(f"{v:g}" for v in mean))
|
|
250
|
+
parts.append("v" + ",".join(f"{v:g}" for v in std))
|
|
251
|
+
if not parts:
|
|
252
|
+
return "raw"
|
|
253
|
+
|
|
254
|
+
key = "_".join(parts)
|
|
255
|
+
return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
|
|
256
|
+
|
|
257
|
+
def _raw_cache_path(self) -> Path:
|
|
258
|
+
return self.root / "CIFAR100_uint8.npz"
|
|
259
|
+
|
|
260
|
+
def _proc_cache_path(self) -> Path:
|
|
261
|
+
dtype_name = str(self.preprocess_dtype)
|
|
262
|
+
return self.root / f"CIFAR100_{self._cache_key()}_{dtype_name}.npz"
|
|
263
|
+
|
|
264
|
+
def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
|
|
265
|
+
raw_path = self._raw_cache_path()
|
|
266
|
+
if self.cache and raw_path.exists():
|
|
267
|
+
with np.load(raw_path) as npz:
|
|
268
|
+
images = npz["images"]
|
|
269
|
+
labels = npz["labels"]
|
|
270
|
+
return images, labels
|
|
271
|
+
|
|
272
|
+
csv_path = self.root / "CIFAR100.csv"
|
|
273
|
+
if not csv_path.exists():
|
|
274
|
+
raise RuntimeError(
|
|
275
|
+
f"CIFAR-100 dataset CSV file not found at {csv_path}. "
|
|
276
|
+
+ "Use `download=True`."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
df = pd.read_csv(csv_path)
|
|
280
|
+
labels = df["class"].values.astype(np.int32)
|
|
281
|
+
images = df.drop(columns=["class"]).values.astype(np.uint8, copy=False)
|
|
282
|
+
images = images.reshape(-1, 3, 32, 32)
|
|
283
|
+
|
|
284
|
+
if self.cache:
|
|
285
|
+
np.savez_compressed(raw_path, images=images, labels=labels)
|
|
286
|
+
|
|
287
|
+
return images, labels
|
|
288
|
+
|
|
289
|
+
def _maybe_preprocess_and_cache(
|
|
290
|
+
self, images_uint8: np.ndarray, labels_int32: np.ndarray
|
|
291
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
292
|
+
if self.resize is None and self.scale is None and self.normalize is None:
|
|
293
|
+
return images_uint8.astype(np.float32), labels_int32
|
|
294
|
+
|
|
295
|
+
proc_path = self._proc_cache_path()
|
|
296
|
+
if self.cache and self.cache_preprocessed and proc_path.exists():
|
|
297
|
+
with np.load(proc_path) as npz:
|
|
298
|
+
images = npz["images"]
|
|
299
|
+
labels = npz["labels"]
|
|
300
|
+
return images, labels
|
|
301
|
+
|
|
302
|
+
from lucid.transforms import Compose, Resize, Normalize
|
|
303
|
+
|
|
304
|
+
class _Scale(lucid.nn.Module):
|
|
305
|
+
def __init__(self, factor: float) -> None:
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.factor = factor
|
|
308
|
+
|
|
309
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
310
|
+
return x * self.factor
|
|
311
|
+
|
|
312
|
+
transforms: list[lucid.nn.Module] = []
|
|
313
|
+
if self.resize is not None:
|
|
314
|
+
transforms.append(Resize(self.resize))
|
|
315
|
+
if self.scale is not None:
|
|
316
|
+
transforms.append(_Scale(self.scale))
|
|
317
|
+
if self.normalize is not None:
|
|
318
|
+
mean, std = self.normalize
|
|
319
|
+
transforms.append(Normalize(mean=mean, std=std))
|
|
320
|
+
|
|
321
|
+
transform = Compose(transforms)
|
|
322
|
+
n = images_uint8.shape[0]
|
|
323
|
+
out_h, out_w = self.resize if self.resize is not None else (32, 32)
|
|
324
|
+
|
|
325
|
+
out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
|
|
326
|
+
out_images = np.empty((n, 3, out_h, out_w), dtype=out_dtype)
|
|
327
|
+
|
|
328
|
+
for start in range(0, n, self.preprocess_chunk_size):
|
|
329
|
+
end = min(start + self.preprocess_chunk_size, n)
|
|
330
|
+
chunk = images_uint8[start:end].astype(np.float32)
|
|
331
|
+
x = lucid.to_tensor(chunk, dtype=lucid.Float32)
|
|
332
|
+
x = transform(x)
|
|
333
|
+
out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
|
|
334
|
+
|
|
335
|
+
if self.cache and self.cache_preprocessed:
|
|
336
|
+
np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
|
|
337
|
+
|
|
338
|
+
return out_images, labels_int32
|
|
339
|
+
|
|
340
|
+
def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
|
|
341
|
+
images, labels = self._ensure_raw_cache()
|
|
342
|
+
images, labels = self._maybe_preprocess_and_cache(images, labels)
|
|
343
|
+
|
|
344
|
+
train_size = int(math.floor(len(images) * (1 - self.test_size)))
|
|
345
|
+
if split == "train":
|
|
346
|
+
images, labels = images[:train_size], labels[:train_size]
|
|
347
|
+
else:
|
|
348
|
+
images, labels = images[train_size:], labels[train_size:]
|
|
349
|
+
|
|
350
|
+
if self.to_tensor:
|
|
351
|
+
images = lucid.to_tensor(images, dtype=lucid.Float32)
|
|
352
|
+
labels = lucid.to_tensor(labels, dtype=lucid.Int32)
|
|
353
|
+
|
|
354
|
+
return images, labels
|
|
355
|
+
|
|
356
|
+
def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
|
|
357
|
+
image = self.data[index]
|
|
358
|
+
label = self.targets[index]
|
|
359
|
+
|
|
360
|
+
if self.transform:
|
|
361
|
+
image = self.transform(image)
|
|
362
|
+
if self.target_transform:
|
|
363
|
+
label = self.target_transform(label)
|
|
364
|
+
|
|
365
|
+
return image, label
|
|
@@ -32,10 +32,12 @@ class ViT(nn.Module):
|
|
|
32
32
|
in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
-
self.cls_token = nn.Parameter(lucid.
|
|
36
|
-
self.pos_emb = nn.Parameter(
|
|
37
|
-
|
|
38
|
-
)
|
|
35
|
+
self.cls_token = nn.Parameter(lucid.zeros(1, 1, embedding_dim))
|
|
36
|
+
self.pos_emb = nn.Parameter(lucid.zeros(1, 1 + self.num_patches, embedding_dim))
|
|
37
|
+
|
|
38
|
+
nn.init.normal(self.cls_token, std=0.02)
|
|
39
|
+
nn.init.normal(self.pos_emb, std=0.02)
|
|
40
|
+
|
|
39
41
|
self.dropout = nn.Dropout(dropout_rate)
|
|
40
42
|
|
|
41
43
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
@@ -5,6 +5,7 @@ import lucid.nn as nn
|
|
|
5
5
|
import lucid.nn.functional as F
|
|
6
6
|
|
|
7
7
|
from lucid._tensor import Tensor
|
|
8
|
+
from lucid.nn.utils.rnn import PackedSequence
|
|
8
9
|
from lucid.types import Numeric, _DeviceType
|
|
9
10
|
|
|
10
11
|
from .activation import Tanh, ReLU
|
|
@@ -351,21 +352,47 @@ class RNNBase(nn.Module):
|
|
|
351
352
|
)
|
|
352
353
|
|
|
353
354
|
def forward(
|
|
354
|
-
self,
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
355
|
+
self,
|
|
356
|
+
input_: Tensor | PackedSequence,
|
|
357
|
+
hx: Tensor | tuple[Tensor, Tensor] | None = None,
|
|
358
|
+
) -> (
|
|
359
|
+
tuple[Tensor | PackedSequence, Tensor]
|
|
360
|
+
| tuple[Tensor | PackedSequence, tuple[Tensor, Tensor]]
|
|
361
|
+
):
|
|
362
|
+
is_packed = isinstance(input_, PackedSequence)
|
|
363
|
+
if is_packed:
|
|
364
|
+
data = input_.data
|
|
365
|
+
batch_sizes = input_.batch_sizes
|
|
366
|
+
if data.ndim != 2:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
"RNNBase expected packed data with 2 dimensions, "
|
|
369
|
+
f"got {data.ndim} dimensions"
|
|
370
|
+
)
|
|
371
|
+
if batch_sizes.ndim != 1 or batch_sizes.shape[0] == 0:
|
|
372
|
+
raise ValueError(
|
|
373
|
+
"PackedSequence batch_sizes must be a non-empty 1D tensor"
|
|
374
|
+
)
|
|
360
375
|
|
|
361
|
-
|
|
362
|
-
|
|
376
|
+
batch_size = int(batch_sizes[0].item())
|
|
377
|
+
feat = data.shape[1]
|
|
378
|
+
if feat != self.input_size:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"RNNBase expected input with feature size {self.input_size}, got {feat}"
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
if input_.ndim != 3:
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"RNNBase expected input with 3 dimensions, got {input_.ndim} dimensions"
|
|
386
|
+
)
|
|
363
387
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
388
|
+
if self.batch_first:
|
|
389
|
+
input_ = input_.swapaxes(0, 1)
|
|
390
|
+
|
|
391
|
+
seq_len, batch_size, feat = input_.shape
|
|
392
|
+
if feat != self.input_size:
|
|
393
|
+
raise ValueError(
|
|
394
|
+
f"RNNBase expected input with feature size {self.input_size}, got {feat}"
|
|
395
|
+
)
|
|
369
396
|
|
|
370
397
|
if self.is_lstm:
|
|
371
398
|
if hx is None:
|
|
@@ -410,7 +437,7 @@ class RNNBase(nn.Module):
|
|
|
410
437
|
if hx.shape[2] != self.hidden_size:
|
|
411
438
|
raise ValueError("Incorrect hidden size in hx")
|
|
412
439
|
|
|
413
|
-
layer_input = input_
|
|
440
|
+
layer_input = data if is_packed else input_
|
|
414
441
|
h_n_list: list[Tensor] = []
|
|
415
442
|
c_n_list: list[Tensor] | None = [] if self.is_lstm else None
|
|
416
443
|
|
|
@@ -420,33 +447,111 @@ class RNNBase(nn.Module):
|
|
|
420
447
|
c_t = hx_c[layer_idx]
|
|
421
448
|
else:
|
|
422
449
|
h_t = hx[layer_idx]
|
|
450
|
+
|
|
423
451
|
outputs = []
|
|
452
|
+
if is_packed:
|
|
453
|
+
final_h: list[Tensor] = []
|
|
454
|
+
final_c: list[Tensor] | None = [] if self.is_lstm else None
|
|
455
|
+
offset = 0
|
|
456
|
+
|
|
457
|
+
prev_bs: int | None = None
|
|
458
|
+
max_len = int(batch_sizes.shape[0])
|
|
459
|
+
for t in range(max_len):
|
|
460
|
+
bs = int(batch_sizes[t].item())
|
|
461
|
+
if bs == 0:
|
|
462
|
+
break
|
|
463
|
+
|
|
464
|
+
if prev_bs is None:
|
|
465
|
+
prev_bs = bs
|
|
466
|
+
if bs > prev_bs:
|
|
467
|
+
raise ValueError(
|
|
468
|
+
"PackedSequence batch_sizes must be non-increasing"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if bs < prev_bs:
|
|
472
|
+
final_h.append(h_t[bs:prev_bs])
|
|
473
|
+
if self.is_lstm and final_c is not None:
|
|
474
|
+
final_c.append(c_t[bs:prev_bs])
|
|
475
|
+
|
|
476
|
+
h_t = h_t[:bs]
|
|
477
|
+
if self.is_lstm:
|
|
478
|
+
c_t = c_t[:bs]
|
|
479
|
+
|
|
480
|
+
step_input = layer_input[offset : offset + bs]
|
|
481
|
+
offset += bs
|
|
482
|
+
|
|
483
|
+
if self.is_lstm:
|
|
484
|
+
h_t, c_t = cell(step_input, (h_t, c_t))
|
|
485
|
+
else:
|
|
486
|
+
h_t = cell(step_input, h_t)
|
|
487
|
+
|
|
488
|
+
outputs.append(h_t)
|
|
489
|
+
prev_bs = bs
|
|
490
|
+
|
|
491
|
+
final_h.append(h_t)
|
|
492
|
+
if self.is_lstm and final_c is not None:
|
|
493
|
+
final_c.append(c_t)
|
|
494
|
+
|
|
495
|
+
h_n_list.append(
|
|
496
|
+
lucid.concatenate(tuple(reversed(final_h)), axis=0).unsqueeze(
|
|
497
|
+
axis=0
|
|
498
|
+
)
|
|
499
|
+
)
|
|
500
|
+
if self.is_lstm and final_c is not None and c_n_list is not None:
|
|
501
|
+
c_n_list.append(
|
|
502
|
+
lucid.concatenate(tuple(reversed(final_c)), axis=0).unsqueeze(
|
|
503
|
+
axis=0
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
layer_output = (
|
|
508
|
+
lucid.concatenate(tuple(outputs), axis=0)
|
|
509
|
+
if outputs
|
|
510
|
+
else layer_input[:0]
|
|
511
|
+
)
|
|
424
512
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
513
|
+
else:
|
|
514
|
+
for t in range(seq_len):
|
|
515
|
+
if self.is_lstm:
|
|
516
|
+
h_t, c_t = cell(layer_input[t], (h_t, c_t))
|
|
517
|
+
outputs.append(h_t.unsqueeze(axis=0))
|
|
518
|
+
else:
|
|
519
|
+
h_t = cell(layer_input[t], h_t)
|
|
520
|
+
outputs.append(h_t.unsqueeze(axis=0))
|
|
432
521
|
|
|
433
|
-
|
|
522
|
+
layer_output = lucid.concatenate(tuple(outputs), axis=0)
|
|
434
523
|
|
|
435
524
|
if self.training and self.dropout > 0.0 and layer_idx < self.num_layers - 1:
|
|
436
525
|
layer_output = F.dropout(layer_output, p=self.dropout)
|
|
437
526
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
c_n_list
|
|
527
|
+
if not is_packed:
|
|
528
|
+
h_n_list.append(h_t.unsqueeze(axis=0))
|
|
529
|
+
if self.is_lstm and c_n_list is not None:
|
|
530
|
+
c_n_list.append(c_t.unsqueeze(axis=0))
|
|
441
531
|
layer_input = layer_output
|
|
442
532
|
|
|
443
|
-
|
|
533
|
+
if is_packed:
|
|
534
|
+
output = PackedSequence(
|
|
535
|
+
data=layer_input,
|
|
536
|
+
batch_sizes=batch_sizes,
|
|
537
|
+
sorted_indices=input_.sorted_indices,
|
|
538
|
+
unsorted_indices=input_.unsorted_indices,
|
|
539
|
+
)
|
|
540
|
+
else:
|
|
541
|
+
output = layer_input
|
|
542
|
+
|
|
444
543
|
h_n = lucid.concatenate(tuple(h_n_list), axis=0)
|
|
445
544
|
if self.is_lstm and c_n_list is not None:
|
|
446
545
|
c_n = lucid.concatenate(tuple(c_n_list), axis=0)
|
|
447
546
|
|
|
448
|
-
if
|
|
449
|
-
|
|
547
|
+
if is_packed:
|
|
548
|
+
if input_.unsorted_indices is not None:
|
|
549
|
+
h_n = h_n[:, input_.unsorted_indices]
|
|
550
|
+
if self.is_lstm and c_n_list is not None:
|
|
551
|
+
c_n = c_n[:, input_.unsorted_indices]
|
|
552
|
+
else:
|
|
553
|
+
if self.batch_first:
|
|
554
|
+
output = output.swapaxes(0, 1)
|
|
450
555
|
|
|
451
556
|
if self.is_lstm and c_n_list is not None:
|
|
452
557
|
return output, (h_n, c_n)
|
|
@@ -6,7 +6,7 @@ from lucid._tensor import Tensor
|
|
|
6
6
|
from lucid.types import _Scalar
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
__all__ = ["grad_norm", "clip_grad_norm", "clip_grad_value"]
|
|
9
|
+
__all__ = ["grad_norm", "get_total_norm", "clip_grad_norm", "clip_grad_value"]
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def _as_iter(parameters: Iterable[Tensor] | Tensor) -> list[Tensor]:
|
|
@@ -32,6 +32,25 @@ def grad_norm(parameters: Iterable[Tensor] | Tensor, norm_type: int = 2) -> Tens
|
|
|
32
32
|
return Tensor(total_norm, device=device)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
def get_total_norm(parameters: Iterable[Tensor] | Tensor, norm_type: int = 2) -> Tensor:
|
|
36
|
+
parameters = _as_iter(parameters)
|
|
37
|
+
if not parameters:
|
|
38
|
+
return Tensor(0.0)
|
|
39
|
+
|
|
40
|
+
device = parameters[0].device
|
|
41
|
+
grads: list[Tensor] = [p.grad for p in parameters if p.grad is not None]
|
|
42
|
+
if not grads:
|
|
43
|
+
return Tensor(0.0, device=device)
|
|
44
|
+
|
|
45
|
+
norm_pow_sum = 0.0
|
|
46
|
+
for g in grads:
|
|
47
|
+
grad_norm = lucid.linalg.norm(lucid.ravel(g), ord=norm_type).item()
|
|
48
|
+
norm_pow_sum += grad_norm**norm_type
|
|
49
|
+
|
|
50
|
+
total_norm = norm_pow_sum ** (1.0 / norm_type)
|
|
51
|
+
return Tensor(total_norm, device=device)
|
|
52
|
+
|
|
53
|
+
|
|
35
54
|
def clip_grad_norm(
|
|
36
55
|
parameters: Iterable[Tensor] | Tensor,
|
|
37
56
|
max_norm: _Scalar,
|
|
@@ -39,7 +58,7 @@ def clip_grad_norm(
|
|
|
39
58
|
eps: float = 1e-7,
|
|
40
59
|
) -> float:
|
|
41
60
|
params: list[Tensor] = [p for p in _as_iter(parameters) if p.grad is not None]
|
|
42
|
-
total_norm =
|
|
61
|
+
total_norm = get_total_norm(params, norm_type=norm_type)
|
|
43
62
|
|
|
44
63
|
clip_coef = float(max_norm) / (total_norm.item() + eps)
|
|
45
64
|
if clip_coef < 1.0:
|