lt-tensor 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,596 @@
1
+ import gc
2
+ import random
3
+ import numpy as np
4
+ from lt_utils.types_utils import is_str
5
+ from ._torch_commons import *
6
+ from lt_utils.misc_utils import log_traceback, cache_wrapper
7
+ from lt_utils.files_ops import load_json, load_yaml, save_json, save_yaml
8
+ import math
9
+
10
+
11
+ def log_tensor(
12
+ item: Union[Tensor, np.ndarray],
13
+ title: Optional[str] = None,
14
+ print_details: bool = True,
15
+ print_tensor: bool = False,
16
+ dim: Optional[int] = None,
17
+ ):
18
+ assert isinstance(item, (Tensor, np.ndarray))
19
+ has_title = is_str(title)
20
+
21
+ if has_title:
22
+ print("========[" + title.title() + "]========")
23
+ _b = 20 + len(title.strip())
24
+ print(f"shape: {item.shape}")
25
+ print(f"dtype: {item.dtype}")
26
+ if not print_details:
27
+ print(f"ndim: {item.ndim}")
28
+ if isinstance(item, Tensor):
29
+ print(f"device: {item.device}")
30
+ print(f"min: {item.min():.4f}")
31
+ print(f"max: {item.max():.4f}")
32
+ print(f"std: {item.std(dim=dim):.4f}")
33
+ print(f"mean: {item.mean(dim=dim):.4f}")
34
+ if print_tensor:
35
+ print(item)
36
+ if has_title:
37
+ print("".join(["-"] * _b), "\n")
38
+
39
+
40
+ def set_seed(seed: int):
41
+ """Set random seed for reproducibility."""
42
+ torch.manual_seed(seed)
43
+ np.random.seed(seed)
44
+ random.seed(seed)
45
+
46
+ if torch.cuda.is_available():
47
+ torch.cuda.manual_seed_all(seed)
48
+ if torch.mps.is_available():
49
+ torch.mps.manual_seed(seed)
50
+ if torch.xpu.is_available():
51
+ torch.xpu.manual_seed_all(seed)
52
+
53
+
54
+ def count_parameters(model: Module) -> int:
55
+ """Returns total number of trainable parameters."""
56
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
57
+
58
+
59
+ def freeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
60
+ """Freezes all model parameters except specified layers."""
61
+ no_exceptions = not except_layers
62
+ for name, param in model.named_parameters():
63
+ if no_exceptions:
64
+ param.requires_grad_(False)
65
+ elif any(layer in name for layer in except_layers):
66
+ param.requires_grad_(False)
67
+
68
+
69
+ def freeze_selected_weights(model: Module, target_layers: list[str]):
70
+ """Freezes only parameters on specified layers."""
71
+ for name, param in model.named_parameters():
72
+ if any(layer in name for layer in target_layers):
73
+ param.requires_grad_(False)
74
+
75
+
76
+ def unfreeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
77
+ """Unfreezes all model parameters except specified layers."""
78
+ no_exceptions = not except_layers
79
+ for name, param in model.named_parameters():
80
+ if no_exceptions:
81
+ param.requires_grad_(True)
82
+ elif not any(layer in name for layer in except_layers):
83
+ param.requires_grad_(True)
84
+
85
+
86
+ def unfreeze_selected_weights(model: Module, target_layers: list[str]):
87
+ """Unfreezes only parameters on specified layers."""
88
+ for name, param in model.named_parameters():
89
+ if not any(layer in name for layer in target_layers):
90
+ param.requires_grad_(True)
91
+
92
+
93
+ def clip_gradients(model: Module, max_norm: float = 1.0):
94
+ """Applies gradient clipping."""
95
+ return nn.utils.clip_grad_norm_(model.parameters(), max_norm)
96
+
97
+
98
+ def detach_hidden(hidden):
99
+ """Detaches hidden states (for RNNs)."""
100
+ if isinstance(hidden, torch.Tensor):
101
+ return hidden.detach()
102
+ else:
103
+ return tuple(detach_hidden(h) for h in hidden)
104
+
105
+
106
+ def tensor_summary(tensor: torch.Tensor) -> str:
107
+ """Prints min/max/mean/std of a tensor for debugging."""
108
+ return f"Shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}, min: {tensor.min():.4f}, max: {tensor.max():.4f}, mean: {tensor.mean():.4f}, std: {tensor.std():.4f}"
109
+
110
+
111
+ def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
112
+ """One-hot encodes a tensor of labels."""
113
+ return F.one_hot(labels, num_classes).float()
114
+
115
+
116
+ def safe_divide(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
117
+ """Safe division for tensors (prevents divide-by-zero)."""
118
+ return a / (b + eps)
119
+
120
+
121
+ def batch_pad(tensors: list[torch.Tensor], padding_value: float = 0.0) -> torch.Tensor:
122
+ """Pads a list of tensors to the same shape (assumes 2D+ tensors)."""
123
+ max_shape = [
124
+ max(s[i] for s in [t.shape for t in tensors]) for i in range(tensors[0].dim())
125
+ ]
126
+ padded = []
127
+ for t in tensors:
128
+ pad_dims = [(0, m - s) for s, m in zip(t.shape, max_shape)]
129
+ pad_flat = [p for pair in reversed(pad_dims) for p in pair] # reverse for F.pad
130
+ padded.append(F.pad(t, pad_flat, value=padding_value))
131
+ return torch.stack(padded)
132
+
133
+
134
+ def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
135
+ """Randomly samples values from tensor for preview."""
136
+ flat = tensor.flatten()
137
+ idx = torch.randperm(len(flat))[:num_samples]
138
+ return flat[idx]
139
+
140
+
141
+ class TorchCacheUtils:
142
+ cached_shortcuts: dict[str, Callable[[None], None]] = {}
143
+
144
+ has_cuda: bool = torch.cuda.is_available()
145
+ has_xpu: bool = torch.xpu.is_available()
146
+ has_mps: bool = torch.mps.is_available()
147
+
148
+ _ignore: list[str] = []
149
+
150
+ def __init__(self):
151
+ pass
152
+
153
+ def _apply_clear(self, device: str):
154
+ if device in self._ignore:
155
+ gc.collect()
156
+ return
157
+ try:
158
+ clear_fn = self.cached_shortcuts.get(
159
+ device, getattr(torch, device).empty_cache
160
+ )
161
+ if device not in self.cached_shortcuts:
162
+ self.cached_shortcuts.update({device: clear_fn})
163
+
164
+ except Exception as e:
165
+ print(e)
166
+ self._ignore.append(device)
167
+
168
+ def clear(self):
169
+ gc.collect()
170
+ if self.has_xpu:
171
+ self._apply_clear("xpu")
172
+ if self.has_cuda:
173
+ self._apply_clear("cuda")
174
+ if self.has_mps:
175
+ self._apply_clear("mps")
176
+ gc.collect()
177
+
178
+
179
+ _clear_cache_cls = TorchCacheUtils()
180
+
181
+
182
+ def clear_cache():
183
+ _clear_cache_cls.clear()
184
+
185
+
186
+ @cache_wrapper
187
+ def default_device(idx: Optional[int] = None):
188
+ try:
189
+ if torch.cuda.is_available():
190
+ return torch.device("cuda", idx)
191
+ if torch.xpu.is_available():
192
+ return torch.device("xpu", idx)
193
+ if torch.mps.is_available():
194
+ return torch.device("mps", idx)
195
+ if hasattr(torch, "is_vulkan_available"):
196
+ if getattr(torch, "is_vulkan_available")():
197
+ return torch.device("vulkan", idx)
198
+ except:
199
+ pass
200
+ finally:
201
+ return torch.device(torch.zeros(1).device)
202
+
203
+
204
+ class Packing:
205
+ """
206
+ example:
207
+
208
+ ```
209
+ x_lengths = torch.tensor([5, 3, 6])
210
+ x_padded = torch.randn(3, 6, 256) # padded input [B, T, C]
211
+
212
+ # 1. RNN expects packed input
213
+ x_packed = Padding.pack_sequence(x_padded, x_lengths)
214
+ output_packed, _ = rnn(x_packed)
215
+
216
+ # 2. Recover padded for loss
217
+ output = Padding.unpack_sequence(output_packed, total_length=x_padded.size(1))
218
+
219
+ # 3. Mask for loss
220
+ mask = torch.arange(x_padded.size(1))[None, :] < x_lengths[:, None]
221
+ loss = (F.mse_loss(output, target, reduction="none") * mask.unsqueeze(-1)).sum() / mask.sum()
222
+ ```
223
+ """
224
+
225
+ @staticmethod
226
+ def pack_sequence(x: Tensor, lengths: Tensor):
227
+ """
228
+ Pack padded sequence for RNN/LSTM.
229
+ Args:
230
+ x (Tensor): Padded input [B, T, C]
231
+ lengths (Tensor): Actual lengths [B]
232
+ Returns:
233
+ PackedSequence
234
+
235
+ """
236
+ return nn.utils.rnn.pack_padded_sequence(
237
+ x,
238
+ lengths.cpu().numpy(),
239
+ batch_first=True,
240
+ enforce_sorted=False,
241
+ )
242
+
243
+ @staticmethod
244
+ def unpack_sequence(packed, total_length: int) -> Tensor:
245
+ """Unpacks RNN PackedSequence to padded [B, T, C]."""
246
+ output, _ = nn.utils.rnn.pad_packed_sequence(
247
+ packed,
248
+ batch_first=True,
249
+ total_length=total_length,
250
+ )
251
+ return output
252
+
253
+
254
+ class Padding:
255
+
256
+ @staticmethod
257
+ def pad_to(x: Tensor, target_length: int, pad_value: float = 0.0) -> Tensor:
258
+ """
259
+ Pad input tensor along time axis (dim=1) to target length.
260
+ Args:
261
+ x (Tensor): Input tensor [B, T, C]
262
+ target_length (int): Target time length
263
+ pad_value (float): Fill value
264
+ Returns:
265
+ Padded tensor [B, target_length, C]
266
+ """
267
+ B, T, C = x.size()
268
+ if T >= target_length:
269
+ return x
270
+ pad = x.new_full((B, target_length - T, C), pad_value)
271
+ return torch.cat([x, pad], dim=1)
272
+
273
+ @staticmethod
274
+ def pad_sequence(
275
+ inputs: Tensor,
276
+ size: int,
277
+ direction: Literal["left", "right"] = "left",
278
+ pad_id: Union[int, float] = 0,
279
+ ) -> Tensor:
280
+ """
281
+ Pads a single tensor to the specified size in 1D.
282
+ Args:
283
+ inputs (Tensor): Tensor of shape [T] or [B, T]
284
+ size (int): Desired size along the last dimension
285
+ direction (str): 'left' or 'right'
286
+ pad_id (int): Value to pad with
287
+ Returns:
288
+ Padded tensor
289
+ """
290
+ total = size - inputs.shape[-1]
291
+ if total < 1:
292
+ return inputs
293
+ pad_config = (total, 0) if direction == "left" else (0, total)
294
+ return F.pad(inputs, pad_config, value=pad_id)
295
+
296
+ @staticmethod
297
+ def pad_batch_1d(
298
+ batch: List[Tensor],
299
+ pad_value: float = 0.0,
300
+ pad_to_multiple: Optional[int] = None,
301
+ direction: Literal["left", "right"] = "right",
302
+ ) -> Tuple[Tensor, Tensor]:
303
+ """
304
+ Pad list of 1D tensors to same length with optional multiple alignment.
305
+ Returns:
306
+ Padded tensor [B, T], Lengths [B]
307
+ """
308
+ lengths = torch.tensor([t.size(0) for t in batch])
309
+ max_len = lengths.max().item()
310
+
311
+ if pad_to_multiple:
312
+ max_len = (
313
+ (max_len + pad_to_multiple - 1) // pad_to_multiple
314
+ ) * pad_to_multiple
315
+
316
+ padded = []
317
+ for t in batch:
318
+ padded.append(Padding.pad_sequence(t, max_len, direction, pad_value))
319
+ return torch.stack(padded), lengths
320
+
321
+ @staticmethod
322
+ def pad_batch_2d(
323
+ batch: List[Tensor],
324
+ pad_value: float = 0.0,
325
+ pad_to_multiple: Optional[int] = None,
326
+ direction: Literal["left", "right"] = "right",
327
+ ) -> Tuple[Tensor, Tensor]:
328
+ """
329
+ Pad list of 2D tensors (e.g. [T, D]) to same T.
330
+ Returns:
331
+ Padded tensor [B, T, D], Lengths [B]
332
+ """
333
+ lengths = torch.tensor([t.size(0) for t in batch])
334
+ feat_dim = batch[0].size(1)
335
+ max_len = lengths.max().item()
336
+
337
+ if pad_to_multiple:
338
+ max_len = (
339
+ (max_len + pad_to_multiple - 1) // pad_to_multiple
340
+ ) * pad_to_multiple
341
+
342
+ padded = []
343
+ for t in batch:
344
+ pad_len = max_len - t.size(0)
345
+ if direction == "left":
346
+ pad_tensor = t.new_full((pad_len, feat_dim), pad_value)
347
+ padded.append(torch.cat([pad_tensor, t], dim=0))
348
+ else:
349
+ pad_tensor = t.new_full((pad_len, feat_dim), pad_value)
350
+ padded.append(torch.cat([t, pad_tensor], dim=0))
351
+ return torch.stack(padded), lengths
352
+
353
+ # --- Batching ---
354
+
355
+ @staticmethod
356
+ def pad_batch_1d(
357
+ batch: List[Tensor],
358
+ pad_value: float = 0.0,
359
+ pad_to_multiple: Optional[int] = None,
360
+ direction: Literal["left", "right"] = "right",
361
+ ) -> Tuple[Tensor, Tensor]:
362
+ """Pads list of 1D tensors → [B, T]"""
363
+ lengths = torch.tensor([t.size(0) for t in batch])
364
+ max_len = lengths.max().item()
365
+ if pad_to_multiple:
366
+ max_len = (
367
+ (max_len + pad_to_multiple - 1) // pad_to_multiple
368
+ ) * pad_to_multiple
369
+
370
+ padded = [Padding.pad_sequence(t, max_len, direction, pad_value) for t in batch]
371
+ return torch.stack(padded), lengths
372
+
373
+ @staticmethod
374
+ def pad_batch_2d(
375
+ batch: List[Tensor],
376
+ pad_value: float = 0.0,
377
+ pad_to_multiple: Optional[int] = None,
378
+ direction: Literal["left", "right"] = "right",
379
+ ) -> Tuple[Tensor, Tensor]:
380
+ """Pads list of 2D tensors [T, D] → [B, T, D]"""
381
+ lengths = torch.tensor([t.size(0) for t in batch])
382
+ feat_dim = batch[0].size(1)
383
+ max_len = lengths.max().item()
384
+ if pad_to_multiple:
385
+ max_len = (
386
+ (max_len + pad_to_multiple - 1) // pad_to_multiple
387
+ ) * pad_to_multiple
388
+
389
+ padded = []
390
+ for t in batch:
391
+ pad_len = max_len - t.size(0)
392
+ pad_tensor = t.new_full((pad_len, feat_dim), pad_value)
393
+ padded_tensor = (
394
+ torch.cat([pad_tensor, t], dim=0)
395
+ if direction == "left"
396
+ else torch.cat([t, pad_tensor], dim=0)
397
+ )
398
+ padded.append(padded_tensor)
399
+ return torch.stack(padded), lengths
400
+
401
+ @staticmethod
402
+ def pad_batch_nd(
403
+ batch: List[Tensor],
404
+ pad_value: float = 0.0,
405
+ dim: int = 0,
406
+ pad_to_multiple: Optional[int] = None,
407
+ ) -> Tuple[Tensor, Tensor]:
408
+ """
409
+ General N-D padding along time axis (dim=0, usually).
410
+ Handles shapes like:
411
+ [T, C] → [B, T, C]
412
+ [T, H, W] → [B, T, H, W]
413
+ """
414
+ lengths = torch.tensor([t.size(dim) for t in batch])
415
+ max_len = lengths.max().item()
416
+ if pad_to_multiple:
417
+ max_len = (
418
+ (max_len + pad_to_multiple - 1) // pad_to_multiple
419
+ ) * pad_to_multiple
420
+
421
+ padded = []
422
+ for t in batch:
423
+ pad_len = max_len - t.size(dim)
424
+ pad_shape = list(t.shape)
425
+ pad_shape[dim] = pad_len
426
+ pad_tensor = t.new_full(pad_shape, pad_value)
427
+ padded_tensor = torch.cat([t, pad_tensor], dim=dim)
428
+ padded.append(padded_tensor)
429
+
430
+ return torch.stack(padded), lengths
431
+
432
+
433
+ class MaskUtils:
434
+
435
+ @staticmethod
436
+ def apply_mask(x: Tensor, mask: Tensor, fill_value: Number = 0) -> Tensor:
437
+ """
438
+ Apply a mask to a tensor, setting masked positions to `fill_value`.
439
+ Args:
440
+ x (Tensor): Input tensor of shape [..., T, D].
441
+ mask (Tensor): Mask of shape [..., T] where True = masked.
442
+ fill_value (Number): Value to fill masked positions with.
443
+ Returns:
444
+ Tensor: Masked tensor.
445
+ """
446
+ return x.masked_fill(mask.unsqueeze(-1), fill_value)
447
+
448
+ @staticmethod
449
+ def get_padding_mask(
450
+ lengths: Optional[Tensor] = None,
451
+ tokens: Optional[Tensor] = None,
452
+ padding_id: int = 0,
453
+ ) -> Tensor:
454
+ """
455
+ Generate a padding mask: 1 for real tokens, 0 for padding.
456
+ Args:
457
+ lengths (Tensor): Tensor of shape [B] with sequence lengths.
458
+ tokens (Tensor): Tensor of shape [B, T] with token ids.
459
+ padding_id (int): Padding token id (default=0).
460
+ Returns:
461
+ Tensor: Boolean mask of shape [B, T].
462
+ """
463
+ assert (
464
+ tokens is not None or lengths is not None
465
+ ), "Either tokens or lengths must be provided."
466
+
467
+ if tokens is not None:
468
+ return tokens != padding_id
469
+
470
+ B = lengths.size(0)
471
+ max_len = lengths.max().item()
472
+ arange = torch.arange(max_len, device=lengths.device).unsqueeze(0).expand(B, -1)
473
+ return arange < lengths.unsqueeze(1)
474
+
475
+ @staticmethod
476
+ def get_padding_mask_fps(lengths: Tensor) -> Tensor:
477
+ """
478
+ Legacy-style padding mask using 1-based comparison.
479
+ """
480
+ mask = (
481
+ torch.arange(lengths.max(), device=lengths.device)
482
+ .unsqueeze(0)
483
+ .expand(lengths.shape[0], -1)
484
+ )
485
+ return (mask + 1) > lengths.unsqueeze(1)
486
+
487
+ @staticmethod
488
+ def get_causal_mask(
489
+ size: Union[int, tuple[int, ...]],
490
+ device: Optional[Union[str, torch.device]] = None,
491
+ ) -> Tensor:
492
+ """
493
+ Generate a causal mask for self-attention.
494
+ Args:
495
+ size (int or tuple): Size (T) or (1, T, T)
496
+ Returns:
497
+ Tensor: [1, T, T] boolean causal mask
498
+ """
499
+ if isinstance(size, int):
500
+ size = (1, size, size)
501
+ return torch.tril(torch.ones(size, dtype=torch.bool, device=device))
502
+
503
+ @staticmethod
504
+ def combine_masks(pad_mask: Tensor, causal_mask: Tensor) -> Tensor:
505
+ """
506
+ Combine padding and causal masks.
507
+ Args:
508
+ pad_mask (Tensor): [B, T] padding mask
509
+ causal_mask (Tensor): [1, T, T] causal mask
510
+ Returns:
511
+ Tensor: [B, T, T] combined mask
512
+ """
513
+ return (
514
+ causal_mask & pad_mask.unsqueeze(1).expand(-1, pad_mask.size(1), -1).bool()
515
+ )
516
+
517
+
518
+ def masked_cross_entropy(
519
+ logits: torch.Tensor, # [B, T, V]
520
+ targets: torch.Tensor, # [B, T]
521
+ lengths: torch.Tensor, # [B]
522
+ reduction: str = "mean",
523
+ ) -> torch.Tensor:
524
+ """
525
+ CrossEntropyLoss with masking for variable-length sequences.
526
+ - logits: unnormalized scores [B, T, V]
527
+ - targets: ground truth indices [B, T]
528
+ - lengths: actual sequence lengths [B]
529
+ """
530
+ B, T, V = logits.size()
531
+ logits = logits.view(-1, V)
532
+ targets = targets.view(-1)
533
+
534
+ # Create mask
535
+ mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
536
+ mask = mask.reshape(-1)
537
+
538
+ # Apply CE only where mask == True
539
+ loss = F.cross_entropy(
540
+ logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
541
+ )
542
+ if reduction == "none":
543
+ return loss
544
+ return loss
545
+
546
+
547
+ class NoiseScheduler(Module):
548
+ def __init__(self, timesteps: int = 512):
549
+ super().__init__()
550
+
551
+ betas = torch.linspace(1e-4, 0.02, timesteps)
552
+ alphas = 1.0 - betas
553
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
554
+
555
+ self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
556
+ self.register_buffer(
557
+ "sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
558
+ )
559
+
560
+ self.timesteps = timesteps
561
+ self.default_noise = math.sqrt(1.25)
562
+
563
+ def get_random_noise(
564
+ self, min_max: Tuple[float, float] = (-3, 3), seed: int = 0
565
+ ) -> float:
566
+ if seed > 0:
567
+ random.seed(seed)
568
+ return random.uniform(*min_max)
569
+
570
+ def set_noise(
571
+ self,
572
+ seed: int = 0,
573
+ min_max: Tuple[float, float] = (-3, 3),
574
+ default: bool = False,
575
+ ):
576
+ self.default_noise = (
577
+ math.sqrt(1.25) if default else self.get_random_noise(min_max, seed)
578
+ )
579
+
580
+ def forward(
581
+ self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
582
+ ) -> Tensor:
583
+ if t < 0 or t >= self.timesteps:
584
+ raise ValueError(
585
+ f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
586
+ )
587
+
588
+ if noise is None:
589
+ noise = self.default_noise
590
+
591
+ if isinstance(noise, (float, int)):
592
+ noise = torch.randn_like(x_0) * noise
593
+
594
+ alpha_term = self.sqrt_alpha_cumprod[t] * x_0
595
+ noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
596
+ return alpha_term + noise_term
@@ -0,0 +1,2 @@
1
+ from . import transformer_models as Transformer
2
+ from . import basic, residual
@@ -0,0 +1,65 @@
1
+ from .._torch_commons import *
2
+ from .._basics import Model
3
+ from ..transform import get_sinusoidal_embedding
4
+
5
+
6
+ class FeedForward(Model):
7
+ def __init__(
8
+ self,
9
+ d_model: int,
10
+ ff_dim: int,
11
+ dropout: float = 0.01,
12
+ activation: nn.Module = nn.LeakyReLU(0.1),
13
+ normalizer: nn.Module = nn.Identity(),
14
+ ):
15
+ """Creates a Feed-Forward Layer, with the chosen activation function and the normalizer."""
16
+ super().__init__()
17
+ self.net = nn.Sequential(
18
+ nn.Linear(d_model, ff_dim),
19
+ activation,
20
+ nn.Dropout(dropout),
21
+ nn.Linear(ff_dim, d_model),
22
+ normalizer,
23
+ )
24
+
25
+ def forward(self, x: Tensor):
26
+ return self.net(x)
27
+
28
+
29
+ class MLP(Model):
30
+ def __init__(
31
+ self,
32
+ d_model: int,
33
+ ff_dim: int,
34
+ n_classes: int,
35
+ dropout: float = 0.01,
36
+ activation: nn.Module = nn.LeakyReLU(0.1),
37
+ normalizer: nn.Module = nn.Identity(),
38
+ ):
39
+ """Creates a MLP block, with the chosen activation function and the normalizer."""
40
+ super().__init__()
41
+ self.net = nn.Sequential(
42
+ nn.Linear(d_model, ff_dim),
43
+ activation,
44
+ nn.Dropout(dropout),
45
+ nn.Linear(ff_dim, n_classes),
46
+ normalizer,
47
+ )
48
+
49
+ def forward(self, x: Tensor):
50
+ return self.net(x)
51
+
52
+
53
+ class TimestepEmbedder(nn.Module):
54
+ def __init__(self, dim_emb: int, proj_dim: int):
55
+ super().__init__()
56
+ self.mlp = nn.Sequential(
57
+ nn.Linear(dim_emb, proj_dim),
58
+ nn.SiLU(),
59
+ nn.Linear(proj_dim, proj_dim),
60
+ )
61
+
62
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
63
+ # t: [B] (long)
64
+ emb = get_sinusoidal_embedding(t, self.mlp[0].in_features) # [B, dim_emb]
65
+ return self.mlp(emb) # [B, proj_dim]
@@ -0,0 +1,6 @@
1
+ __all__ = [
2
+ "ResidualBlock1D",
3
+ "Downsample1D",
4
+ "Upsample1D",
5
+ "DiffusionUNet",
6
+ ]