lt-tensor 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +0 -0
- lt_tensor/_basics.py +244 -0
- lt_tensor/_torch_commons.py +12 -0
- lt_tensor/lr_schedulers.py +108 -0
- lt_tensor/math_ops.py +120 -0
- lt_tensor/misc_utils.py +596 -0
- lt_tensor/model_zoo/__init__.py +2 -0
- lt_tensor/model_zoo/basic.py +65 -0
- lt_tensor/model_zoo/diffusion/__init__.py +6 -0
- lt_tensor/model_zoo/diffusion/models.py +114 -0
- lt_tensor/model_zoo/residual.py +236 -0
- lt_tensor/model_zoo/transformer_models/__init__.py +6 -0
- lt_tensor/model_zoo/transformer_models/models.py +132 -0
- lt_tensor/model_zoo/transformer_models/positional_encoders.py +95 -0
- lt_tensor/monotonic_align.py +70 -0
- lt_tensor/transform.py +113 -0
- lt_tensor-0.0.1.dev0.dist-info/METADATA +33 -0
- lt_tensor-0.0.1.dev0.dist-info/RECORD +21 -0
- lt_tensor-0.0.1.dev0.dist-info/WHEEL +5 -0
- lt_tensor-0.0.1.dev0.dist-info/licenses/LICENSE +201 -0
- lt_tensor-0.0.1.dev0.dist-info/top_level.txt +1 -0
lt_tensor/misc_utils.py
ADDED
@@ -0,0 +1,596 @@
|
|
1
|
+
import gc
|
2
|
+
import random
|
3
|
+
import numpy as np
|
4
|
+
from lt_utils.types_utils import is_str
|
5
|
+
from ._torch_commons import *
|
6
|
+
from lt_utils.misc_utils import log_traceback, cache_wrapper
|
7
|
+
from lt_utils.files_ops import load_json, load_yaml, save_json, save_yaml
|
8
|
+
import math
|
9
|
+
|
10
|
+
|
11
|
+
def log_tensor(
|
12
|
+
item: Union[Tensor, np.ndarray],
|
13
|
+
title: Optional[str] = None,
|
14
|
+
print_details: bool = True,
|
15
|
+
print_tensor: bool = False,
|
16
|
+
dim: Optional[int] = None,
|
17
|
+
):
|
18
|
+
assert isinstance(item, (Tensor, np.ndarray))
|
19
|
+
has_title = is_str(title)
|
20
|
+
|
21
|
+
if has_title:
|
22
|
+
print("========[" + title.title() + "]========")
|
23
|
+
_b = 20 + len(title.strip())
|
24
|
+
print(f"shape: {item.shape}")
|
25
|
+
print(f"dtype: {item.dtype}")
|
26
|
+
if not print_details:
|
27
|
+
print(f"ndim: {item.ndim}")
|
28
|
+
if isinstance(item, Tensor):
|
29
|
+
print(f"device: {item.device}")
|
30
|
+
print(f"min: {item.min():.4f}")
|
31
|
+
print(f"max: {item.max():.4f}")
|
32
|
+
print(f"std: {item.std(dim=dim):.4f}")
|
33
|
+
print(f"mean: {item.mean(dim=dim):.4f}")
|
34
|
+
if print_tensor:
|
35
|
+
print(item)
|
36
|
+
if has_title:
|
37
|
+
print("".join(["-"] * _b), "\n")
|
38
|
+
|
39
|
+
|
40
|
+
def set_seed(seed: int):
|
41
|
+
"""Set random seed for reproducibility."""
|
42
|
+
torch.manual_seed(seed)
|
43
|
+
np.random.seed(seed)
|
44
|
+
random.seed(seed)
|
45
|
+
|
46
|
+
if torch.cuda.is_available():
|
47
|
+
torch.cuda.manual_seed_all(seed)
|
48
|
+
if torch.mps.is_available():
|
49
|
+
torch.mps.manual_seed(seed)
|
50
|
+
if torch.xpu.is_available():
|
51
|
+
torch.xpu.manual_seed_all(seed)
|
52
|
+
|
53
|
+
|
54
|
+
def count_parameters(model: Module) -> int:
|
55
|
+
"""Returns total number of trainable parameters."""
|
56
|
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
57
|
+
|
58
|
+
|
59
|
+
def freeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
|
60
|
+
"""Freezes all model parameters except specified layers."""
|
61
|
+
no_exceptions = not except_layers
|
62
|
+
for name, param in model.named_parameters():
|
63
|
+
if no_exceptions:
|
64
|
+
param.requires_grad_(False)
|
65
|
+
elif any(layer in name for layer in except_layers):
|
66
|
+
param.requires_grad_(False)
|
67
|
+
|
68
|
+
|
69
|
+
def freeze_selected_weights(model: Module, target_layers: list[str]):
|
70
|
+
"""Freezes only parameters on specified layers."""
|
71
|
+
for name, param in model.named_parameters():
|
72
|
+
if any(layer in name for layer in target_layers):
|
73
|
+
param.requires_grad_(False)
|
74
|
+
|
75
|
+
|
76
|
+
def unfreeze_all_except(model: Module, except_layers: Optional[list[str]] = None):
|
77
|
+
"""Unfreezes all model parameters except specified layers."""
|
78
|
+
no_exceptions = not except_layers
|
79
|
+
for name, param in model.named_parameters():
|
80
|
+
if no_exceptions:
|
81
|
+
param.requires_grad_(True)
|
82
|
+
elif not any(layer in name for layer in except_layers):
|
83
|
+
param.requires_grad_(True)
|
84
|
+
|
85
|
+
|
86
|
+
def unfreeze_selected_weights(model: Module, target_layers: list[str]):
|
87
|
+
"""Unfreezes only parameters on specified layers."""
|
88
|
+
for name, param in model.named_parameters():
|
89
|
+
if not any(layer in name for layer in target_layers):
|
90
|
+
param.requires_grad_(True)
|
91
|
+
|
92
|
+
|
93
|
+
def clip_gradients(model: Module, max_norm: float = 1.0):
|
94
|
+
"""Applies gradient clipping."""
|
95
|
+
return nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
96
|
+
|
97
|
+
|
98
|
+
def detach_hidden(hidden):
|
99
|
+
"""Detaches hidden states (for RNNs)."""
|
100
|
+
if isinstance(hidden, torch.Tensor):
|
101
|
+
return hidden.detach()
|
102
|
+
else:
|
103
|
+
return tuple(detach_hidden(h) for h in hidden)
|
104
|
+
|
105
|
+
|
106
|
+
def tensor_summary(tensor: torch.Tensor) -> str:
|
107
|
+
"""Prints min/max/mean/std of a tensor for debugging."""
|
108
|
+
return f"Shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}, min: {tensor.min():.4f}, max: {tensor.max():.4f}, mean: {tensor.mean():.4f}, std: {tensor.std():.4f}"
|
109
|
+
|
110
|
+
|
111
|
+
def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
|
112
|
+
"""One-hot encodes a tensor of labels."""
|
113
|
+
return F.one_hot(labels, num_classes).float()
|
114
|
+
|
115
|
+
|
116
|
+
def safe_divide(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
|
117
|
+
"""Safe division for tensors (prevents divide-by-zero)."""
|
118
|
+
return a / (b + eps)
|
119
|
+
|
120
|
+
|
121
|
+
def batch_pad(tensors: list[torch.Tensor], padding_value: float = 0.0) -> torch.Tensor:
|
122
|
+
"""Pads a list of tensors to the same shape (assumes 2D+ tensors)."""
|
123
|
+
max_shape = [
|
124
|
+
max(s[i] for s in [t.shape for t in tensors]) for i in range(tensors[0].dim())
|
125
|
+
]
|
126
|
+
padded = []
|
127
|
+
for t in tensors:
|
128
|
+
pad_dims = [(0, m - s) for s, m in zip(t.shape, max_shape)]
|
129
|
+
pad_flat = [p for pair in reversed(pad_dims) for p in pair] # reverse for F.pad
|
130
|
+
padded.append(F.pad(t, pad_flat, value=padding_value))
|
131
|
+
return torch.stack(padded)
|
132
|
+
|
133
|
+
|
134
|
+
def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
|
135
|
+
"""Randomly samples values from tensor for preview."""
|
136
|
+
flat = tensor.flatten()
|
137
|
+
idx = torch.randperm(len(flat))[:num_samples]
|
138
|
+
return flat[idx]
|
139
|
+
|
140
|
+
|
141
|
+
class TorchCacheUtils:
|
142
|
+
cached_shortcuts: dict[str, Callable[[None], None]] = {}
|
143
|
+
|
144
|
+
has_cuda: bool = torch.cuda.is_available()
|
145
|
+
has_xpu: bool = torch.xpu.is_available()
|
146
|
+
has_mps: bool = torch.mps.is_available()
|
147
|
+
|
148
|
+
_ignore: list[str] = []
|
149
|
+
|
150
|
+
def __init__(self):
|
151
|
+
pass
|
152
|
+
|
153
|
+
def _apply_clear(self, device: str):
|
154
|
+
if device in self._ignore:
|
155
|
+
gc.collect()
|
156
|
+
return
|
157
|
+
try:
|
158
|
+
clear_fn = self.cached_shortcuts.get(
|
159
|
+
device, getattr(torch, device).empty_cache
|
160
|
+
)
|
161
|
+
if device not in self.cached_shortcuts:
|
162
|
+
self.cached_shortcuts.update({device: clear_fn})
|
163
|
+
|
164
|
+
except Exception as e:
|
165
|
+
print(e)
|
166
|
+
self._ignore.append(device)
|
167
|
+
|
168
|
+
def clear(self):
|
169
|
+
gc.collect()
|
170
|
+
if self.has_xpu:
|
171
|
+
self._apply_clear("xpu")
|
172
|
+
if self.has_cuda:
|
173
|
+
self._apply_clear("cuda")
|
174
|
+
if self.has_mps:
|
175
|
+
self._apply_clear("mps")
|
176
|
+
gc.collect()
|
177
|
+
|
178
|
+
|
179
|
+
_clear_cache_cls = TorchCacheUtils()
|
180
|
+
|
181
|
+
|
182
|
+
def clear_cache():
|
183
|
+
_clear_cache_cls.clear()
|
184
|
+
|
185
|
+
|
186
|
+
@cache_wrapper
|
187
|
+
def default_device(idx: Optional[int] = None):
|
188
|
+
try:
|
189
|
+
if torch.cuda.is_available():
|
190
|
+
return torch.device("cuda", idx)
|
191
|
+
if torch.xpu.is_available():
|
192
|
+
return torch.device("xpu", idx)
|
193
|
+
if torch.mps.is_available():
|
194
|
+
return torch.device("mps", idx)
|
195
|
+
if hasattr(torch, "is_vulkan_available"):
|
196
|
+
if getattr(torch, "is_vulkan_available")():
|
197
|
+
return torch.device("vulkan", idx)
|
198
|
+
except:
|
199
|
+
pass
|
200
|
+
finally:
|
201
|
+
return torch.device(torch.zeros(1).device)
|
202
|
+
|
203
|
+
|
204
|
+
class Packing:
|
205
|
+
"""
|
206
|
+
example:
|
207
|
+
|
208
|
+
```
|
209
|
+
x_lengths = torch.tensor([5, 3, 6])
|
210
|
+
x_padded = torch.randn(3, 6, 256) # padded input [B, T, C]
|
211
|
+
|
212
|
+
# 1. RNN expects packed input
|
213
|
+
x_packed = Padding.pack_sequence(x_padded, x_lengths)
|
214
|
+
output_packed, _ = rnn(x_packed)
|
215
|
+
|
216
|
+
# 2. Recover padded for loss
|
217
|
+
output = Padding.unpack_sequence(output_packed, total_length=x_padded.size(1))
|
218
|
+
|
219
|
+
# 3. Mask for loss
|
220
|
+
mask = torch.arange(x_padded.size(1))[None, :] < x_lengths[:, None]
|
221
|
+
loss = (F.mse_loss(output, target, reduction="none") * mask.unsqueeze(-1)).sum() / mask.sum()
|
222
|
+
```
|
223
|
+
"""
|
224
|
+
|
225
|
+
@staticmethod
|
226
|
+
def pack_sequence(x: Tensor, lengths: Tensor):
|
227
|
+
"""
|
228
|
+
Pack padded sequence for RNN/LSTM.
|
229
|
+
Args:
|
230
|
+
x (Tensor): Padded input [B, T, C]
|
231
|
+
lengths (Tensor): Actual lengths [B]
|
232
|
+
Returns:
|
233
|
+
PackedSequence
|
234
|
+
|
235
|
+
"""
|
236
|
+
return nn.utils.rnn.pack_padded_sequence(
|
237
|
+
x,
|
238
|
+
lengths.cpu().numpy(),
|
239
|
+
batch_first=True,
|
240
|
+
enforce_sorted=False,
|
241
|
+
)
|
242
|
+
|
243
|
+
@staticmethod
|
244
|
+
def unpack_sequence(packed, total_length: int) -> Tensor:
|
245
|
+
"""Unpacks RNN PackedSequence to padded [B, T, C]."""
|
246
|
+
output, _ = nn.utils.rnn.pad_packed_sequence(
|
247
|
+
packed,
|
248
|
+
batch_first=True,
|
249
|
+
total_length=total_length,
|
250
|
+
)
|
251
|
+
return output
|
252
|
+
|
253
|
+
|
254
|
+
class Padding:
|
255
|
+
|
256
|
+
@staticmethod
|
257
|
+
def pad_to(x: Tensor, target_length: int, pad_value: float = 0.0) -> Tensor:
|
258
|
+
"""
|
259
|
+
Pad input tensor along time axis (dim=1) to target length.
|
260
|
+
Args:
|
261
|
+
x (Tensor): Input tensor [B, T, C]
|
262
|
+
target_length (int): Target time length
|
263
|
+
pad_value (float): Fill value
|
264
|
+
Returns:
|
265
|
+
Padded tensor [B, target_length, C]
|
266
|
+
"""
|
267
|
+
B, T, C = x.size()
|
268
|
+
if T >= target_length:
|
269
|
+
return x
|
270
|
+
pad = x.new_full((B, target_length - T, C), pad_value)
|
271
|
+
return torch.cat([x, pad], dim=1)
|
272
|
+
|
273
|
+
@staticmethod
|
274
|
+
def pad_sequence(
|
275
|
+
inputs: Tensor,
|
276
|
+
size: int,
|
277
|
+
direction: Literal["left", "right"] = "left",
|
278
|
+
pad_id: Union[int, float] = 0,
|
279
|
+
) -> Tensor:
|
280
|
+
"""
|
281
|
+
Pads a single tensor to the specified size in 1D.
|
282
|
+
Args:
|
283
|
+
inputs (Tensor): Tensor of shape [T] or [B, T]
|
284
|
+
size (int): Desired size along the last dimension
|
285
|
+
direction (str): 'left' or 'right'
|
286
|
+
pad_id (int): Value to pad with
|
287
|
+
Returns:
|
288
|
+
Padded tensor
|
289
|
+
"""
|
290
|
+
total = size - inputs.shape[-1]
|
291
|
+
if total < 1:
|
292
|
+
return inputs
|
293
|
+
pad_config = (total, 0) if direction == "left" else (0, total)
|
294
|
+
return F.pad(inputs, pad_config, value=pad_id)
|
295
|
+
|
296
|
+
@staticmethod
|
297
|
+
def pad_batch_1d(
|
298
|
+
batch: List[Tensor],
|
299
|
+
pad_value: float = 0.0,
|
300
|
+
pad_to_multiple: Optional[int] = None,
|
301
|
+
direction: Literal["left", "right"] = "right",
|
302
|
+
) -> Tuple[Tensor, Tensor]:
|
303
|
+
"""
|
304
|
+
Pad list of 1D tensors to same length with optional multiple alignment.
|
305
|
+
Returns:
|
306
|
+
Padded tensor [B, T], Lengths [B]
|
307
|
+
"""
|
308
|
+
lengths = torch.tensor([t.size(0) for t in batch])
|
309
|
+
max_len = lengths.max().item()
|
310
|
+
|
311
|
+
if pad_to_multiple:
|
312
|
+
max_len = (
|
313
|
+
(max_len + pad_to_multiple - 1) // pad_to_multiple
|
314
|
+
) * pad_to_multiple
|
315
|
+
|
316
|
+
padded = []
|
317
|
+
for t in batch:
|
318
|
+
padded.append(Padding.pad_sequence(t, max_len, direction, pad_value))
|
319
|
+
return torch.stack(padded), lengths
|
320
|
+
|
321
|
+
@staticmethod
|
322
|
+
def pad_batch_2d(
|
323
|
+
batch: List[Tensor],
|
324
|
+
pad_value: float = 0.0,
|
325
|
+
pad_to_multiple: Optional[int] = None,
|
326
|
+
direction: Literal["left", "right"] = "right",
|
327
|
+
) -> Tuple[Tensor, Tensor]:
|
328
|
+
"""
|
329
|
+
Pad list of 2D tensors (e.g. [T, D]) to same T.
|
330
|
+
Returns:
|
331
|
+
Padded tensor [B, T, D], Lengths [B]
|
332
|
+
"""
|
333
|
+
lengths = torch.tensor([t.size(0) for t in batch])
|
334
|
+
feat_dim = batch[0].size(1)
|
335
|
+
max_len = lengths.max().item()
|
336
|
+
|
337
|
+
if pad_to_multiple:
|
338
|
+
max_len = (
|
339
|
+
(max_len + pad_to_multiple - 1) // pad_to_multiple
|
340
|
+
) * pad_to_multiple
|
341
|
+
|
342
|
+
padded = []
|
343
|
+
for t in batch:
|
344
|
+
pad_len = max_len - t.size(0)
|
345
|
+
if direction == "left":
|
346
|
+
pad_tensor = t.new_full((pad_len, feat_dim), pad_value)
|
347
|
+
padded.append(torch.cat([pad_tensor, t], dim=0))
|
348
|
+
else:
|
349
|
+
pad_tensor = t.new_full((pad_len, feat_dim), pad_value)
|
350
|
+
padded.append(torch.cat([t, pad_tensor], dim=0))
|
351
|
+
return torch.stack(padded), lengths
|
352
|
+
|
353
|
+
# --- Batching ---
|
354
|
+
|
355
|
+
@staticmethod
|
356
|
+
def pad_batch_1d(
|
357
|
+
batch: List[Tensor],
|
358
|
+
pad_value: float = 0.0,
|
359
|
+
pad_to_multiple: Optional[int] = None,
|
360
|
+
direction: Literal["left", "right"] = "right",
|
361
|
+
) -> Tuple[Tensor, Tensor]:
|
362
|
+
"""Pads list of 1D tensors → [B, T]"""
|
363
|
+
lengths = torch.tensor([t.size(0) for t in batch])
|
364
|
+
max_len = lengths.max().item()
|
365
|
+
if pad_to_multiple:
|
366
|
+
max_len = (
|
367
|
+
(max_len + pad_to_multiple - 1) // pad_to_multiple
|
368
|
+
) * pad_to_multiple
|
369
|
+
|
370
|
+
padded = [Padding.pad_sequence(t, max_len, direction, pad_value) for t in batch]
|
371
|
+
return torch.stack(padded), lengths
|
372
|
+
|
373
|
+
@staticmethod
|
374
|
+
def pad_batch_2d(
|
375
|
+
batch: List[Tensor],
|
376
|
+
pad_value: float = 0.0,
|
377
|
+
pad_to_multiple: Optional[int] = None,
|
378
|
+
direction: Literal["left", "right"] = "right",
|
379
|
+
) -> Tuple[Tensor, Tensor]:
|
380
|
+
"""Pads list of 2D tensors [T, D] → [B, T, D]"""
|
381
|
+
lengths = torch.tensor([t.size(0) for t in batch])
|
382
|
+
feat_dim = batch[0].size(1)
|
383
|
+
max_len = lengths.max().item()
|
384
|
+
if pad_to_multiple:
|
385
|
+
max_len = (
|
386
|
+
(max_len + pad_to_multiple - 1) // pad_to_multiple
|
387
|
+
) * pad_to_multiple
|
388
|
+
|
389
|
+
padded = []
|
390
|
+
for t in batch:
|
391
|
+
pad_len = max_len - t.size(0)
|
392
|
+
pad_tensor = t.new_full((pad_len, feat_dim), pad_value)
|
393
|
+
padded_tensor = (
|
394
|
+
torch.cat([pad_tensor, t], dim=0)
|
395
|
+
if direction == "left"
|
396
|
+
else torch.cat([t, pad_tensor], dim=0)
|
397
|
+
)
|
398
|
+
padded.append(padded_tensor)
|
399
|
+
return torch.stack(padded), lengths
|
400
|
+
|
401
|
+
@staticmethod
|
402
|
+
def pad_batch_nd(
|
403
|
+
batch: List[Tensor],
|
404
|
+
pad_value: float = 0.0,
|
405
|
+
dim: int = 0,
|
406
|
+
pad_to_multiple: Optional[int] = None,
|
407
|
+
) -> Tuple[Tensor, Tensor]:
|
408
|
+
"""
|
409
|
+
General N-D padding along time axis (dim=0, usually).
|
410
|
+
Handles shapes like:
|
411
|
+
[T, C] → [B, T, C]
|
412
|
+
[T, H, W] → [B, T, H, W]
|
413
|
+
"""
|
414
|
+
lengths = torch.tensor([t.size(dim) for t in batch])
|
415
|
+
max_len = lengths.max().item()
|
416
|
+
if pad_to_multiple:
|
417
|
+
max_len = (
|
418
|
+
(max_len + pad_to_multiple - 1) // pad_to_multiple
|
419
|
+
) * pad_to_multiple
|
420
|
+
|
421
|
+
padded = []
|
422
|
+
for t in batch:
|
423
|
+
pad_len = max_len - t.size(dim)
|
424
|
+
pad_shape = list(t.shape)
|
425
|
+
pad_shape[dim] = pad_len
|
426
|
+
pad_tensor = t.new_full(pad_shape, pad_value)
|
427
|
+
padded_tensor = torch.cat([t, pad_tensor], dim=dim)
|
428
|
+
padded.append(padded_tensor)
|
429
|
+
|
430
|
+
return torch.stack(padded), lengths
|
431
|
+
|
432
|
+
|
433
|
+
class MaskUtils:
|
434
|
+
|
435
|
+
@staticmethod
|
436
|
+
def apply_mask(x: Tensor, mask: Tensor, fill_value: Number = 0) -> Tensor:
|
437
|
+
"""
|
438
|
+
Apply a mask to a tensor, setting masked positions to `fill_value`.
|
439
|
+
Args:
|
440
|
+
x (Tensor): Input tensor of shape [..., T, D].
|
441
|
+
mask (Tensor): Mask of shape [..., T] where True = masked.
|
442
|
+
fill_value (Number): Value to fill masked positions with.
|
443
|
+
Returns:
|
444
|
+
Tensor: Masked tensor.
|
445
|
+
"""
|
446
|
+
return x.masked_fill(mask.unsqueeze(-1), fill_value)
|
447
|
+
|
448
|
+
@staticmethod
|
449
|
+
def get_padding_mask(
|
450
|
+
lengths: Optional[Tensor] = None,
|
451
|
+
tokens: Optional[Tensor] = None,
|
452
|
+
padding_id: int = 0,
|
453
|
+
) -> Tensor:
|
454
|
+
"""
|
455
|
+
Generate a padding mask: 1 for real tokens, 0 for padding.
|
456
|
+
Args:
|
457
|
+
lengths (Tensor): Tensor of shape [B] with sequence lengths.
|
458
|
+
tokens (Tensor): Tensor of shape [B, T] with token ids.
|
459
|
+
padding_id (int): Padding token id (default=0).
|
460
|
+
Returns:
|
461
|
+
Tensor: Boolean mask of shape [B, T].
|
462
|
+
"""
|
463
|
+
assert (
|
464
|
+
tokens is not None or lengths is not None
|
465
|
+
), "Either tokens or lengths must be provided."
|
466
|
+
|
467
|
+
if tokens is not None:
|
468
|
+
return tokens != padding_id
|
469
|
+
|
470
|
+
B = lengths.size(0)
|
471
|
+
max_len = lengths.max().item()
|
472
|
+
arange = torch.arange(max_len, device=lengths.device).unsqueeze(0).expand(B, -1)
|
473
|
+
return arange < lengths.unsqueeze(1)
|
474
|
+
|
475
|
+
@staticmethod
|
476
|
+
def get_padding_mask_fps(lengths: Tensor) -> Tensor:
|
477
|
+
"""
|
478
|
+
Legacy-style padding mask using 1-based comparison.
|
479
|
+
"""
|
480
|
+
mask = (
|
481
|
+
torch.arange(lengths.max(), device=lengths.device)
|
482
|
+
.unsqueeze(0)
|
483
|
+
.expand(lengths.shape[0], -1)
|
484
|
+
)
|
485
|
+
return (mask + 1) > lengths.unsqueeze(1)
|
486
|
+
|
487
|
+
@staticmethod
|
488
|
+
def get_causal_mask(
|
489
|
+
size: Union[int, tuple[int, ...]],
|
490
|
+
device: Optional[Union[str, torch.device]] = None,
|
491
|
+
) -> Tensor:
|
492
|
+
"""
|
493
|
+
Generate a causal mask for self-attention.
|
494
|
+
Args:
|
495
|
+
size (int or tuple): Size (T) or (1, T, T)
|
496
|
+
Returns:
|
497
|
+
Tensor: [1, T, T] boolean causal mask
|
498
|
+
"""
|
499
|
+
if isinstance(size, int):
|
500
|
+
size = (1, size, size)
|
501
|
+
return torch.tril(torch.ones(size, dtype=torch.bool, device=device))
|
502
|
+
|
503
|
+
@staticmethod
|
504
|
+
def combine_masks(pad_mask: Tensor, causal_mask: Tensor) -> Tensor:
|
505
|
+
"""
|
506
|
+
Combine padding and causal masks.
|
507
|
+
Args:
|
508
|
+
pad_mask (Tensor): [B, T] padding mask
|
509
|
+
causal_mask (Tensor): [1, T, T] causal mask
|
510
|
+
Returns:
|
511
|
+
Tensor: [B, T, T] combined mask
|
512
|
+
"""
|
513
|
+
return (
|
514
|
+
causal_mask & pad_mask.unsqueeze(1).expand(-1, pad_mask.size(1), -1).bool()
|
515
|
+
)
|
516
|
+
|
517
|
+
|
518
|
+
def masked_cross_entropy(
|
519
|
+
logits: torch.Tensor, # [B, T, V]
|
520
|
+
targets: torch.Tensor, # [B, T]
|
521
|
+
lengths: torch.Tensor, # [B]
|
522
|
+
reduction: str = "mean",
|
523
|
+
) -> torch.Tensor:
|
524
|
+
"""
|
525
|
+
CrossEntropyLoss with masking for variable-length sequences.
|
526
|
+
- logits: unnormalized scores [B, T, V]
|
527
|
+
- targets: ground truth indices [B, T]
|
528
|
+
- lengths: actual sequence lengths [B]
|
529
|
+
"""
|
530
|
+
B, T, V = logits.size()
|
531
|
+
logits = logits.view(-1, V)
|
532
|
+
targets = targets.view(-1)
|
533
|
+
|
534
|
+
# Create mask
|
535
|
+
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
536
|
+
mask = mask.reshape(-1)
|
537
|
+
|
538
|
+
# Apply CE only where mask == True
|
539
|
+
loss = F.cross_entropy(
|
540
|
+
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
541
|
+
)
|
542
|
+
if reduction == "none":
|
543
|
+
return loss
|
544
|
+
return loss
|
545
|
+
|
546
|
+
|
547
|
+
class NoiseScheduler(Module):
|
548
|
+
def __init__(self, timesteps: int = 512):
|
549
|
+
super().__init__()
|
550
|
+
|
551
|
+
betas = torch.linspace(1e-4, 0.02, timesteps)
|
552
|
+
alphas = 1.0 - betas
|
553
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
554
|
+
|
555
|
+
self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
|
556
|
+
self.register_buffer(
|
557
|
+
"sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
|
558
|
+
)
|
559
|
+
|
560
|
+
self.timesteps = timesteps
|
561
|
+
self.default_noise = math.sqrt(1.25)
|
562
|
+
|
563
|
+
def get_random_noise(
|
564
|
+
self, min_max: Tuple[float, float] = (-3, 3), seed: int = 0
|
565
|
+
) -> float:
|
566
|
+
if seed > 0:
|
567
|
+
random.seed(seed)
|
568
|
+
return random.uniform(*min_max)
|
569
|
+
|
570
|
+
def set_noise(
|
571
|
+
self,
|
572
|
+
seed: int = 0,
|
573
|
+
min_max: Tuple[float, float] = (-3, 3),
|
574
|
+
default: bool = False,
|
575
|
+
):
|
576
|
+
self.default_noise = (
|
577
|
+
math.sqrt(1.25) if default else self.get_random_noise(min_max, seed)
|
578
|
+
)
|
579
|
+
|
580
|
+
def forward(
|
581
|
+
self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
|
582
|
+
) -> Tensor:
|
583
|
+
if t < 0 or t >= self.timesteps:
|
584
|
+
raise ValueError(
|
585
|
+
f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
|
586
|
+
)
|
587
|
+
|
588
|
+
if noise is None:
|
589
|
+
noise = self.default_noise
|
590
|
+
|
591
|
+
if isinstance(noise, (float, int)):
|
592
|
+
noise = torch.randn_like(x_0) * noise
|
593
|
+
|
594
|
+
alpha_term = self.sqrt_alpha_cumprod[t] * x_0
|
595
|
+
noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
|
596
|
+
return alpha_term + noise_term
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from .._torch_commons import *
|
2
|
+
from .._basics import Model
|
3
|
+
from ..transform import get_sinusoidal_embedding
|
4
|
+
|
5
|
+
|
6
|
+
class FeedForward(Model):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
d_model: int,
|
10
|
+
ff_dim: int,
|
11
|
+
dropout: float = 0.01,
|
12
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
13
|
+
normalizer: nn.Module = nn.Identity(),
|
14
|
+
):
|
15
|
+
"""Creates a Feed-Forward Layer, with the chosen activation function and the normalizer."""
|
16
|
+
super().__init__()
|
17
|
+
self.net = nn.Sequential(
|
18
|
+
nn.Linear(d_model, ff_dim),
|
19
|
+
activation,
|
20
|
+
nn.Dropout(dropout),
|
21
|
+
nn.Linear(ff_dim, d_model),
|
22
|
+
normalizer,
|
23
|
+
)
|
24
|
+
|
25
|
+
def forward(self, x: Tensor):
|
26
|
+
return self.net(x)
|
27
|
+
|
28
|
+
|
29
|
+
class MLP(Model):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
d_model: int,
|
33
|
+
ff_dim: int,
|
34
|
+
n_classes: int,
|
35
|
+
dropout: float = 0.01,
|
36
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
37
|
+
normalizer: nn.Module = nn.Identity(),
|
38
|
+
):
|
39
|
+
"""Creates a MLP block, with the chosen activation function and the normalizer."""
|
40
|
+
super().__init__()
|
41
|
+
self.net = nn.Sequential(
|
42
|
+
nn.Linear(d_model, ff_dim),
|
43
|
+
activation,
|
44
|
+
nn.Dropout(dropout),
|
45
|
+
nn.Linear(ff_dim, n_classes),
|
46
|
+
normalizer,
|
47
|
+
)
|
48
|
+
|
49
|
+
def forward(self, x: Tensor):
|
50
|
+
return self.net(x)
|
51
|
+
|
52
|
+
|
53
|
+
class TimestepEmbedder(nn.Module):
|
54
|
+
def __init__(self, dim_emb: int, proj_dim: int):
|
55
|
+
super().__init__()
|
56
|
+
self.mlp = nn.Sequential(
|
57
|
+
nn.Linear(dim_emb, proj_dim),
|
58
|
+
nn.SiLU(),
|
59
|
+
nn.Linear(proj_dim, proj_dim),
|
60
|
+
)
|
61
|
+
|
62
|
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
63
|
+
# t: [B] (long)
|
64
|
+
emb = get_sinusoidal_embedding(t, self.mlp[0].in_features) # [B, dim_emb]
|
65
|
+
return self.mlp(emb) # [B, proj_dim]
|