lucid-dl 2.11.0__py3-none-any.whl → 2.11.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/__init__.py +4 -2
- lucid/_backend/core.py +89 -9
- lucid/_backend/metal.py +5 -1
- lucid/_func/__init__.py +162 -0
- lucid/_tensor/{tensor_ops.py → base.py} +64 -0
- lucid/_tensor/tensor.py +63 -19
- lucid/autograd/__init__.py +4 -1
- lucid/datasets/mnist.py +135 -6
- lucid/models/imggen/__init__.py +1 -0
- lucid/models/imggen/ncsn.py +402 -0
- lucid/nn/_kernel/__init__.py +1 -0
- lucid/nn/_kernel/activation.py +188 -0
- lucid/nn/_kernel/attention.py +125 -0
- lucid/{_backend → nn/_kernel}/conv.py +4 -13
- lucid/nn/_kernel/embedding.py +72 -0
- lucid/nn/_kernel/loss.py +416 -0
- lucid/nn/_kernel/norm.py +365 -0
- lucid/{_backend → nn/_kernel}/pool.py +7 -27
- lucid/nn/functional/__init__.py +4 -0
- lucid/nn/functional/_activation.py +19 -13
- lucid/nn/functional/_attention.py +9 -0
- lucid/nn/functional/_conv.py +5 -16
- lucid/nn/functional/_loss.py +31 -32
- lucid/nn/functional/_norm.py +60 -69
- lucid/nn/functional/_pool.py +7 -7
- lucid/nn/functional/_util.py +5 -1
- lucid/nn/init/_dist.py +1 -0
- lucid/types.py +24 -2
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/METADATA +7 -5
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/RECORD +33 -26
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/WHEEL +1 -1
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.11.0.dist-info → lucid_dl-2.11.2.dist-info}/top_level.txt +0 -0
lucid/datasets/mnist.py
CHANGED
|
@@ -4,6 +4,8 @@ import openml
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
6
|
from typing import SupportsIndex, Tuple, ClassVar
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import re
|
|
7
9
|
|
|
8
10
|
import lucid
|
|
9
11
|
from lucid._tensor import Tensor
|
|
@@ -17,6 +19,42 @@ __all__ = ["MNIST", "FashionMNIST"]
|
|
|
17
19
|
class MNIST(DatasetBase):
|
|
18
20
|
OPENML_ID: ClassVar[int] = 554
|
|
19
21
|
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
root: str | Path,
|
|
25
|
+
train: bool | None = True,
|
|
26
|
+
download: bool | None = False,
|
|
27
|
+
transform: lucid.nn.Module | None = None,
|
|
28
|
+
target_transform: lucid.nn.Module | None = None,
|
|
29
|
+
test_size: float = 0.2,
|
|
30
|
+
to_tensor: bool = True,
|
|
31
|
+
*,
|
|
32
|
+
cache: bool = True,
|
|
33
|
+
scale: float | None = None,
|
|
34
|
+
resize: tuple[int, int] | None = None,
|
|
35
|
+
normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
|
|
36
|
+
cache_preprocessed: bool = True,
|
|
37
|
+
preprocess_dtype: lucid.Numeric = lucid.Float16,
|
|
38
|
+
preprocess_chunk_size: int = 4096,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.cache = cache
|
|
41
|
+
self.scale = scale
|
|
42
|
+
self.resize = resize
|
|
43
|
+
self.normalize = normalize
|
|
44
|
+
self.cache_preprocessed = cache_preprocessed
|
|
45
|
+
self.preprocess_dtype = preprocess_dtype
|
|
46
|
+
self.preprocess_chunk_size = preprocess_chunk_size
|
|
47
|
+
|
|
48
|
+
super().__init__(
|
|
49
|
+
root=root,
|
|
50
|
+
train=train,
|
|
51
|
+
download=download,
|
|
52
|
+
transform=transform,
|
|
53
|
+
target_transform=target_transform,
|
|
54
|
+
test_size=test_size,
|
|
55
|
+
to_tensor=to_tensor,
|
|
56
|
+
)
|
|
57
|
+
|
|
20
58
|
def _download(self) -> None:
|
|
21
59
|
try:
|
|
22
60
|
dataset = openml.datasets.get_dataset(self.OPENML_ID)
|
|
@@ -26,18 +64,106 @@ class MNIST(DatasetBase):
|
|
|
26
64
|
except Exception as e:
|
|
27
65
|
raise RuntimeError(f"Failed to download the MNIST dataset. Error: {e}")
|
|
28
66
|
|
|
29
|
-
def
|
|
67
|
+
def _cache_key(self) -> str:
|
|
68
|
+
parts: list[str] = []
|
|
69
|
+
if self.scale is not None:
|
|
70
|
+
parts.append(f"s{self.scale:g}")
|
|
71
|
+
if self.resize is not None:
|
|
72
|
+
parts.append(f"r{self.resize[0]}x{self.resize[1]}")
|
|
73
|
+
if self.normalize is not None:
|
|
74
|
+
mean, std = self.normalize
|
|
75
|
+
parts.append("m" + ",".join(f"{v:g}" for v in mean))
|
|
76
|
+
parts.append("v" + ",".join(f"{v:g}" for v in std))
|
|
77
|
+
if not parts:
|
|
78
|
+
return "raw"
|
|
79
|
+
key = "_".join(parts)
|
|
80
|
+
return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
|
|
81
|
+
|
|
82
|
+
def _raw_cache_path(self) -> Path:
|
|
83
|
+
return self.root / "MNIST_int16.npz"
|
|
84
|
+
|
|
85
|
+
def _proc_cache_path(self) -> Path:
|
|
86
|
+
dtype_name = str(self.preprocess_dtype)
|
|
87
|
+
return self.root / f"MNIST_{self._cache_key()}_{dtype_name}.npz"
|
|
88
|
+
|
|
89
|
+
def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
|
|
90
|
+
raw_path = self._raw_cache_path()
|
|
91
|
+
if self.cache and raw_path.exists():
|
|
92
|
+
with np.load(raw_path) as npz:
|
|
93
|
+
images = npz["images"]
|
|
94
|
+
labels = npz["labels"]
|
|
95
|
+
return images, labels
|
|
96
|
+
|
|
30
97
|
csv_path = self.root / "MNIST.csv"
|
|
31
98
|
if not csv_path.exists():
|
|
32
99
|
raise RuntimeError(
|
|
33
|
-
f"MNIST dataset CSV file not found at {csv_path}. "
|
|
34
|
-
+ "Use `download=True`."
|
|
100
|
+
f"MNIST dataset CSV file not found at {csv_path}. Use `download=True`."
|
|
35
101
|
)
|
|
36
102
|
|
|
37
103
|
df = pd.read_csv(csv_path)
|
|
38
104
|
labels = df["class"].values.astype(np.int32)
|
|
39
105
|
images = df.drop(columns=["class"]).values.astype(np.float32)
|
|
40
|
-
images = images.reshape(-1, 1, 28, 28)
|
|
106
|
+
images = images.reshape(-1, 1, 28, 28).astype(np.int16)
|
|
107
|
+
|
|
108
|
+
if self.cache:
|
|
109
|
+
np.savez_compressed(raw_path, images=images, labels=labels)
|
|
110
|
+
|
|
111
|
+
return images, labels
|
|
112
|
+
|
|
113
|
+
def _maybe_preprocess_and_cache(
|
|
114
|
+
self, images_int16: np.ndarray, labels_int32: np.ndarray
|
|
115
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
116
|
+
if self.resize is None and self.scale is None and self.normalize is None:
|
|
117
|
+
return images_int16.astype(np.float32), labels_int32
|
|
118
|
+
|
|
119
|
+
proc_path = self._proc_cache_path()
|
|
120
|
+
if self.cache and self.cache_preprocessed and proc_path.exists():
|
|
121
|
+
with np.load(proc_path) as npz:
|
|
122
|
+
images = npz["images"]
|
|
123
|
+
labels = npz["labels"]
|
|
124
|
+
return images, labels
|
|
125
|
+
|
|
126
|
+
from lucid.transforms import Compose, Resize, Normalize
|
|
127
|
+
|
|
128
|
+
class _Scale(lucid.nn.Module):
|
|
129
|
+
def __init__(self, factor: float) -> None:
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.factor = factor
|
|
132
|
+
|
|
133
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
134
|
+
return x * self.factor
|
|
135
|
+
|
|
136
|
+
transforms: list[lucid.nn.Module] = []
|
|
137
|
+
if self.resize is not None:
|
|
138
|
+
transforms.append(Resize(self.resize))
|
|
139
|
+
if self.scale is not None:
|
|
140
|
+
transforms.append(_Scale(self.scale))
|
|
141
|
+
if self.normalize is not None:
|
|
142
|
+
mean, std = self.normalize
|
|
143
|
+
transforms.append(Normalize(mean=mean, std=std))
|
|
144
|
+
|
|
145
|
+
transform = Compose(transforms)
|
|
146
|
+
n = images_int16.shape[0]
|
|
147
|
+
out_h, out_w = self.resize if self.resize is not None else (28, 28)
|
|
148
|
+
|
|
149
|
+
out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
|
|
150
|
+
out_images = np.empty((n, 1, out_h, out_w), dtype=out_dtype)
|
|
151
|
+
|
|
152
|
+
for start in range(0, n, self.preprocess_chunk_size):
|
|
153
|
+
end = min(start + self.preprocess_chunk_size, n)
|
|
154
|
+
chunk = images_int16[start:end].astype(np.float32)
|
|
155
|
+
x = lucid.to_tensor(chunk, dtype=lucid.Float32)
|
|
156
|
+
x = transform(x)
|
|
157
|
+
out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
|
|
158
|
+
|
|
159
|
+
if self.cache and self.cache_preprocessed:
|
|
160
|
+
np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
|
|
161
|
+
|
|
162
|
+
return out_images, labels_int32
|
|
163
|
+
|
|
164
|
+
def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
|
|
165
|
+
images, labels = self._ensure_raw_cache()
|
|
166
|
+
images, labels = self._maybe_preprocess_and_cache(images, labels)
|
|
41
167
|
|
|
42
168
|
train_size = int(math.floor(len(images) * (1 - self.test_size)))
|
|
43
169
|
if split == "train":
|
|
@@ -46,13 +172,16 @@ class MNIST(DatasetBase):
|
|
|
46
172
|
images, labels = images[train_size:], labels[train_size:]
|
|
47
173
|
|
|
48
174
|
if self.to_tensor:
|
|
49
|
-
images
|
|
175
|
+
if images.dtype == np.float16 and self.preprocess_dtype == lucid.Float16:
|
|
176
|
+
images = lucid.to_tensor(images, dtype=lucid.Float16)
|
|
177
|
+
else:
|
|
178
|
+
images = lucid.to_tensor(images, dtype=lucid.Float32)
|
|
50
179
|
labels = lucid.to_tensor(labels, dtype=lucid.Int32)
|
|
51
180
|
|
|
52
181
|
return images, labels
|
|
53
182
|
|
|
54
183
|
def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
|
|
55
|
-
image = self.data[index]
|
|
184
|
+
image = self.data[index]
|
|
56
185
|
label = self.targets[index]
|
|
57
186
|
|
|
58
187
|
if self.transform:
|
lucid/models/imggen/__init__.py
CHANGED
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Literal, Sequence
|
|
3
|
+
|
|
4
|
+
import lucid
|
|
5
|
+
import lucid.nn as nn
|
|
6
|
+
import lucid.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from lucid._tensor import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["NCSN"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _CondInstanceNorm(nn.Module):
|
|
15
|
+
def __init__(self, num_features: int, num_classes: int, eps: float = 1e-5) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.num_features = num_features
|
|
18
|
+
self.norm = nn.InstanceNorm2d(num_features, affine=False, eps=eps)
|
|
19
|
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
|
20
|
+
|
|
21
|
+
nn.init.constant(self.embed.weight[:, :num_features], 1.0)
|
|
22
|
+
nn.init.constant(self.embed.weight[:, num_features:], 0.0)
|
|
23
|
+
|
|
24
|
+
def forward(self, x: Tensor, y: Tensor) -> Tensor:
|
|
25
|
+
if y.dtype != lucid.Long:
|
|
26
|
+
y = y.long()
|
|
27
|
+
|
|
28
|
+
h = self.norm(x)
|
|
29
|
+
gamma_beta = self.embed(y)
|
|
30
|
+
gamma, beta = lucid.chunk(gamma_beta, 2, axis=1)
|
|
31
|
+
|
|
32
|
+
gamma = gamma.reshape(-1, self.num_features, 1, 1)
|
|
33
|
+
beta = beta.reshape(-1, self.num_features, 1, 1)
|
|
34
|
+
|
|
35
|
+
return h * gamma + beta
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _Conv3x3(nn.Module):
|
|
39
|
+
def __init__(
|
|
40
|
+
self, in_channels: int, out_channels: int, dilation: int = 1, bias: bool = True
|
|
41
|
+
) -> None:
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.conv = nn.Conv2d(
|
|
44
|
+
in_channels,
|
|
45
|
+
out_channels,
|
|
46
|
+
kernel_size=3,
|
|
47
|
+
stride=1,
|
|
48
|
+
padding=dilation,
|
|
49
|
+
dilation=dilation,
|
|
50
|
+
bias=bias,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
54
|
+
return self.conv(x)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class _ResidualConvUnit(nn.Module):
|
|
58
|
+
def __init__(self, channels: int, num_classes: int, dilation: int = 1) -> None:
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.norm1 = _CondInstanceNorm(channels, num_classes)
|
|
61
|
+
self.conv1 = _Conv3x3(channels, channels, dilation=dilation)
|
|
62
|
+
|
|
63
|
+
self.norm2 = _CondInstanceNorm(channels, num_classes)
|
|
64
|
+
self.conv2 = _Conv3x3(channels, channels, dilation=dilation)
|
|
65
|
+
|
|
66
|
+
self.act = nn.ELU()
|
|
67
|
+
|
|
68
|
+
def forward(self, x: Tensor, y: Tensor) -> Tensor:
|
|
69
|
+
h = self.conv1(self.act(self.norm1(x, y)))
|
|
70
|
+
h = self.conv2(self.act(self.norm2(h, y)))
|
|
71
|
+
|
|
72
|
+
return x + h
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class _RCUBlock(nn.Module):
|
|
76
|
+
def __init__(
|
|
77
|
+
self, channels: int, num_classes: int, num_units: int = 2, dilation: int = 1
|
|
78
|
+
) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.units = nn.ModuleList(
|
|
81
|
+
[
|
|
82
|
+
_ResidualConvUnit(channels, num_classes, dilation=dilation)
|
|
83
|
+
for _ in range(num_units)
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def forward(self, x: Tensor, y: Tensor) -> Tensor:
|
|
88
|
+
h = x
|
|
89
|
+
for unit in self.units:
|
|
90
|
+
h = unit(h, y)
|
|
91
|
+
return h
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class _CondAdapter(nn.Module):
|
|
95
|
+
def __init__(self, in_channels: int, out_channels: int, num_classes: int) -> None:
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.in_channels = in_channels
|
|
98
|
+
self.out_channels = out_channels
|
|
99
|
+
|
|
100
|
+
if in_channels == out_channels:
|
|
101
|
+
self.norm = None
|
|
102
|
+
self.conv = None
|
|
103
|
+
else:
|
|
104
|
+
self.norm = _CondInstanceNorm(in_channels, num_classes)
|
|
105
|
+
self.conv = nn.Conv2d(
|
|
106
|
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
107
|
+
)
|
|
108
|
+
self.act = nn.ELU()
|
|
109
|
+
|
|
110
|
+
def forward(self, x: Tensor, y: Tensor) -> Tensor:
|
|
111
|
+
if self.in_channels == self.out_channels:
|
|
112
|
+
return x
|
|
113
|
+
return self.conv(self.act(self.norm(x, y)))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class _MultiResFusion(nn.Module):
|
|
117
|
+
def __init__(
|
|
118
|
+
self, in_channels_arr: Sequence[int], out_channels: int, num_classes: int
|
|
119
|
+
) -> None:
|
|
120
|
+
super().__init__()
|
|
121
|
+
self.out_channels = out_channels
|
|
122
|
+
self.norms = nn.ModuleList(
|
|
123
|
+
[_CondInstanceNorm(c, num_classes) for c in in_channels_arr]
|
|
124
|
+
)
|
|
125
|
+
self.convs = nn.ModuleList(
|
|
126
|
+
[
|
|
127
|
+
nn.Conv2d(c, out_channels, kernel_size=3, stride=1, padding=1)
|
|
128
|
+
for c in in_channels_arr
|
|
129
|
+
]
|
|
130
|
+
)
|
|
131
|
+
self.act = nn.ELU()
|
|
132
|
+
|
|
133
|
+
def forward(self, xs: Sequence[Tensor], y: Tensor) -> Tensor:
|
|
134
|
+
if len(xs) != len(self.convs):
|
|
135
|
+
raise ValueError(f"Expected {len(self.convs)} inputs, got {len(xs)}")
|
|
136
|
+
|
|
137
|
+
target_h = max(x.shape[-2] for x in xs)
|
|
138
|
+
target_w = max(x.shape[-1] for x in xs)
|
|
139
|
+
fused = None
|
|
140
|
+
|
|
141
|
+
for x, norm, conv in zip(xs, self.norms, self.convs):
|
|
142
|
+
h = conv(self.act(norm(x, y)))
|
|
143
|
+
if h.shape[-2:] != (target_h, target_w):
|
|
144
|
+
h = F.interpolate(h, size=(target_h, target_w), mode="nearest")
|
|
145
|
+
|
|
146
|
+
fused = h if fused is None else fused + h
|
|
147
|
+
|
|
148
|
+
return fused
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class _ChainedResPooling(nn.Module):
|
|
152
|
+
def __init__(self, channels: int, num_classes: int, num_stages: int = 4) -> None:
|
|
153
|
+
super().__init__()
|
|
154
|
+
self.norms = nn.ModuleList(
|
|
155
|
+
[_CondInstanceNorm(channels, num_classes) for _ in range(num_stages)]
|
|
156
|
+
)
|
|
157
|
+
self.convs = nn.ModuleList(
|
|
158
|
+
[
|
|
159
|
+
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
|
|
160
|
+
for _ in range(num_stages)
|
|
161
|
+
]
|
|
162
|
+
)
|
|
163
|
+
self.act = nn.ELU()
|
|
164
|
+
|
|
165
|
+
def forward(self, x: Tensor, y: Tensor) -> Tensor:
|
|
166
|
+
h = x
|
|
167
|
+
out = x
|
|
168
|
+
for norm, conv in zip(self.norms, self.convs):
|
|
169
|
+
h = self.act(norm(h, y))
|
|
170
|
+
h = F.max_pool2d(h, kernel_size=5, stride=1, padding=2)
|
|
171
|
+
h = conv(h)
|
|
172
|
+
out = out + h
|
|
173
|
+
|
|
174
|
+
return out
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class _RefineBlock(nn.Module):
|
|
178
|
+
def __init__(
|
|
179
|
+
self, in_channels_arr: Sequence[int], out_channels: int, num_classes: int
|
|
180
|
+
) -> None:
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.adapters = nn.ModuleList(
|
|
183
|
+
[_CondAdapter(c, out_channels, num_classes) for c in in_channels_arr]
|
|
184
|
+
)
|
|
185
|
+
self.rcu_in = nn.ModuleList(
|
|
186
|
+
[_RCUBlock(out_channels, num_classes, num_units=2) for _ in in_channels_arr]
|
|
187
|
+
)
|
|
188
|
+
self.msf = _MultiResFusion(
|
|
189
|
+
[out_channels] * len(in_channels_arr), out_channels, num_classes
|
|
190
|
+
)
|
|
191
|
+
self.crp = _ChainedResPooling(out_channels, num_classes, num_stages=4)
|
|
192
|
+
self.rcu_out = _RCUBlock(out_channels, num_classes, num_units=2)
|
|
193
|
+
|
|
194
|
+
def forward(self, xs: Sequence[Tensor], y: Tensor) -> Tensor:
|
|
195
|
+
if len(xs) != len(self.adapters):
|
|
196
|
+
raise ValueError(f"Expected {len(self.adapters)} inputs, got {len(xs)}")
|
|
197
|
+
|
|
198
|
+
hs: list[Tensor] = []
|
|
199
|
+
for x, adapter, rcu in zip(xs, self.adapters, self.rcu_in):
|
|
200
|
+
h = adapter(x, y)
|
|
201
|
+
h = rcu(h, y)
|
|
202
|
+
hs.append(h)
|
|
203
|
+
|
|
204
|
+
h = hs[0] if len(hs) == 1 else self.msf(hs, y)
|
|
205
|
+
h = self.crp(h, y)
|
|
206
|
+
h = self.rcu_out(h, y)
|
|
207
|
+
|
|
208
|
+
return h
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@nn.set_state_dict_pass_attr("sigmas")
|
|
212
|
+
class NCSN(nn.Module):
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
in_channels: int = 3,
|
|
216
|
+
nf: int = 128,
|
|
217
|
+
num_classes: int = 10,
|
|
218
|
+
dilations: Sequence[int] = (1, 2, 4, 8),
|
|
219
|
+
scale_by_sigma: bool = True,
|
|
220
|
+
) -> None:
|
|
221
|
+
super().__init__()
|
|
222
|
+
if len(dilations) != 4:
|
|
223
|
+
raise ValueError("Expected 4 dilation values (for 4 RefineNet stages).")
|
|
224
|
+
|
|
225
|
+
self.in_channels = in_channels
|
|
226
|
+
self.nf = nf
|
|
227
|
+
self.num_classes = num_classes
|
|
228
|
+
self.scale_by_sigma = bool(scale_by_sigma)
|
|
229
|
+
|
|
230
|
+
self.sigmas: nn.Buffer
|
|
231
|
+
self.register_buffer("sigmas", lucid.empty(num_classes))
|
|
232
|
+
|
|
233
|
+
self.begin_conv = nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)
|
|
234
|
+
|
|
235
|
+
self.stage1 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[0])
|
|
236
|
+
self.stage2 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[1])
|
|
237
|
+
self.stage3 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[2])
|
|
238
|
+
self.stage4 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[3])
|
|
239
|
+
|
|
240
|
+
self.refine4 = _RefineBlock([nf], nf, num_classes)
|
|
241
|
+
self.refine3 = _RefineBlock([nf, nf], nf, num_classes)
|
|
242
|
+
self.refine2 = _RefineBlock([nf, nf], nf, num_classes)
|
|
243
|
+
self.refine1 = _RefineBlock([nf, nf], nf, num_classes)
|
|
244
|
+
|
|
245
|
+
self.end_norm = _CondInstanceNorm(nf, num_classes)
|
|
246
|
+
self.end_act = nn.ELU()
|
|
247
|
+
self.end_conv = nn.Conv2d(nf, in_channels, kernel_size=3, stride=1, padding=1)
|
|
248
|
+
|
|
249
|
+
self.apply(self._init_weights)
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def _init_weights(m: nn.Module) -> None:
|
|
253
|
+
if isinstance(m, nn.Conv2d):
|
|
254
|
+
nn.init.xavier_uniform(m.weight)
|
|
255
|
+
if m.bias is not None:
|
|
256
|
+
nn.init.constant(m.bias, 0.0)
|
|
257
|
+
|
|
258
|
+
def forward(self, x: Tensor, labels: Tensor) -> Tensor:
|
|
259
|
+
if labels.ndim != 1:
|
|
260
|
+
labels = labels.reshape(-1)
|
|
261
|
+
|
|
262
|
+
h = self.begin_conv(x)
|
|
263
|
+
h1 = self.stage1(h, labels)
|
|
264
|
+
h2 = self.stage2(h1, labels)
|
|
265
|
+
h3 = self.stage3(h2, labels)
|
|
266
|
+
h4 = self.stage4(h3, labels)
|
|
267
|
+
|
|
268
|
+
r4 = self.refine4([h4], labels)
|
|
269
|
+
r3 = self.refine3([h3, r4], labels)
|
|
270
|
+
r2 = self.refine2([h2, r3], labels)
|
|
271
|
+
r1 = self.refine1([h1, r2], labels)
|
|
272
|
+
|
|
273
|
+
out = self.end_conv(self.end_act(self.end_norm(r1, labels)))
|
|
274
|
+
if self.scale_by_sigma:
|
|
275
|
+
if self.sigmas.size != self.num_classes:
|
|
276
|
+
raise RuntimeError(
|
|
277
|
+
f"'sigmas' buffer has shape {self.sigmas.shape}; "
|
|
278
|
+
f"expected ({self.num_classes},). Call 'set_sigmas(...)'."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
used_sigmas = self.sigmas[labels].reshape(-1, 1, 1, 1)
|
|
282
|
+
out = out / used_sigmas
|
|
283
|
+
|
|
284
|
+
return out
|
|
285
|
+
|
|
286
|
+
@lucid.no_grad()
|
|
287
|
+
def set_sigmas(self, sigmas: Tensor) -> None:
|
|
288
|
+
if sigmas.ndim != 1:
|
|
289
|
+
raise ValueError("sigmas must be 1D.")
|
|
290
|
+
if sigmas.size != self.num_classes:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
f"sigmas length ({sigmas.size}) must match "
|
|
293
|
+
f"num_classes ({self.num_classes})."
|
|
294
|
+
)
|
|
295
|
+
tmp = sigmas.detach()
|
|
296
|
+
tmp.to(self.sigmas.device)
|
|
297
|
+
tmp.to(self.sigmas.dtype)
|
|
298
|
+
self.sigmas.data = tmp.data
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
@lucid.no_grad()
|
|
302
|
+
def make_sigmas(sigma_begin: float, sigma_end: float, num_scales: int) -> Tensor:
|
|
303
|
+
if sigma_begin <= 0 or sigma_end <= 0:
|
|
304
|
+
raise ValueError("sigmas must be positive.")
|
|
305
|
+
if sigma_begin <= sigma_end:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"Expected sigma_begin > sigma_end (descending noise schedule)."
|
|
308
|
+
)
|
|
309
|
+
if num_scales < 2:
|
|
310
|
+
raise ValueError("num_scales must be >= 2.")
|
|
311
|
+
|
|
312
|
+
return lucid.exp(
|
|
313
|
+
lucid.linspace(math.log(sigma_begin), math.log(sigma_end), num_scales)
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
def get_loss(self, x: Tensor) -> tuple[Tensor, Tensor]:
|
|
317
|
+
batch_size = x.shape[0]
|
|
318
|
+
labels = lucid.random.randint(
|
|
319
|
+
0, self.sigmas.shape[0], (batch_size,), device=x.device
|
|
320
|
+
).long()
|
|
321
|
+
used_sigmas = self.sigmas[labels].reshape(batch_size, 1, 1, 1)
|
|
322
|
+
|
|
323
|
+
noise = lucid.random.randn(x.shape, device=x.device)
|
|
324
|
+
perturbed = x + used_sigmas * noise
|
|
325
|
+
score = self.forward(perturbed, labels)
|
|
326
|
+
|
|
327
|
+
loss = lucid.sum((score * used_sigmas + noise) ** 2, axis=(1, 2, 3)).mean()
|
|
328
|
+
return loss, labels
|
|
329
|
+
|
|
330
|
+
@lucid.no_grad()
|
|
331
|
+
def sample(
|
|
332
|
+
self,
|
|
333
|
+
n_samples: int,
|
|
334
|
+
image_size: int,
|
|
335
|
+
in_channels: int,
|
|
336
|
+
n_steps_each: int,
|
|
337
|
+
step_lr: float,
|
|
338
|
+
clip: bool = True,
|
|
339
|
+
denoise: bool = False,
|
|
340
|
+
init: Tensor | None = None,
|
|
341
|
+
init_dist: Literal["uniform", "normal"] = "uniform",
|
|
342
|
+
verbose: bool = True,
|
|
343
|
+
) -> Tensor:
|
|
344
|
+
self.eval()
|
|
345
|
+
if init is None:
|
|
346
|
+
if init_dist == "uniform":
|
|
347
|
+
x = lucid.random.uniform(
|
|
348
|
+
-1.0,
|
|
349
|
+
1.0,
|
|
350
|
+
(n_samples, in_channels, image_size, image_size),
|
|
351
|
+
device=self.device,
|
|
352
|
+
)
|
|
353
|
+
elif init_dist == "normal":
|
|
354
|
+
x = lucid.random.randn(
|
|
355
|
+
n_samples, in_channels, image_size, image_size, device=self.device
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
raise ValueError("init_dist must be either 'uniform' or 'normal'.")
|
|
359
|
+
|
|
360
|
+
else:
|
|
361
|
+
x = init.to(self.device)
|
|
362
|
+
if x.shape != (n_samples, in_channels, image_size, image_size):
|
|
363
|
+
raise ValueError(
|
|
364
|
+
f"init has shape {x.shape} but expected "
|
|
365
|
+
f"{(n_samples, in_channels, image_size, image_size)}."
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
from tqdm import tqdm
|
|
369
|
+
|
|
370
|
+
total = int(self.sigmas.shape[0]) * int(n_steps_each)
|
|
371
|
+
pbar = (
|
|
372
|
+
tqdm(total=total, desc="Sampling", dynamic_ncols=True) if verbose else None
|
|
373
|
+
)
|
|
374
|
+
for i, sigma in enumerate(self.sigmas):
|
|
375
|
+
labels = lucid.full(n_samples, i, device=self.device, dtype=lucid.Long)
|
|
376
|
+
step_size = step_lr * (sigma / self.sigmas[-1]) ** 2
|
|
377
|
+
|
|
378
|
+
for j in range(n_steps_each):
|
|
379
|
+
grad = self.forward(x, labels)
|
|
380
|
+
noise = lucid.random.randn(x.shape, device=self.device)
|
|
381
|
+
|
|
382
|
+
x = x + step_size * grad + lucid.sqrt(2.0 * step_size) * noise
|
|
383
|
+
if clip:
|
|
384
|
+
x = x.clip(-1.0, 1.0)
|
|
385
|
+
if verbose:
|
|
386
|
+
pbar.update(1)
|
|
387
|
+
pbar.set_postfix(sigma=f"{sigma.item():.4f}", l=i, t=j)
|
|
388
|
+
|
|
389
|
+
if denoise:
|
|
390
|
+
last_label = lucid.full(
|
|
391
|
+
n_samples,
|
|
392
|
+
self.sigmas.shape[0] - 1,
|
|
393
|
+
device=self.device,
|
|
394
|
+
dtype=lucid.Long,
|
|
395
|
+
)
|
|
396
|
+
x = x + (self.sigmas[-1] ** 2) * self.forward(x, last_label)
|
|
397
|
+
if clip:
|
|
398
|
+
x = x.clip(-1.0, 1.0)
|
|
399
|
+
|
|
400
|
+
if verbose:
|
|
401
|
+
pbar.close()
|
|
402
|
+
return x
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
pass
|