lt-tensor 0.0.1a33__py3-none-any.whl → 0.0.1a35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/losses.py +169 -47
- lt_tensor/lr_schedulers.py +147 -21
- lt_tensor/misc_utils.py +35 -42
- lt_tensor/model_zoo/activations/__init__.py +3 -0
- lt_tensor/model_zoo/activations/alias_free/__init__.py +3 -0
- lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/act.py +8 -6
- lt_tensor/model_zoo/activations/snake/__init__.py +41 -43
- lt_tensor/model_zoo/audio_models/__init__.py +2 -2
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +243 -0
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +16 -347
- lt_tensor/model_zoo/audio_models/istft/__init__.py +14 -349
- lt_tensor/model_zoo/audio_models/resblocks.py +248 -0
- lt_tensor/model_zoo/convs.py +21 -32
- lt_tensor/model_zoo/losses/discriminators.py +143 -230
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/METADATA +1 -1
- lt_tensor-0.0.1a35.dist-info/RECORD +40 -0
- lt_tensor/model_zoo/activations/alias_free_torch/__init__.py +0 -1
- lt_tensor-0.0.1a33.dist-info/RECORD +0 -37
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/filter.py +0 -0
- /lt_tensor/model_zoo/activations/{alias_free_torch → alias_free}/resample.py +0 -0
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a33.dist-info → lt_tensor-0.0.1a35.dist-info}/top_level.txt +0 -0
lt_tensor/__init__.py
CHANGED
lt_tensor/losses.py
CHANGED
@@ -6,19 +6,24 @@ __all__ = [
|
|
6
6
|
"hybrid_loss",
|
7
7
|
"diff_loss",
|
8
8
|
"cosine_loss",
|
9
|
-
"gan_loss",
|
10
9
|
"ft_n_loss",
|
10
|
+
"MultiMelScaleLoss",
|
11
11
|
]
|
12
12
|
import math
|
13
13
|
import random
|
14
14
|
from lt_tensor.torch_commons import *
|
15
15
|
from lt_utils.common import *
|
16
16
|
import torch.nn.functional as F
|
17
|
+
from lt_tensor.model_base import Model
|
18
|
+
from lt_tensor.processors import AudioProcessor, AudioProcessorConfig
|
19
|
+
from lt_tensor.math_ops import normalize_minmax, normalize_unit_norm, normalize_zscore
|
20
|
+
|
17
21
|
|
18
22
|
def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
|
19
23
|
if weight is not None:
|
20
|
-
return torch.mean((torch.abs(output - target) + weight) **0.5)
|
21
|
-
return torch.mean(torch.abs(output - target)**0.5)
|
24
|
+
return torch.mean((torch.abs(output - target) + weight) ** 0.5)
|
25
|
+
return torch.mean(torch.abs(output - target) ** 0.5)
|
26
|
+
|
22
27
|
|
23
28
|
def adaptive_l1_loss(
|
24
29
|
inp: Tensor,
|
@@ -58,50 +63,6 @@ def cosine_loss(inp, tgt):
|
|
58
63
|
return 1 - cos.mean() # Lower is better
|
59
64
|
|
60
65
|
|
61
|
-
class GanLosses:
|
62
|
-
@staticmethod
|
63
|
-
def get_loss(
|
64
|
-
pred: Tensor,
|
65
|
-
target_is_real: bool,
|
66
|
-
loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
|
67
|
-
) -> Tensor:
|
68
|
-
if loss_type == "bce": # Standard GAN
|
69
|
-
target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
|
70
|
-
return F.binary_cross_entropy_with_logits(pred, target)
|
71
|
-
|
72
|
-
elif loss_type == "mse": # LSGAN
|
73
|
-
target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
|
74
|
-
return F.mse_loss(torch.sigmoid(pred), target)
|
75
|
-
|
76
|
-
elif loss_type == "hinge":
|
77
|
-
if target_is_real:
|
78
|
-
return torch.mean(F.relu(1.0 - pred))
|
79
|
-
else:
|
80
|
-
return torch.mean(F.relu(1.0 + pred))
|
81
|
-
|
82
|
-
elif loss_type == "wasserstein":
|
83
|
-
return -pred.mean() if target_is_real else pred.mean()
|
84
|
-
|
85
|
-
else:
|
86
|
-
raise ValueError(f"Unknown loss_type: {loss_type}")
|
87
|
-
|
88
|
-
@staticmethod
|
89
|
-
def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
|
90
|
-
return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
|
91
|
-
|
92
|
-
@staticmethod
|
93
|
-
def discriminator_loss(
|
94
|
-
real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
|
95
|
-
) -> Tensor:
|
96
|
-
real_loss = GanLosses.get_loss(
|
97
|
-
real_pred, target_is_real=True, loss_type=loss_type
|
98
|
-
)
|
99
|
-
fake_loss = GanLosses.get_loss(
|
100
|
-
fake_pred.detach(), target_is_real=False, loss_type=loss_type
|
101
|
-
)
|
102
|
-
return (real_loss + fake_loss) * 0.5
|
103
|
-
|
104
|
-
|
105
66
|
def masked_cross_entropy(
|
106
67
|
logits: torch.Tensor, # [B, T, V]
|
107
68
|
targets: torch.Tensor, # [B, T]
|
@@ -157,3 +118,164 @@ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
|
157
118
|
torch.log(1 - fake + 1e-7)
|
158
119
|
)
|
159
120
|
return loss
|
121
|
+
|
122
|
+
|
123
|
+
class MultiMelScaleLoss(Model):
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
sample_rate: int,
|
127
|
+
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
128
|
+
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
129
|
+
n_ffts: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
130
|
+
hops: List[int] = [8, 16, 32, 64, 128, 256, 512],
|
131
|
+
f_min: float = [0, 0, 0, 0, 0, 0, 0],
|
132
|
+
f_max: Optional[float] = [None, None, None, None, None, None, None],
|
133
|
+
loss_mel_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
134
|
+
loss_pitch_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
135
|
+
loss_rms_fn: Callable[[Tensor, Tensor], Tensor] = nn.L1Loss(),
|
136
|
+
center: bool = True,
|
137
|
+
power: float = 1.0,
|
138
|
+
normalized: bool = False,
|
139
|
+
pad_mode: str = "reflect",
|
140
|
+
onesided: Optional[bool] = None,
|
141
|
+
std: int = 4,
|
142
|
+
mean: int = -4,
|
143
|
+
use_istft_norm: bool = True,
|
144
|
+
use_pitch_loss: bool = True,
|
145
|
+
use_rms_loss: bool = True,
|
146
|
+
norm_pitch_fn: Callable[[Tensor], Tensor] = normalize_minmax,
|
147
|
+
norm_rms_fn: Callable[[Tensor], Tensor] = normalize_zscore,
|
148
|
+
lambda_mel: float = 1.0,
|
149
|
+
lambda_rms: float = 1.0,
|
150
|
+
lambda_pitch: float = 1.0,
|
151
|
+
weight: float = 1.0,
|
152
|
+
):
|
153
|
+
super().__init__()
|
154
|
+
assert (
|
155
|
+
len(n_mels)
|
156
|
+
== len(window_lengths)
|
157
|
+
== len(n_ffts)
|
158
|
+
== len(hops)
|
159
|
+
== len(f_min)
|
160
|
+
== len(f_max)
|
161
|
+
)
|
162
|
+
self.loss_mel_fn = loss_mel_fn
|
163
|
+
self.loss_pitch_fn = loss_pitch_fn
|
164
|
+
self.loss_rms_fn = loss_rms_fn
|
165
|
+
self.lambda_mel = lambda_mel
|
166
|
+
self.weight = weight
|
167
|
+
self.use_istft_norm = use_istft_norm
|
168
|
+
self.use_pitch_loss = use_pitch_loss
|
169
|
+
self.use_rms_loss = use_rms_loss
|
170
|
+
self.lambda_pitch = lambda_pitch
|
171
|
+
self.lambda_rms = lambda_rms
|
172
|
+
|
173
|
+
self.norm_pitch_fn = norm_pitch_fn
|
174
|
+
self.norm_rms = norm_rms_fn
|
175
|
+
|
176
|
+
self._setup_mels(
|
177
|
+
sample_rate,
|
178
|
+
n_mels,
|
179
|
+
window_lengths,
|
180
|
+
n_ffts,
|
181
|
+
hops,
|
182
|
+
f_min,
|
183
|
+
f_max,
|
184
|
+
center,
|
185
|
+
power,
|
186
|
+
normalized,
|
187
|
+
pad_mode,
|
188
|
+
onesided,
|
189
|
+
std,
|
190
|
+
mean,
|
191
|
+
)
|
192
|
+
|
193
|
+
def _setup_mels(
|
194
|
+
self,
|
195
|
+
sample_rate: int,
|
196
|
+
n_mels: List[int],
|
197
|
+
window_lengths: List[int],
|
198
|
+
n_ffts: List[int],
|
199
|
+
hops: List[int],
|
200
|
+
f_min: List[float],
|
201
|
+
f_max: List[Optional[float]],
|
202
|
+
center: bool,
|
203
|
+
power: float,
|
204
|
+
normalized: bool,
|
205
|
+
pad_mode: str,
|
206
|
+
onesided: Optional[bool],
|
207
|
+
std: int,
|
208
|
+
mean: int,
|
209
|
+
):
|
210
|
+
assert (
|
211
|
+
len(n_mels)
|
212
|
+
== len(window_lengths)
|
213
|
+
== len(n_ffts)
|
214
|
+
== len(hops)
|
215
|
+
== len(f_min)
|
216
|
+
== len(f_max)
|
217
|
+
)
|
218
|
+
_mel_kwargs = dict(
|
219
|
+
sample_rate=sample_rate,
|
220
|
+
center=center,
|
221
|
+
onesided=onesided,
|
222
|
+
normalized=normalized,
|
223
|
+
power=power,
|
224
|
+
pad_mode=pad_mode,
|
225
|
+
std=std,
|
226
|
+
mean=mean,
|
227
|
+
)
|
228
|
+
self.mel_spectrograms: List[AudioProcessor] = nn.ModuleList(
|
229
|
+
[
|
230
|
+
AudioProcessor(
|
231
|
+
AudioProcessorConfig(
|
232
|
+
**_mel_kwargs,
|
233
|
+
n_mels=mel,
|
234
|
+
n_fft=n_fft,
|
235
|
+
win_length=win,
|
236
|
+
hop_length=hop,
|
237
|
+
f_min=fmin,
|
238
|
+
f_max=fmax,
|
239
|
+
)
|
240
|
+
)
|
241
|
+
for mel, win, n_fft, hop, fmin, fmax in zip(
|
242
|
+
n_mels, window_lengths, n_ffts, hops, f_min, f_max
|
243
|
+
)
|
244
|
+
]
|
245
|
+
)
|
246
|
+
|
247
|
+
def forward(
|
248
|
+
self, input_wave: torch.Tensor, target_wave: torch.Tensor
|
249
|
+
) -> torch.Tensor:
|
250
|
+
assert self.use_istft_norm or input_wave.shape[-1] == target_wave.shape[-1]
|
251
|
+
target_wave = target_wave.to(input_wave.device)
|
252
|
+
losses = 0.0
|
253
|
+
for M in self.mel_spectrograms:
|
254
|
+
# Apply normalization if requested
|
255
|
+
if self.use_istft_norm:
|
256
|
+
input_proc = M.istft_norm(input_wave, length=target_wave.shape[-1])
|
257
|
+
target_proc = M.istft_norm(target_wave, length=target_wave.shape[-1])
|
258
|
+
else:
|
259
|
+
input_proc, target_proc = input_wave, target_wave
|
260
|
+
|
261
|
+
x_mels = M(input_proc)
|
262
|
+
y_mels = M(target_proc)
|
263
|
+
|
264
|
+
loss = self.loss_mel_fn(x_mels.squeeze(), y_mels.squeeze())
|
265
|
+
losses += loss * self.lambda_mel
|
266
|
+
|
267
|
+
# pitch/f0 loss
|
268
|
+
if self.use_pitch_loss:
|
269
|
+
x_pitch = self.norm_pitch_fn(M.compute_pitch(input_proc))
|
270
|
+
y_pitch = self.norm_pitch_fn(M.compute_pitch(target_proc))
|
271
|
+
f0_loss = self.loss_pitch_fn(x_pitch, y_pitch)
|
272
|
+
losses += f0_loss * self.lambda_pitch
|
273
|
+
|
274
|
+
# energy/rms loss
|
275
|
+
if self.use_rms_loss:
|
276
|
+
x_rms = self.norm_rms(M.compute_rms(input_proc, x_mels))
|
277
|
+
y_rms = self.norm_rms(M.compute_rms(target_proc, y_mels))
|
278
|
+
rms_loss = self.loss_rms_fn(x_rms, y_rms)
|
279
|
+
losses += rms_loss * self.lambda_rms
|
280
|
+
|
281
|
+
return losses * self.weight
|
lt_tensor/lr_schedulers.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1
1
|
__all__ = [
|
2
2
|
"WarmupDecayScheduler",
|
3
3
|
"AdaptiveDropScheduler",
|
4
|
-
"
|
4
|
+
"SinusoidalDecayLR",
|
5
|
+
"GuidedWaveringLR",
|
6
|
+
"FloorExponentialLR",
|
5
7
|
]
|
6
8
|
|
7
9
|
import math
|
8
10
|
from torch.optim import Optimizer
|
9
|
-
from torch.optim.lr_scheduler import
|
11
|
+
from torch.optim.lr_scheduler import LRScheduler
|
12
|
+
from typing import Optional
|
13
|
+
from numbers import Number
|
14
|
+
from lt_tensor.misc_utils import update_lr
|
10
15
|
|
11
16
|
|
12
|
-
class WarmupDecayScheduler(
|
17
|
+
class WarmupDecayScheduler(LRScheduler):
|
13
18
|
def __init__(
|
14
19
|
self,
|
15
20
|
optimizer: Optimizer,
|
@@ -49,7 +54,7 @@ class WarmupDecayScheduler(_LRScheduler):
|
|
49
54
|
return lrs
|
50
55
|
|
51
56
|
|
52
|
-
class AdaptiveDropScheduler(
|
57
|
+
class AdaptiveDropScheduler(LRScheduler):
|
53
58
|
def __init__(
|
54
59
|
self,
|
55
60
|
optimizer,
|
@@ -89,26 +94,147 @@ class AdaptiveDropScheduler(_LRScheduler):
|
|
89
94
|
return [group["lr"] for group in self.optimizer.param_groups]
|
90
95
|
|
91
96
|
|
92
|
-
class
|
97
|
+
class SinusoidalDecayLR(LRScheduler):
|
93
98
|
def __init__(
|
94
|
-
self,
|
99
|
+
self,
|
100
|
+
optimizer: Optimizer,
|
101
|
+
initial_lr: float = 1e-3,
|
102
|
+
target_lr: float = 1e-5,
|
103
|
+
floor_lr: float = 1e-7,
|
104
|
+
decay_rate: float = 1e-6, # decay per period
|
105
|
+
wave_amplitude: float = 1e-5,
|
106
|
+
period: int = 256,
|
107
|
+
last_epoch: int = -1,
|
95
108
|
):
|
96
|
-
""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
self.
|
104
|
-
self.
|
109
|
+
assert decay_rate != 0.0, "decay_rate must be different from 0.0"
|
110
|
+
assert (
|
111
|
+
initial_lr >= target_lr >= floor_lr
|
112
|
+
), "Must satisfy: initial_lr ≥ target_lr ≥ floor_lr"
|
113
|
+
|
114
|
+
self.initial_lr = initial_lr
|
115
|
+
self.target_lr = target_lr
|
116
|
+
self.floor_lr = floor_lr
|
117
|
+
self.decay_rate = decay_rate
|
118
|
+
self.wave_amplitude = wave_amplitude
|
105
119
|
self.period = period
|
106
|
-
|
120
|
+
|
107
121
|
super().__init__(optimizer, last_epoch)
|
108
122
|
|
109
123
|
def get_lr(self):
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
124
|
+
step = self.last_epoch + 1
|
125
|
+
cycles = step // self.period
|
126
|
+
t = step % self.period
|
127
|
+
# Decay center down to target_lr, then freeze
|
128
|
+
center_decay = math.exp(-self.decay_rate * cycles)
|
129
|
+
center = max(self.target_lr, self.initial_lr * center_decay)
|
130
|
+
# Decay amplitude in sync with center (relative to initial)
|
131
|
+
amplitude_decay = math.exp(-self.decay_rate * cycles)
|
132
|
+
current_amplitude = self.wave_amplitude * self.initial_lr * amplitude_decay
|
133
|
+
sin_offset = math.sin(2 * math.pi * t / self.period)
|
134
|
+
lr = max(center + current_amplitude * sin_offset, self.floor_lr)
|
135
|
+
return [lr for _ in self.optimizer.param_groups]
|
136
|
+
|
137
|
+
|
138
|
+
class GuidedWaveringLR(LRScheduler):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
optimizer: Optimizer,
|
142
|
+
initial_lr: float = 1e-3,
|
143
|
+
target_lr: float = 1e-5,
|
144
|
+
floor_lr: float = 1e-7,
|
145
|
+
decay_rate: float = 0.01,
|
146
|
+
wave_amplitude: float = 0.02,
|
147
|
+
period: int = 256,
|
148
|
+
stop_decay_after: int = None,
|
149
|
+
last_epoch: int = -1,
|
150
|
+
):
|
151
|
+
assert decay_rate != 0.0, "decay_rate must be non-zero"
|
152
|
+
assert (
|
153
|
+
initial_lr >= target_lr >= floor_lr
|
154
|
+
), "Must satisfy: initial ≥ target ≥ floor"
|
155
|
+
|
156
|
+
self.initial_lr = initial_lr
|
157
|
+
self.target_lr = target_lr
|
158
|
+
self.floor_lr = floor_lr
|
159
|
+
self.decay_rate = decay_rate
|
160
|
+
self.wave_amplitude = wave_amplitude
|
161
|
+
self.period = period
|
162
|
+
self.stop_decay_after = stop_decay_after
|
163
|
+
|
164
|
+
super().__init__(optimizer, last_epoch)
|
165
|
+
|
166
|
+
def get_lr(self):
|
167
|
+
step = self.last_epoch + 1
|
168
|
+
cycles = step // self.period
|
169
|
+
t = step % self.period
|
170
|
+
|
171
|
+
decay_cycles = (
|
172
|
+
min(cycles, self.stop_decay_after) if self.stop_decay_after else cycles
|
173
|
+
)
|
174
|
+
center = max(
|
175
|
+
self.target_lr, self.initial_lr * math.exp(-self.decay_rate * decay_cycles)
|
176
|
+
)
|
177
|
+
amp = (
|
178
|
+
self.wave_amplitude
|
179
|
+
* self.initial_lr
|
180
|
+
* math.exp(-self.decay_rate * decay_cycles)
|
181
|
+
)
|
182
|
+
phase = 2 * math.pi * t / self.period
|
183
|
+
wave = math.sin(phase) * math.cos(phase)
|
184
|
+
lr = max(center + amp * wave, self.floor_lr)
|
185
|
+
return [lr for _ in self.optimizer.param_groups]
|
186
|
+
|
187
|
+
|
188
|
+
class FloorExponentialLR(LRScheduler):
|
189
|
+
"""Modified version from exponential lr, to have a minimum and reset functions.
|
190
|
+
|
191
|
+
Decays the learning rate of each parameter group by gamma every epoch.
|
192
|
+
|
193
|
+
When last_epoch=-1, sets initial lr as lr.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
optimizer (Optimizer): Wrapped optimizer.
|
197
|
+
gamma (float): Multiplicative factor of learning rate decay.
|
198
|
+
last_epoch (int): The index of last epoch. Default: -1.
|
199
|
+
"""
|
200
|
+
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
optimizer: Optimizer,
|
204
|
+
initial_lr: float = 1e-4,
|
205
|
+
gamma: float = 0.99998,
|
206
|
+
last_epoch: int = -1,
|
207
|
+
floor_lr: float = 1e-6,
|
208
|
+
):
|
209
|
+
self.gamma = gamma
|
210
|
+
self.floor_lr = floor_lr
|
211
|
+
self.initial_lr = initial_lr
|
212
|
+
|
213
|
+
super().__init__(optimizer, last_epoch)
|
214
|
+
|
215
|
+
def set_floor(self, new_value: float):
|
216
|
+
assert isinstance(new_value, Number)
|
217
|
+
self.floor_lr = new_value
|
218
|
+
|
219
|
+
def reset_lr(self, new_value: Optional[float] = None):
|
220
|
+
new_lr = new_value if isinstance(new_value, Number) else self.initial_lr
|
221
|
+
self.initial_lr = new_lr
|
222
|
+
update_lr(self.optimizer, new_lr)
|
223
|
+
|
224
|
+
def get_lr(self):
|
225
|
+
|
226
|
+
if self.last_epoch == 0:
|
227
|
+
return [
|
228
|
+
max(group["lr"], self.floor_lr) for group in self.optimizer.param_groups
|
229
|
+
]
|
230
|
+
|
231
|
+
return [
|
232
|
+
max(group["lr"] * self.gamma, self.floor_lr)
|
233
|
+
for group in self.optimizer.param_groups
|
234
|
+
]
|
235
|
+
|
236
|
+
def _get_closed_form_lr(self):
|
237
|
+
return [
|
238
|
+
max(base_lr * self.gamma**self.last_epoch, self.floor_lr)
|
239
|
+
for base_lr in self.base_lrs
|
240
|
+
]
|
lt_tensor/misc_utils.py
CHANGED
@@ -24,6 +24,7 @@ __all__ = [
|
|
24
24
|
"plot_view",
|
25
25
|
"get_weights",
|
26
26
|
"get_activated_conv",
|
27
|
+
"update_lr",
|
27
28
|
]
|
28
29
|
|
29
30
|
import re
|
@@ -77,6 +78,33 @@ def get_activated_conv(
|
|
77
78
|
)
|
78
79
|
|
79
80
|
|
81
|
+
def get_loss_average(losses: List[float]):
|
82
|
+
"""A little helper for training, for example:
|
83
|
+
```python
|
84
|
+
losses = []
|
85
|
+
for epoch in range(100):
|
86
|
+
for inp, label in dataloader:
|
87
|
+
optimizer.zero_grad()
|
88
|
+
out = model(inp)
|
89
|
+
loss = loss_fn(out, label)
|
90
|
+
optimizer.step()
|
91
|
+
losses.append(loss.item())
|
92
|
+
print(f"Epoch {epoch+1} | Loss: {get_loss_average(losses):.4f}")
|
93
|
+
"""
|
94
|
+
if not losses:
|
95
|
+
return float("nan")
|
96
|
+
return sum(losses) / len(losses)
|
97
|
+
|
98
|
+
|
99
|
+
def update_lr(optimizer: optim.Optimizer, new_value: float = 1e-4):
|
100
|
+
for param_group in optimizer.param_groups:
|
101
|
+
if isinstance(param_group["lr"], Tensor):
|
102
|
+
param_group["lr"].fill_(new_value)
|
103
|
+
else:
|
104
|
+
param_group["lr"] = new_value
|
105
|
+
return optimizer
|
106
|
+
|
107
|
+
|
80
108
|
def plot_view(
|
81
109
|
data: Dict[str, List[Any]],
|
82
110
|
title: str = "Loss",
|
@@ -520,49 +548,14 @@ def sample_tensor(tensor: torch.Tensor, num_samples: int = 5):
|
|
520
548
|
return flat[idx]
|
521
549
|
|
522
550
|
|
523
|
-
class TorchCacheUtils:
|
524
|
-
cached_shortcuts: dict[str, Callable[[None], None]] = {}
|
525
|
-
|
526
|
-
has_cuda: bool = torch.cuda.is_available()
|
527
|
-
has_xpu: bool = torch.xpu.is_available()
|
528
|
-
has_mps: bool = torch.mps.is_available()
|
529
|
-
|
530
|
-
_ignore: list[str] = []
|
531
|
-
|
532
|
-
def __init__(self):
|
533
|
-
pass
|
534
|
-
|
535
|
-
def _apply_clear(self, device: str):
|
536
|
-
if device in self._ignore:
|
537
|
-
gc.collect()
|
538
|
-
return
|
539
|
-
try:
|
540
|
-
clear_fn = self.cached_shortcuts.get(
|
541
|
-
device, getattr(torch, device).empty_cache
|
542
|
-
)
|
543
|
-
if device not in self.cached_shortcuts:
|
544
|
-
self.cached_shortcuts.update({device: clear_fn})
|
545
|
-
|
546
|
-
except Exception as e:
|
547
|
-
print(e)
|
548
|
-
self._ignore.append(device)
|
549
|
-
|
550
|
-
def clear(self):
|
551
|
-
gc.collect()
|
552
|
-
if self.has_xpu:
|
553
|
-
self._apply_clear("xpu")
|
554
|
-
if self.has_cuda:
|
555
|
-
self._apply_clear("cuda")
|
556
|
-
if self.has_mps:
|
557
|
-
self._apply_clear("mps")
|
558
|
-
gc.collect()
|
559
|
-
|
560
|
-
|
561
|
-
_clear_cache_cls = TorchCacheUtils()
|
562
|
-
|
563
|
-
|
564
551
|
def clear_cache():
|
565
|
-
|
552
|
+
if torch.cuda.is_available():
|
553
|
+
torch.cuda.empty_cache()
|
554
|
+
if torch.mps.is_available():
|
555
|
+
torch.mps.empty_cache()
|
556
|
+
if torch.xpu.is_available():
|
557
|
+
torch.xpu.empty_cache()
|
558
|
+
gc.collect()
|
566
559
|
|
567
560
|
|
568
561
|
@cache_wrapper
|
@@ -1,15 +1,17 @@
|
|
1
|
-
import torch
|
2
1
|
import torch.nn as nn
|
3
|
-
|
4
|
-
|
5
|
-
|
2
|
+
from lt_tensor.model_zoo.activations.alias_free.resample import (
|
3
|
+
UpSample2d,
|
4
|
+
DownSample2d,
|
5
|
+
UpSample1d,
|
6
|
+
DownSample1d,
|
7
|
+
)
|
6
8
|
|
7
9
|
|
8
10
|
class Activation1d(nn.Module):
|
9
11
|
|
10
12
|
def __init__(
|
11
13
|
self,
|
12
|
-
activation,
|
14
|
+
activation: nn.Module,
|
13
15
|
up_ratio: int = 2,
|
14
16
|
down_ratio: int = 2,
|
15
17
|
up_kernel_size: int = 12,
|
@@ -34,7 +36,7 @@ class Activation2d(nn.Module):
|
|
34
36
|
|
35
37
|
def __init__(
|
36
38
|
self,
|
37
|
-
activation,
|
39
|
+
activation: nn.Module,
|
38
40
|
up_ratio: int = 2,
|
39
41
|
down_ratio: int = 2,
|
40
42
|
up_kernel_size: int = 12,
|