lt-tensor 0.0.1a4__py3-none-any.whl → 0.0.1a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +9 -1
- lt_tensor/datasets/audio.py +94 -95
- lt_tensor/losses.py +145 -0
- lt_tensor/math_ops.py +7 -0
- lt_tensor/misc_utils.py +10 -96
- lt_tensor/model_base.py +105 -6
- lt_tensor/model_zoo/disc.py +14 -14
- lt_tensor/model_zoo/istft.py +41 -0
- lt_tensor/noise_tools.py +368 -0
- lt_tensor/processors/__init__.py +3 -0
- lt_tensor/processors/audio.py +193 -0
- lt_tensor/transform.py +190 -30
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a7.dist-info}/METADATA +2 -2
- lt_tensor-0.0.1a7.dist-info/RECORD +28 -0
- lt_tensor-0.0.1a4.dist-info/RECORD +0 -24
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a7.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a7.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a4.dist-info → lt_tensor-0.0.1a7.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1
1
|
__version__ = "0.0.1a"
|
2
2
|
|
3
3
|
from . import (
|
4
|
+
lr_schedulers,
|
4
5
|
model_zoo,
|
5
6
|
model_base,
|
6
7
|
math_ops,
|
7
8
|
misc_utils,
|
8
9
|
monotonic_align,
|
9
10
|
transform,
|
10
|
-
|
11
|
+
noise_tools,
|
12
|
+
losses,
|
13
|
+
processors,
|
14
|
+
datasets,
|
11
15
|
)
|
12
16
|
|
13
17
|
__all__ = [
|
@@ -18,4 +22,8 @@ __all__ = [
|
|
18
22
|
"monotonic_align",
|
19
23
|
"transform",
|
20
24
|
"lr_schedulers",
|
25
|
+
"noise_tools",
|
26
|
+
"losses",
|
27
|
+
"processors",
|
28
|
+
"datasets",
|
21
29
|
]
|
lt_tensor/datasets/audio.py
CHANGED
@@ -1,110 +1,109 @@
|
|
1
|
-
__all__ = ["
|
1
|
+
__all__ = ["WaveMelDatasets"]
|
2
2
|
from ..torch_commons import *
|
3
|
-
import torchaudio
|
4
3
|
from lt_utils.common import *
|
5
|
-
import
|
6
|
-
from
|
7
|
-
from
|
8
|
-
|
4
|
+
import random
|
5
|
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
6
|
+
from ..processors import AudioProcessor
|
7
|
+
import torch.nn.functional as FT
|
8
|
+
from ..misc_utils import log_tensor
|
9
9
|
|
10
10
|
|
11
|
-
class
|
11
|
+
class WaveMelDataset(Dataset):
|
12
|
+
"""Untested!"""
|
13
|
+
|
14
|
+
data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
|
12
15
|
|
13
16
|
def __init__(
|
14
17
|
self,
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
hop_length: int = 256,
|
20
|
-
f_min: float = 0,
|
21
|
-
f_max: float | None = None,
|
22
|
-
n_iter: int = 32,
|
23
|
-
center: bool = True,
|
24
|
-
mel_scale: Literal["htk", "slaney"] = "htk",
|
25
|
-
inv_n_fft: int = 16,
|
26
|
-
inv_hop: int = 4,
|
27
|
-
std: int = 4,
|
28
|
-
mean: int = -4,
|
18
|
+
audio_processor: AudioProcessor,
|
19
|
+
path: PathLike,
|
20
|
+
limit_files: Optional[int] = None,
|
21
|
+
max_frame_length: Optional[int] = None,
|
29
22
|
):
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
self.n_fft = n_fft
|
34
|
-
self.n_stft = n_fft // 2 + 1
|
35
|
-
self.f_min = f_min
|
36
|
-
self.f_max = f_max
|
37
|
-
self.n_iter = n_iter
|
38
|
-
self.hop_length = hop_length
|
39
|
-
self.sample_rate = sample_rate
|
40
|
-
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
41
|
-
sample_rate=sample_rate,
|
42
|
-
n_mels=n_mels,
|
43
|
-
n_fft=n_fft,
|
44
|
-
win_length=win_length,
|
45
|
-
hop_length=hop_length,
|
46
|
-
center=center,
|
47
|
-
f_min=f_min,
|
48
|
-
f_max=f_max,
|
49
|
-
mel_scale=mel_scale,
|
50
|
-
)
|
51
|
-
self.mel_rscale = torchaudio.transforms.InverseMelScale(
|
52
|
-
n_stft=self.n_stft,
|
53
|
-
n_mels=n_mels,
|
54
|
-
sample_rate=sample_rate,
|
55
|
-
f_min=f_min,
|
56
|
-
f_max=f_max,
|
57
|
-
mel_scale=mel_scale,
|
58
|
-
)
|
59
|
-
self.giffin_lim = torchaudio.transforms.GriffinLim(
|
60
|
-
n_fft=n_fft,
|
61
|
-
n_iter=n_iter,
|
62
|
-
win_length=win_length,
|
63
|
-
hop_length=hop_length,
|
64
|
-
)
|
65
|
-
self._inverse_transform = lambda x, y: inverse_transform(
|
66
|
-
x, y, inv_n_fft, inv_hop, inv_n_fft
|
23
|
+
super().__init__()
|
24
|
+
assert max_frame_length is None or max_frame_length >= (
|
25
|
+
(audio_processor.n_fft // 2) + 1
|
67
26
|
)
|
27
|
+
self.post_n_fft = (audio_processor.n_fft // 2) + 1
|
28
|
+
self.ap = audio_processor
|
29
|
+
self.files = self.ap.find_audios(path)
|
30
|
+
if limit_files:
|
31
|
+
random.shuffle(self.files)
|
32
|
+
self.files = self.files[:limit_files]
|
33
|
+
self.data = []
|
68
34
|
|
69
|
-
|
70
|
-
|
35
|
+
for file in self.files:
|
36
|
+
results = self.load_data(file, max_frame_length)
|
37
|
+
self.data.extend(results)
|
71
38
|
|
72
|
-
def
|
73
|
-
|
74
|
-
wave: Tensor,
|
75
|
-
) -> Tensor:
|
76
|
-
"""Returns: [B, M, ML]"""
|
77
|
-
mel_tensor = self.mel_spec(wave) # [M, ML]
|
78
|
-
mel_tensor = (mel_tensor - self.mean) / self.std
|
79
|
-
return mel_tensor # [B, M, ML]
|
39
|
+
def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
|
40
|
+
return {"mel": audio_mel, "raw": audio_raw, "file": file}
|
80
41
|
|
81
|
-
def
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
42
|
+
def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
|
43
|
+
initial_audio = self.ap.load_audio(file)
|
44
|
+
if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
|
45
|
+
audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
|
46
|
+
return [self._add_dict(initial_audio, audio_mel, file)]
|
47
|
+
results = []
|
48
|
+
for fragment in torch.split(
|
49
|
+
initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
|
50
|
+
):
|
51
|
+
if fragment.shape[-1] < self.post_n_fft:
|
52
|
+
# sometimes the tensor will be too small to be able to pass on mel
|
53
|
+
continue
|
54
|
+
audio_mel = self.ap.compute_mel(fragment, add_base=True)
|
55
|
+
results.append(self._add_dict(fragment, audio_mel, file))
|
56
|
+
return results
|
93
57
|
|
94
|
-
def
|
58
|
+
def get_data_loader(
|
95
59
|
self,
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
60
|
+
batch_size: int = 1,
|
61
|
+
shuffle: Optional[bool] = None,
|
62
|
+
sampler: Optional[Union[Sampler, Iterable]] = None,
|
63
|
+
batch_sampler: Optional[Union[Sampler[list], Iterable[list]]] = None,
|
64
|
+
num_workers: int = 0,
|
65
|
+
pin_memory: bool = False,
|
66
|
+
drop_last: bool = False,
|
67
|
+
timeout: float = 0,
|
68
|
+
):
|
69
|
+
return DataLoader(
|
70
|
+
self,
|
71
|
+
batch_size=batch_size,
|
72
|
+
shuffle=shuffle,
|
73
|
+
sampler=sampler,
|
74
|
+
batch_sampler=batch_sampler,
|
75
|
+
num_workers=num_workers,
|
76
|
+
pin_memory=pin_memory,
|
77
|
+
drop_last=drop_last,
|
78
|
+
timeout=timeout,
|
79
|
+
collate_fn=self.collate_fn,
|
110
80
|
)
|
81
|
+
|
82
|
+
@staticmethod
|
83
|
+
def collate_fn(batch: Sequence[Dict[str, Tensor]]):
|
84
|
+
mels = []
|
85
|
+
audios = []
|
86
|
+
files = []
|
87
|
+
for x in batch:
|
88
|
+
mels.append(x["mel"])
|
89
|
+
audios.append(x["raw"])
|
90
|
+
files.append(x["file"])
|
91
|
+
# Find max time in mel (dim -1), and max audio length
|
92
|
+
max_mel_len = max([m.shape[-1] for m in mels])
|
93
|
+
max_audio_len = max([a.shape[-1] for a in audios])
|
94
|
+
|
95
|
+
padded_mels = torch.stack(
|
96
|
+
[FT.pad(m, (0, max_mel_len - m.shape[-1])) for m in mels]
|
97
|
+
) # shape: [B, 80, T_max]
|
98
|
+
|
99
|
+
padded_audios = torch.stack(
|
100
|
+
[FT.pad(a, (0, max_audio_len - a.shape[-1])) for a in audios]
|
101
|
+
) # shape: [B, L_max]
|
102
|
+
|
103
|
+
return padded_mels, padded_audios, files
|
104
|
+
|
105
|
+
def __len__(self):
|
106
|
+
return len(self.data)
|
107
|
+
|
108
|
+
def __getitem__(self, index):
|
109
|
+
return self.data[index]
|
lt_tensor/losses.py
ADDED
@@ -0,0 +1,145 @@
|
|
1
|
+
__all__ = ["masked_cross_entropy"]
|
2
|
+
import math
|
3
|
+
import random
|
4
|
+
from .torch_commons import *
|
5
|
+
from lt_utils.common import *
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
def masked_cross_entropy(
|
10
|
+
logits: torch.Tensor, # [B, T, V]
|
11
|
+
targets: torch.Tensor, # [B, T]
|
12
|
+
lengths: torch.Tensor, # [B]
|
13
|
+
reduction: str = "mean",
|
14
|
+
) -> torch.Tensor:
|
15
|
+
"""
|
16
|
+
CrossEntropyLoss with masking for variable-length sequences.
|
17
|
+
- logits: unnormalized scores [B, T, V]
|
18
|
+
- targets: ground truth indices [B, T]
|
19
|
+
- lengths: actual sequence lengths [B]
|
20
|
+
"""
|
21
|
+
B, T, V = logits.size()
|
22
|
+
logits = logits.view(-1, V)
|
23
|
+
targets = targets.view(-1)
|
24
|
+
|
25
|
+
# Create mask
|
26
|
+
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
27
|
+
mask = mask.reshape(-1)
|
28
|
+
|
29
|
+
# Apply CE only where mask == True
|
30
|
+
loss = F.cross_entropy(
|
31
|
+
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
32
|
+
)
|
33
|
+
if reduction == "none":
|
34
|
+
return loss
|
35
|
+
return loss
|
36
|
+
|
37
|
+
|
38
|
+
def diff_loss(pred_noise, true_noise, mask=None):
|
39
|
+
"""Standard diffusion noise-prediction loss (e.g., DDPM)"""
|
40
|
+
if mask is not None:
|
41
|
+
return F.mse_loss(pred_noise * mask, true_noise * mask)
|
42
|
+
return F.mse_loss(pred_noise, true_noise)
|
43
|
+
|
44
|
+
|
45
|
+
def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
|
46
|
+
"""Combines L1 and L2"""
|
47
|
+
l1 = F.l1_loss(pred_noise, true_noise)
|
48
|
+
l2 = F.mse_loss(pred_noise, true_noise)
|
49
|
+
return alpha * l1 + (1 - alpha) * l2
|
50
|
+
|
51
|
+
|
52
|
+
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
53
|
+
loss = 0
|
54
|
+
for real, fake in zip(real_preds, fake_preds):
|
55
|
+
if use_lsgan:
|
56
|
+
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
57
|
+
fake, torch.zeros_like(fake)
|
58
|
+
)
|
59
|
+
else:
|
60
|
+
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
61
|
+
torch.log(1 - fake + 1e-7)
|
62
|
+
)
|
63
|
+
return loss
|
64
|
+
|
65
|
+
|
66
|
+
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
67
|
+
loss = 0
|
68
|
+
for real, fake in zip(real_preds, fake_preds):
|
69
|
+
if use_lsgan:
|
70
|
+
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
71
|
+
fake, torch.zeros_like(fake)
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
75
|
+
torch.log(1 - fake + 1e-7)
|
76
|
+
)
|
77
|
+
return loss
|
78
|
+
|
79
|
+
|
80
|
+
def gan_g_loss(fake_preds, use_lsgan=True):
|
81
|
+
loss = 0
|
82
|
+
for fake in fake_preds:
|
83
|
+
if use_lsgan:
|
84
|
+
loss += F.mse_loss(fake, torch.ones_like(fake))
|
85
|
+
else:
|
86
|
+
loss += -torch.mean(torch.log(fake + 1e-7))
|
87
|
+
return loss
|
88
|
+
|
89
|
+
|
90
|
+
def feature_matching_loss(real_feats, fake_feats):
|
91
|
+
"""real_feats and fake_feats are lists of intermediate features"""
|
92
|
+
loss = 0
|
93
|
+
for real_layers, fake_layers in zip(real_feats, fake_feats):
|
94
|
+
for r, f in zip(real_layers, fake_layers):
|
95
|
+
loss += F.l1_loss(f, r.detach())
|
96
|
+
return loss
|
97
|
+
|
98
|
+
|
99
|
+
def feature_loss(real_fmaps, fake_fmaps, weight=2.0):
|
100
|
+
loss = 0.0
|
101
|
+
for dr, dg in zip(real_fmaps, fake_fmaps): # Each (layer list from a discriminator)
|
102
|
+
for r_feat, g_feat in zip(dr, dg):
|
103
|
+
loss += F.l1_loss(r_feat, g_feat)
|
104
|
+
return loss * weight
|
105
|
+
|
106
|
+
|
107
|
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
108
|
+
loss = 0.0
|
109
|
+
r_losses = []
|
110
|
+
g_losses = []
|
111
|
+
|
112
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
113
|
+
r_loss = F.mse_loss(dr, torch.ones_like(dr))
|
114
|
+
g_loss = F.mse_loss(dg, torch.zeros_like(dg))
|
115
|
+
loss += r_loss + g_loss
|
116
|
+
r_losses.append(r_loss)
|
117
|
+
g_losses.append(g_loss)
|
118
|
+
|
119
|
+
return loss, r_losses, g_losses
|
120
|
+
|
121
|
+
|
122
|
+
def generator_loss(fake_outputs):
|
123
|
+
total = 0.0
|
124
|
+
g_losses = []
|
125
|
+
for out in fake_outputs:
|
126
|
+
loss = F.mse_loss(out, torch.ones_like(out))
|
127
|
+
g_losses.append(loss)
|
128
|
+
total += loss
|
129
|
+
return total, g_losses
|
130
|
+
|
131
|
+
|
132
|
+
def multi_resolution_stft_loss(y, y_hat, fft_sizes=[512, 1024, 2048]):
|
133
|
+
loss = 0
|
134
|
+
for fft_size in fft_sizes:
|
135
|
+
hop = fft_size // 4
|
136
|
+
win = fft_size
|
137
|
+
y_stft = torch.stft(
|
138
|
+
y, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
|
139
|
+
)
|
140
|
+
y_hat_stft = torch.stft(
|
141
|
+
y_hat, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
|
142
|
+
)
|
143
|
+
|
144
|
+
loss += F.l1_loss(torch.abs(y_stft), torch.abs(y_hat_stft))
|
145
|
+
return loss
|
lt_tensor/math_ops.py
CHANGED
@@ -8,6 +8,7 @@ __all__ = [
|
|
8
8
|
"dot_product",
|
9
9
|
"normalize_tensor",
|
10
10
|
"log_magnitude",
|
11
|
+
"shift_time",
|
11
12
|
"phase",
|
12
13
|
]
|
13
14
|
|
@@ -50,6 +51,11 @@ def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
|
|
50
51
|
return torch.roll(x, shifts=1, dims=dim)
|
51
52
|
|
52
53
|
|
54
|
+
def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
|
55
|
+
"""Shifts tensor along time axis (last dim)."""
|
56
|
+
return torch.roll(x, shifts=shift, dims=-1)
|
57
|
+
|
58
|
+
|
53
59
|
def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
|
54
60
|
"""Computes dot product along the specified dimension."""
|
55
61
|
return torch.sum(x * y, dim=dim)
|
@@ -69,3 +75,4 @@ def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
|
|
69
75
|
def phase(stft_complex: Tensor) -> Tensor:
|
70
76
|
"""Returns phase from complex STFT."""
|
71
77
|
return torch.angle(stft_complex)
|
78
|
+
|
lt_tensor/misc_utils.py
CHANGED
@@ -8,7 +8,6 @@ __all__ = [
|
|
8
8
|
"unfreeze_selected_weights",
|
9
9
|
"clip_gradients",
|
10
10
|
"detach_hidden",
|
11
|
-
"tensor_summary",
|
12
11
|
"one_hot",
|
13
12
|
"safe_divide",
|
14
13
|
"batch_pad",
|
@@ -18,22 +17,20 @@ __all__ = [
|
|
18
17
|
"default_device",
|
19
18
|
"Packing",
|
20
19
|
"Padding",
|
21
|
-
"
|
22
|
-
"masked_cross_entropy",
|
23
|
-
"NoiseScheduler",
|
20
|
+
"Masking",
|
24
21
|
]
|
25
22
|
|
26
23
|
import gc
|
24
|
+
import sys
|
27
25
|
import random
|
28
26
|
import numpy as np
|
29
27
|
from lt_utils.type_utils import is_str
|
30
28
|
from .torch_commons import *
|
31
|
-
from lt_utils.misc_utils import
|
32
|
-
from lt_utils.file_ops import load_json, load_yaml, save_json, save_yaml
|
33
|
-
import math
|
29
|
+
from lt_utils.misc_utils import cache_wrapper
|
34
30
|
from lt_utils.common import *
|
35
31
|
import torch.nn.functional as F
|
36
32
|
|
33
|
+
|
37
34
|
def log_tensor(
|
38
35
|
item: Union[Tensor, np.ndarray],
|
39
36
|
title: Optional[str] = None,
|
@@ -64,10 +61,13 @@ def log_tensor(
|
|
64
61
|
print(f"mean: {item.mean(dim=dim):.4f}")
|
65
62
|
except:
|
66
63
|
pass
|
67
|
-
|
68
|
-
|
64
|
+
if print_tensor:
|
65
|
+
print(item)
|
69
66
|
if has_title:
|
70
67
|
print("".join(["-"] * _b), "\n")
|
68
|
+
else:
|
69
|
+
print("\n")
|
70
|
+
sys.stdout.flush()
|
71
71
|
|
72
72
|
|
73
73
|
def set_seed(seed: int):
|
@@ -136,11 +136,6 @@ def detach_hidden(hidden):
|
|
136
136
|
return tuple(detach_hidden(h) for h in hidden)
|
137
137
|
|
138
138
|
|
139
|
-
def tensor_summary(tensor: torch.Tensor) -> str:
|
140
|
-
"""Prints min/max/mean/std of a tensor for debugging."""
|
141
|
-
return f"Shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}, min: {tensor.min():.4f}, max: {tensor.max():.4f}, mean: {tensor.mean():.4f}, std: {tensor.std():.4f}"
|
142
|
-
|
143
|
-
|
144
139
|
def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
|
145
140
|
"""One-hot encodes a tensor of labels."""
|
146
141
|
return F.one_hot(labels, num_classes).float()
|
@@ -463,7 +458,7 @@ class Padding:
|
|
463
458
|
return torch.stack(padded), lengths
|
464
459
|
|
465
460
|
|
466
|
-
class
|
461
|
+
class Masking:
|
467
462
|
|
468
463
|
@staticmethod
|
469
464
|
def apply_mask(x: Tensor, mask: Tensor, fill_value: Number = 0) -> Tensor:
|
@@ -546,84 +541,3 @@ class MaskUtils:
|
|
546
541
|
return (
|
547
542
|
causal_mask & pad_mask.unsqueeze(1).expand(-1, pad_mask.size(1), -1).bool()
|
548
543
|
)
|
549
|
-
|
550
|
-
|
551
|
-
def masked_cross_entropy(
|
552
|
-
logits: torch.Tensor, # [B, T, V]
|
553
|
-
targets: torch.Tensor, # [B, T]
|
554
|
-
lengths: torch.Tensor, # [B]
|
555
|
-
reduction: str = "mean",
|
556
|
-
) -> torch.Tensor:
|
557
|
-
"""
|
558
|
-
CrossEntropyLoss with masking for variable-length sequences.
|
559
|
-
- logits: unnormalized scores [B, T, V]
|
560
|
-
- targets: ground truth indices [B, T]
|
561
|
-
- lengths: actual sequence lengths [B]
|
562
|
-
"""
|
563
|
-
B, T, V = logits.size()
|
564
|
-
logits = logits.view(-1, V)
|
565
|
-
targets = targets.view(-1)
|
566
|
-
|
567
|
-
# Create mask
|
568
|
-
mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
|
569
|
-
mask = mask.reshape(-1)
|
570
|
-
|
571
|
-
# Apply CE only where mask == True
|
572
|
-
loss = F.cross_entropy(
|
573
|
-
logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
|
574
|
-
)
|
575
|
-
if reduction == "none":
|
576
|
-
return loss
|
577
|
-
return loss
|
578
|
-
|
579
|
-
|
580
|
-
class NoiseScheduler(nn.Module):
|
581
|
-
def __init__(self, timesteps: int = 512):
|
582
|
-
super().__init__()
|
583
|
-
|
584
|
-
betas = torch.linspace(1e-4, 0.02, timesteps)
|
585
|
-
alphas = 1.0 - betas
|
586
|
-
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
587
|
-
|
588
|
-
self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
|
589
|
-
self.register_buffer(
|
590
|
-
"sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
|
591
|
-
)
|
592
|
-
|
593
|
-
self.timesteps = timesteps
|
594
|
-
self.default_noise = math.sqrt(1.25)
|
595
|
-
|
596
|
-
def get_random_noise(
|
597
|
-
self, min_max: Tuple[float, float] = (-3, 3), seed: int = 0
|
598
|
-
) -> float:
|
599
|
-
if seed > 0:
|
600
|
-
random.seed(seed)
|
601
|
-
return random.uniform(*min_max)
|
602
|
-
|
603
|
-
def set_noise(
|
604
|
-
self,
|
605
|
-
seed: int = 0,
|
606
|
-
min_max: Tuple[float, float] = (-3, 3),
|
607
|
-
default: bool = False,
|
608
|
-
):
|
609
|
-
self.default_noise = (
|
610
|
-
math.sqrt(1.25) if default else self.get_random_noise(min_max, seed)
|
611
|
-
)
|
612
|
-
|
613
|
-
def forward(
|
614
|
-
self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
|
615
|
-
) -> Tensor:
|
616
|
-
if t < 0 or t >= self.timesteps:
|
617
|
-
raise ValueError(
|
618
|
-
f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
|
619
|
-
)
|
620
|
-
|
621
|
-
if noise is None:
|
622
|
-
noise = self.default_noise
|
623
|
-
|
624
|
-
if isinstance(noise, (float, int)):
|
625
|
-
noise = torch.randn_like(x_0) * noise
|
626
|
-
|
627
|
-
alpha_term = self.sqrt_alpha_cumprod[t] * x_0
|
628
|
-
noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
|
629
|
-
return alpha_term + noise_term
|