lucid-dl 2.11.0__py3-none-any.whl → 2.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lucid/datasets/mnist.py CHANGED
@@ -4,6 +4,8 @@ import openml
4
4
  import math
5
5
 
6
6
  from typing import SupportsIndex, Tuple, ClassVar
7
+ from pathlib import Path
8
+ import re
7
9
 
8
10
  import lucid
9
11
  from lucid._tensor import Tensor
@@ -17,6 +19,42 @@ __all__ = ["MNIST", "FashionMNIST"]
17
19
  class MNIST(DatasetBase):
18
20
  OPENML_ID: ClassVar[int] = 554
19
21
 
22
+ def __init__(
23
+ self,
24
+ root: str | Path,
25
+ train: bool | None = True,
26
+ download: bool | None = False,
27
+ transform: lucid.nn.Module | None = None,
28
+ target_transform: lucid.nn.Module | None = None,
29
+ test_size: float = 0.2,
30
+ to_tensor: bool = True,
31
+ *,
32
+ cache: bool = True,
33
+ scale: float | None = None,
34
+ resize: tuple[int, int] | None = None,
35
+ normalize: tuple[tuple[float, ...], tuple[float, ...]] | None = None,
36
+ cache_preprocessed: bool = True,
37
+ preprocess_dtype: lucid.Numeric = lucid.Float16,
38
+ preprocess_chunk_size: int = 4096,
39
+ ) -> None:
40
+ self.cache = cache
41
+ self.scale = scale
42
+ self.resize = resize
43
+ self.normalize = normalize
44
+ self.cache_preprocessed = cache_preprocessed
45
+ self.preprocess_dtype = preprocess_dtype
46
+ self.preprocess_chunk_size = preprocess_chunk_size
47
+
48
+ super().__init__(
49
+ root=root,
50
+ train=train,
51
+ download=download,
52
+ transform=transform,
53
+ target_transform=target_transform,
54
+ test_size=test_size,
55
+ to_tensor=to_tensor,
56
+ )
57
+
20
58
  def _download(self) -> None:
21
59
  try:
22
60
  dataset = openml.datasets.get_dataset(self.OPENML_ID)
@@ -26,18 +64,106 @@ class MNIST(DatasetBase):
26
64
  except Exception as e:
27
65
  raise RuntimeError(f"Failed to download the MNIST dataset. Error: {e}")
28
66
 
29
- def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
67
+ def _cache_key(self) -> str:
68
+ parts: list[str] = []
69
+ if self.scale is not None:
70
+ parts.append(f"s{self.scale:g}")
71
+ if self.resize is not None:
72
+ parts.append(f"r{self.resize[0]}x{self.resize[1]}")
73
+ if self.normalize is not None:
74
+ mean, std = self.normalize
75
+ parts.append("m" + ",".join(f"{v:g}" for v in mean))
76
+ parts.append("v" + ",".join(f"{v:g}" for v in std))
77
+ if not parts:
78
+ return "raw"
79
+ key = "_".join(parts)
80
+ return re.sub(r"[^a-zA-Z0-9_,.x-]+", "_", key)
81
+
82
+ def _raw_cache_path(self) -> Path:
83
+ return self.root / "MNIST_int16.npz"
84
+
85
+ def _proc_cache_path(self) -> Path:
86
+ dtype_name = str(self.preprocess_dtype)
87
+ return self.root / f"MNIST_{self._cache_key()}_{dtype_name}.npz"
88
+
89
+ def _ensure_raw_cache(self) -> tuple[np.ndarray, np.ndarray]:
90
+ raw_path = self._raw_cache_path()
91
+ if self.cache and raw_path.exists():
92
+ with np.load(raw_path) as npz:
93
+ images = npz["images"]
94
+ labels = npz["labels"]
95
+ return images, labels
96
+
30
97
  csv_path = self.root / "MNIST.csv"
31
98
  if not csv_path.exists():
32
99
  raise RuntimeError(
33
- f"MNIST dataset CSV file not found at {csv_path}. "
34
- + "Use `download=True`."
100
+ f"MNIST dataset CSV file not found at {csv_path}. Use `download=True`."
35
101
  )
36
102
 
37
103
  df = pd.read_csv(csv_path)
38
104
  labels = df["class"].values.astype(np.int32)
39
105
  images = df.drop(columns=["class"]).values.astype(np.float32)
40
- images = images.reshape(-1, 1, 28, 28)
106
+ images = images.reshape(-1, 1, 28, 28).astype(np.int16)
107
+
108
+ if self.cache:
109
+ np.savez_compressed(raw_path, images=images, labels=labels)
110
+
111
+ return images, labels
112
+
113
+ def _maybe_preprocess_and_cache(
114
+ self, images_int16: np.ndarray, labels_int32: np.ndarray
115
+ ) -> tuple[np.ndarray, np.ndarray]:
116
+ if self.resize is None and self.scale is None and self.normalize is None:
117
+ return images_int16.astype(np.float32), labels_int32
118
+
119
+ proc_path = self._proc_cache_path()
120
+ if self.cache and self.cache_preprocessed and proc_path.exists():
121
+ with np.load(proc_path) as npz:
122
+ images = npz["images"]
123
+ labels = npz["labels"]
124
+ return images, labels
125
+
126
+ from lucid.transforms import Compose, Resize, Normalize
127
+
128
+ class _Scale(lucid.nn.Module):
129
+ def __init__(self, factor: float) -> None:
130
+ super().__init__()
131
+ self.factor = factor
132
+
133
+ def forward(self, x: Tensor) -> Tensor:
134
+ return x * self.factor
135
+
136
+ transforms: list[lucid.nn.Module] = []
137
+ if self.resize is not None:
138
+ transforms.append(Resize(self.resize))
139
+ if self.scale is not None:
140
+ transforms.append(_Scale(self.scale))
141
+ if self.normalize is not None:
142
+ mean, std = self.normalize
143
+ transforms.append(Normalize(mean=mean, std=std))
144
+
145
+ transform = Compose(transforms)
146
+ n = images_int16.shape[0]
147
+ out_h, out_w = self.resize if self.resize is not None else (28, 28)
148
+
149
+ out_dtype = np.float16 if self.preprocess_dtype == lucid.Float16 else np.float32
150
+ out_images = np.empty((n, 1, out_h, out_w), dtype=out_dtype)
151
+
152
+ for start in range(0, n, self.preprocess_chunk_size):
153
+ end = min(start + self.preprocess_chunk_size, n)
154
+ chunk = images_int16[start:end].astype(np.float32)
155
+ x = lucid.to_tensor(chunk, dtype=lucid.Float32)
156
+ x = transform(x)
157
+ out_images[start:end] = x.numpy().astype(out_dtype, copy=False)
158
+
159
+ if self.cache and self.cache_preprocessed:
160
+ np.savez_compressed(proc_path, images=out_images, labels=labels_int32)
161
+
162
+ return out_images, labels_int32
163
+
164
+ def _load_data(self, split: str) -> Tuple[Tensor, Tensor]:
165
+ images, labels = self._ensure_raw_cache()
166
+ images, labels = self._maybe_preprocess_and_cache(images, labels)
41
167
 
42
168
  train_size = int(math.floor(len(images) * (1 - self.test_size)))
43
169
  if split == "train":
@@ -46,13 +172,16 @@ class MNIST(DatasetBase):
46
172
  images, labels = images[train_size:], labels[train_size:]
47
173
 
48
174
  if self.to_tensor:
49
- images = lucid.to_tensor(images, dtype=lucid.Float32)
175
+ if images.dtype == np.float16 and self.preprocess_dtype == lucid.Float16:
176
+ images = lucid.to_tensor(images, dtype=lucid.Float16)
177
+ else:
178
+ images = lucid.to_tensor(images, dtype=lucid.Float32)
50
179
  labels = lucid.to_tensor(labels, dtype=lucid.Int32)
51
180
 
52
181
  return images, labels
53
182
 
54
183
  def __getitem__(self, index: SupportsIndex) -> Tuple[Tensor, Tensor]:
55
- image = self.data[index].reshape(-1, 1, 28, 28)
184
+ image = self.data[index]
56
185
  label = self.targets[index]
57
186
 
58
187
  if self.transform:
@@ -1,2 +1,3 @@
1
1
  from .ddpm import *
2
2
  from .vae import *
3
+ from .ncsn import *
@@ -0,0 +1,402 @@
1
+ import math
2
+ from typing import Literal, Sequence
3
+
4
+ import lucid
5
+ import lucid.nn as nn
6
+ import lucid.nn.functional as F
7
+
8
+ from lucid._tensor import Tensor
9
+
10
+
11
+ __all__ = ["NCSN"]
12
+
13
+
14
+ class _CondInstanceNorm(nn.Module):
15
+ def __init__(self, num_features: int, num_classes: int, eps: float = 1e-5) -> None:
16
+ super().__init__()
17
+ self.num_features = num_features
18
+ self.norm = nn.InstanceNorm2d(num_features, affine=False, eps=eps)
19
+ self.embed = nn.Embedding(num_classes, num_features * 2)
20
+
21
+ nn.init.constant(self.embed.weight[:, :num_features], 1.0)
22
+ nn.init.constant(self.embed.weight[:, num_features:], 0.0)
23
+
24
+ def forward(self, x: Tensor, y: Tensor) -> Tensor:
25
+ if y.dtype != lucid.Long:
26
+ y = y.long()
27
+
28
+ h = self.norm(x)
29
+ gamma_beta = self.embed(y)
30
+ gamma, beta = lucid.chunk(gamma_beta, 2, axis=1)
31
+
32
+ gamma = gamma.reshape(-1, self.num_features, 1, 1)
33
+ beta = beta.reshape(-1, self.num_features, 1, 1)
34
+
35
+ return h * gamma + beta
36
+
37
+
38
+ class _Conv3x3(nn.Module):
39
+ def __init__(
40
+ self, in_channels: int, out_channels: int, dilation: int = 1, bias: bool = True
41
+ ) -> None:
42
+ super().__init__()
43
+ self.conv = nn.Conv2d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=3,
47
+ stride=1,
48
+ padding=dilation,
49
+ dilation=dilation,
50
+ bias=bias,
51
+ )
52
+
53
+ def forward(self, x: Tensor) -> Tensor:
54
+ return self.conv(x)
55
+
56
+
57
+ class _ResidualConvUnit(nn.Module):
58
+ def __init__(self, channels: int, num_classes: int, dilation: int = 1) -> None:
59
+ super().__init__()
60
+ self.norm1 = _CondInstanceNorm(channels, num_classes)
61
+ self.conv1 = _Conv3x3(channels, channels, dilation=dilation)
62
+
63
+ self.norm2 = _CondInstanceNorm(channels, num_classes)
64
+ self.conv2 = _Conv3x3(channels, channels, dilation=dilation)
65
+
66
+ self.act = nn.ELU()
67
+
68
+ def forward(self, x: Tensor, y: Tensor) -> Tensor:
69
+ h = self.conv1(self.act(self.norm1(x, y)))
70
+ h = self.conv2(self.act(self.norm2(h, y)))
71
+
72
+ return x + h
73
+
74
+
75
+ class _RCUBlock(nn.Module):
76
+ def __init__(
77
+ self, channels: int, num_classes: int, num_units: int = 2, dilation: int = 1
78
+ ) -> None:
79
+ super().__init__()
80
+ self.units = nn.ModuleList(
81
+ [
82
+ _ResidualConvUnit(channels, num_classes, dilation=dilation)
83
+ for _ in range(num_units)
84
+ ]
85
+ )
86
+
87
+ def forward(self, x: Tensor, y: Tensor) -> Tensor:
88
+ h = x
89
+ for unit in self.units:
90
+ h = unit(h, y)
91
+ return h
92
+
93
+
94
+ class _CondAdapter(nn.Module):
95
+ def __init__(self, in_channels: int, out_channels: int, num_classes: int) -> None:
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ self.out_channels = out_channels
99
+
100
+ if in_channels == out_channels:
101
+ self.norm = None
102
+ self.conv = None
103
+ else:
104
+ self.norm = _CondInstanceNorm(in_channels, num_classes)
105
+ self.conv = nn.Conv2d(
106
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
107
+ )
108
+ self.act = nn.ELU()
109
+
110
+ def forward(self, x: Tensor, y: Tensor) -> Tensor:
111
+ if self.in_channels == self.out_channels:
112
+ return x
113
+ return self.conv(self.act(self.norm(x, y)))
114
+
115
+
116
+ class _MultiResFusion(nn.Module):
117
+ def __init__(
118
+ self, in_channels_arr: Sequence[int], out_channels: int, num_classes: int
119
+ ) -> None:
120
+ super().__init__()
121
+ self.out_channels = out_channels
122
+ self.norms = nn.ModuleList(
123
+ [_CondInstanceNorm(c, num_classes) for c in in_channels_arr]
124
+ )
125
+ self.convs = nn.ModuleList(
126
+ [
127
+ nn.Conv2d(c, out_channels, kernel_size=3, stride=1, padding=1)
128
+ for c in in_channels_arr
129
+ ]
130
+ )
131
+ self.act = nn.ELU()
132
+
133
+ def forward(self, xs: Sequence[Tensor], y: Tensor) -> Tensor:
134
+ if len(xs) != len(self.convs):
135
+ raise ValueError(f"Expected {len(self.convs)} inputs, got {len(xs)}")
136
+
137
+ target_h = max(x.shape[-2] for x in xs)
138
+ target_w = max(x.shape[-1] for x in xs)
139
+ fused = None
140
+
141
+ for x, norm, conv in zip(xs, self.norms, self.convs):
142
+ h = conv(self.act(norm(x, y)))
143
+ if h.shape[-2:] != (target_h, target_w):
144
+ h = F.interpolate(h, size=(target_h, target_w), mode="nearest")
145
+
146
+ fused = h if fused is None else fused + h
147
+
148
+ return fused
149
+
150
+
151
+ class _ChainedResPooling(nn.Module):
152
+ def __init__(self, channels: int, num_classes: int, num_stages: int = 4) -> None:
153
+ super().__init__()
154
+ self.norms = nn.ModuleList(
155
+ [_CondInstanceNorm(channels, num_classes) for _ in range(num_stages)]
156
+ )
157
+ self.convs = nn.ModuleList(
158
+ [
159
+ nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
160
+ for _ in range(num_stages)
161
+ ]
162
+ )
163
+ self.act = nn.ELU()
164
+
165
+ def forward(self, x: Tensor, y: Tensor) -> Tensor:
166
+ h = x
167
+ out = x
168
+ for norm, conv in zip(self.norms, self.convs):
169
+ h = self.act(norm(h, y))
170
+ h = F.max_pool2d(h, kernel_size=5, stride=1, padding=2)
171
+ h = conv(h)
172
+ out = out + h
173
+
174
+ return out
175
+
176
+
177
+ class _RefineBlock(nn.Module):
178
+ def __init__(
179
+ self, in_channels_arr: Sequence[int], out_channels: int, num_classes: int
180
+ ) -> None:
181
+ super().__init__()
182
+ self.adapters = nn.ModuleList(
183
+ [_CondAdapter(c, out_channels, num_classes) for c in in_channels_arr]
184
+ )
185
+ self.rcu_in = nn.ModuleList(
186
+ [_RCUBlock(out_channels, num_classes, num_units=2) for _ in in_channels_arr]
187
+ )
188
+ self.msf = _MultiResFusion(
189
+ [out_channels] * len(in_channels_arr), out_channels, num_classes
190
+ )
191
+ self.crp = _ChainedResPooling(out_channels, num_classes, num_stages=4)
192
+ self.rcu_out = _RCUBlock(out_channels, num_classes, num_units=2)
193
+
194
+ def forward(self, xs: Sequence[Tensor], y: Tensor) -> Tensor:
195
+ if len(xs) != len(self.adapters):
196
+ raise ValueError(f"Expected {len(self.adapters)} inputs, got {len(xs)}")
197
+
198
+ hs: list[Tensor] = []
199
+ for x, adapter, rcu in zip(xs, self.adapters, self.rcu_in):
200
+ h = adapter(x, y)
201
+ h = rcu(h, y)
202
+ hs.append(h)
203
+
204
+ h = hs[0] if len(hs) == 1 else self.msf(hs, y)
205
+ h = self.crp(h, y)
206
+ h = self.rcu_out(h, y)
207
+
208
+ return h
209
+
210
+
211
+ @nn.set_state_dict_pass_attr("sigmas")
212
+ class NCSN(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_channels: int = 3,
216
+ nf: int = 128,
217
+ num_classes: int = 10,
218
+ dilations: Sequence[int] = (1, 2, 4, 8),
219
+ scale_by_sigma: bool = True,
220
+ ) -> None:
221
+ super().__init__()
222
+ if len(dilations) != 4:
223
+ raise ValueError("Expected 4 dilation values (for 4 RefineNet stages).")
224
+
225
+ self.in_channels = in_channels
226
+ self.nf = nf
227
+ self.num_classes = num_classes
228
+ self.scale_by_sigma = bool(scale_by_sigma)
229
+
230
+ self.sigmas: nn.Buffer
231
+ self.register_buffer("sigmas", lucid.empty(num_classes))
232
+
233
+ self.begin_conv = nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)
234
+
235
+ self.stage1 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[0])
236
+ self.stage2 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[1])
237
+ self.stage3 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[2])
238
+ self.stage4 = _RCUBlock(nf, num_classes, num_units=2, dilation=dilations[3])
239
+
240
+ self.refine4 = _RefineBlock([nf], nf, num_classes)
241
+ self.refine3 = _RefineBlock([nf, nf], nf, num_classes)
242
+ self.refine2 = _RefineBlock([nf, nf], nf, num_classes)
243
+ self.refine1 = _RefineBlock([nf, nf], nf, num_classes)
244
+
245
+ self.end_norm = _CondInstanceNorm(nf, num_classes)
246
+ self.end_act = nn.ELU()
247
+ self.end_conv = nn.Conv2d(nf, in_channels, kernel_size=3, stride=1, padding=1)
248
+
249
+ self.apply(self._init_weights)
250
+
251
+ @staticmethod
252
+ def _init_weights(m: nn.Module) -> None:
253
+ if isinstance(m, nn.Conv2d):
254
+ nn.init.xavier_uniform(m.weight)
255
+ if m.bias is not None:
256
+ nn.init.constant(m.bias, 0.0)
257
+
258
+ def forward(self, x: Tensor, labels: Tensor) -> Tensor:
259
+ if labels.ndim != 1:
260
+ labels = labels.reshape(-1)
261
+
262
+ h = self.begin_conv(x)
263
+ h1 = self.stage1(h, labels)
264
+ h2 = self.stage2(h1, labels)
265
+ h3 = self.stage3(h2, labels)
266
+ h4 = self.stage4(h3, labels)
267
+
268
+ r4 = self.refine4([h4], labels)
269
+ r3 = self.refine3([h3, r4], labels)
270
+ r2 = self.refine2([h2, r3], labels)
271
+ r1 = self.refine1([h1, r2], labels)
272
+
273
+ out = self.end_conv(self.end_act(self.end_norm(r1, labels)))
274
+ if self.scale_by_sigma:
275
+ if self.sigmas.size != self.num_classes:
276
+ raise RuntimeError(
277
+ f"'sigmas' buffer has shape {self.sigmas.shape}; "
278
+ f"expected ({self.num_classes},). Call 'set_sigmas(...)'."
279
+ )
280
+
281
+ used_sigmas = self.sigmas[labels].reshape(-1, 1, 1, 1)
282
+ out = out / used_sigmas
283
+
284
+ return out
285
+
286
+ @lucid.no_grad()
287
+ def set_sigmas(self, sigmas: Tensor) -> None:
288
+ if sigmas.ndim != 1:
289
+ raise ValueError("sigmas must be 1D.")
290
+ if sigmas.size != self.num_classes:
291
+ raise ValueError(
292
+ f"sigmas length ({sigmas.size}) must match "
293
+ f"num_classes ({self.num_classes})."
294
+ )
295
+ tmp = sigmas.detach()
296
+ tmp.to(self.sigmas.device)
297
+ tmp.to(self.sigmas.dtype)
298
+ self.sigmas.data = tmp.data
299
+
300
+ @staticmethod
301
+ @lucid.no_grad()
302
+ def make_sigmas(sigma_begin: float, sigma_end: float, num_scales: int) -> Tensor:
303
+ if sigma_begin <= 0 or sigma_end <= 0:
304
+ raise ValueError("sigmas must be positive.")
305
+ if sigma_begin <= sigma_end:
306
+ raise ValueError(
307
+ "Expected sigma_begin > sigma_end (descending noise schedule)."
308
+ )
309
+ if num_scales < 2:
310
+ raise ValueError("num_scales must be >= 2.")
311
+
312
+ return lucid.exp(
313
+ lucid.linspace(math.log(sigma_begin), math.log(sigma_end), num_scales)
314
+ )
315
+
316
+ def get_loss(self, x: Tensor) -> tuple[Tensor, Tensor]:
317
+ batch_size = x.shape[0]
318
+ labels = lucid.random.randint(
319
+ 0, self.sigmas.shape[0], (batch_size,), device=x.device
320
+ ).long()
321
+ used_sigmas = self.sigmas[labels].reshape(batch_size, 1, 1, 1)
322
+
323
+ noise = lucid.random.randn(x.shape, device=x.device)
324
+ perturbed = x + used_sigmas * noise
325
+ score = self.forward(perturbed, labels)
326
+
327
+ loss = lucid.sum((score * used_sigmas + noise) ** 2, axis=(1, 2, 3)).mean()
328
+ return loss, labels
329
+
330
+ @lucid.no_grad()
331
+ def sample(
332
+ self,
333
+ n_samples: int,
334
+ image_size: int,
335
+ in_channels: int,
336
+ n_steps_each: int,
337
+ step_lr: float,
338
+ clip: bool = True,
339
+ denoise: bool = False,
340
+ init: Tensor | None = None,
341
+ init_dist: Literal["uniform", "normal"] = "uniform",
342
+ verbose: bool = True,
343
+ ) -> Tensor:
344
+ self.eval()
345
+ if init is None:
346
+ if init_dist == "uniform":
347
+ x = lucid.random.uniform(
348
+ -1.0,
349
+ 1.0,
350
+ (n_samples, in_channels, image_size, image_size),
351
+ device=self.device,
352
+ )
353
+ elif init_dist == "normal":
354
+ x = lucid.random.randn(
355
+ n_samples, in_channels, image_size, image_size, device=self.device
356
+ )
357
+ else:
358
+ raise ValueError("init_dist must be either 'uniform' or 'normal'.")
359
+
360
+ else:
361
+ x = init.to(self.device)
362
+ if x.shape != (n_samples, in_channels, image_size, image_size):
363
+ raise ValueError(
364
+ f"init has shape {x.shape} but expected "
365
+ f"{(n_samples, in_channels, image_size, image_size)}."
366
+ )
367
+
368
+ from tqdm import tqdm
369
+
370
+ total = int(self.sigmas.shape[0]) * int(n_steps_each)
371
+ pbar = (
372
+ tqdm(total=total, desc="Sampling", dynamic_ncols=True) if verbose else None
373
+ )
374
+ for i, sigma in enumerate(self.sigmas):
375
+ labels = lucid.full(n_samples, i, device=self.device, dtype=lucid.Long)
376
+ step_size = step_lr * (sigma / self.sigmas[-1]) ** 2
377
+
378
+ for j in range(n_steps_each):
379
+ grad = self.forward(x, labels)
380
+ noise = lucid.random.randn(x.shape, device=self.device)
381
+
382
+ x = x + step_size * grad + lucid.sqrt(2.0 * step_size) * noise
383
+ if clip:
384
+ x = x.clip(-1.0, 1.0)
385
+ if verbose:
386
+ pbar.update(1)
387
+ pbar.set_postfix(sigma=f"{sigma.item():.4f}", l=i, t=j)
388
+
389
+ if denoise:
390
+ last_label = lucid.full(
391
+ n_samples,
392
+ self.sigmas.shape[0] - 1,
393
+ device=self.device,
394
+ dtype=lucid.Long,
395
+ )
396
+ x = x + (self.sigmas[-1] ** 2) * self.forward(x, last_label)
397
+ if clip:
398
+ x = x.clip(-1.0, 1.0)
399
+
400
+ if verbose:
401
+ pbar.close()
402
+ return x
@@ -0,0 +1 @@
1
+ pass