lt-tensor 0.0.1a9__py3-none-any.whl → 0.0.1a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/losses.py +97 -83
- lt_tensor/misc_utils.py +58 -35
- lt_tensor/model_base.py +247 -31
- lt_tensor/model_zoo/bsc.py +22 -0
- lt_tensor/model_zoo/disc.py +54 -50
- lt_tensor/model_zoo/istft.py +0 -41
- lt_tensor/noise_tools.py +1 -2
- lt_tensor/transform.py +13 -37
- {lt_tensor-0.0.1a9.dist-info → lt_tensor-0.0.1a11.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a9.dist-info → lt_tensor-0.0.1a11.dist-info}/RECORD +13 -13
- {lt_tensor-0.0.1a9.dist-info → lt_tensor-0.0.1a11.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a9.dist-info → lt_tensor-0.0.1a11.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a9.dist-info → lt_tensor-0.0.1a11.dist-info}/top_level.txt +0 -0
lt_tensor/losses.py
CHANGED
@@ -1,10 +1,106 @@
|
|
1
|
-
__all__ = [
|
1
|
+
__all__ = [
|
2
|
+
"masked_cross_entropy",
|
3
|
+
"adaptive_l1_loss",
|
4
|
+
"contrastive_loss",
|
5
|
+
"smooth_l1_loss",
|
6
|
+
"hybrid_loss",
|
7
|
+
"diff_loss",
|
8
|
+
"cosine_loss",
|
9
|
+
"gan_loss",
|
10
|
+
"ft_n_loss",
|
11
|
+
]
|
2
12
|
import math
|
3
13
|
import random
|
4
14
|
from .torch_commons import *
|
5
15
|
from lt_utils.common import *
|
6
16
|
import torch.nn.functional as F
|
7
17
|
|
18
|
+
def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
|
19
|
+
if weight is not None:
|
20
|
+
return torch.mean((torch.abs(output - target) + weight) **0.5)
|
21
|
+
return torch.mean(torch.abs(output - target)**0.5)
|
22
|
+
|
23
|
+
def adaptive_l1_loss(
|
24
|
+
inp: Tensor,
|
25
|
+
tgt: Tensor,
|
26
|
+
weight: Optional[Tensor] = None,
|
27
|
+
scale: float = 1.0,
|
28
|
+
inverted: bool = False,
|
29
|
+
):
|
30
|
+
|
31
|
+
if weight is not None:
|
32
|
+
loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
|
33
|
+
else:
|
34
|
+
loss = torch.mean(torch.abs(inp - tgt))
|
35
|
+
loss *= scale
|
36
|
+
if inverted:
|
37
|
+
return -loss
|
38
|
+
return loss
|
39
|
+
|
40
|
+
|
41
|
+
def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
|
42
|
+
diff = torch.abs(inp - tgt)
|
43
|
+
loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
|
44
|
+
if weight is not None:
|
45
|
+
loss *= weight
|
46
|
+
return loss.mean()
|
47
|
+
|
48
|
+
|
49
|
+
def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
|
50
|
+
# label == 1: similar, label == 0: dissimilar
|
51
|
+
dist = torch.nn.functional.pairwise_distance(x1, x2)
|
52
|
+
loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
|
53
|
+
return loss.mean()
|
54
|
+
|
55
|
+
|
56
|
+
def cosine_loss(inp, tgt):
|
57
|
+
cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
|
58
|
+
return 1 - cos.mean() # Lower is better
|
59
|
+
|
60
|
+
|
61
|
+
class GanLosses:
|
62
|
+
@staticmethod
|
63
|
+
def get_loss(
|
64
|
+
pred: Tensor,
|
65
|
+
target_is_real: bool,
|
66
|
+
loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
|
67
|
+
) -> Tensor:
|
68
|
+
if loss_type == "bce": # Standard GAN
|
69
|
+
target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
|
70
|
+
return F.binary_cross_entropy_with_logits(pred, target)
|
71
|
+
|
72
|
+
elif loss_type == "mse": # LSGAN
|
73
|
+
target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
|
74
|
+
return F.mse_loss(torch.sigmoid(pred), target)
|
75
|
+
|
76
|
+
elif loss_type == "hinge":
|
77
|
+
if target_is_real:
|
78
|
+
return torch.mean(F.relu(1.0 - pred))
|
79
|
+
else:
|
80
|
+
return torch.mean(F.relu(1.0 + pred))
|
81
|
+
|
82
|
+
elif loss_type == "wasserstein":
|
83
|
+
return -pred.mean() if target_is_real else pred.mean()
|
84
|
+
|
85
|
+
else:
|
86
|
+
raise ValueError(f"Unknown loss_type: {loss_type}")
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
|
90
|
+
return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def discriminator_loss(
|
94
|
+
real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
|
95
|
+
) -> Tensor:
|
96
|
+
real_loss = GanLosses.get_loss(
|
97
|
+
real_pred, target_is_real=True, loss_type=loss_type
|
98
|
+
)
|
99
|
+
fake_loss = GanLosses.get_loss(
|
100
|
+
fake_pred.detach(), target_is_real=False, loss_type=loss_type
|
101
|
+
)
|
102
|
+
return (real_loss + fake_loss) * 0.5
|
103
|
+
|
8
104
|
|
9
105
|
def masked_cross_entropy(
|
10
106
|
logits: torch.Tensor, # [B, T, V]
|
@@ -61,85 +157,3 @@ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
|
61
157
|
torch.log(1 - fake + 1e-7)
|
62
158
|
)
|
63
159
|
return loss
|
64
|
-
|
65
|
-
|
66
|
-
def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
|
67
|
-
loss = 0
|
68
|
-
for real, fake in zip(real_preds, fake_preds):
|
69
|
-
if use_lsgan:
|
70
|
-
loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
|
71
|
-
fake, torch.zeros_like(fake)
|
72
|
-
)
|
73
|
-
else:
|
74
|
-
loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
|
75
|
-
torch.log(1 - fake + 1e-7)
|
76
|
-
)
|
77
|
-
return loss
|
78
|
-
|
79
|
-
|
80
|
-
def gan_g_loss(fake_preds, use_lsgan=True):
|
81
|
-
loss = 0
|
82
|
-
for fake in fake_preds:
|
83
|
-
if use_lsgan:
|
84
|
-
loss += F.mse_loss(fake, torch.ones_like(fake))
|
85
|
-
else:
|
86
|
-
loss += -torch.mean(torch.log(fake + 1e-7))
|
87
|
-
return loss
|
88
|
-
|
89
|
-
|
90
|
-
def feature_matching_loss(real_feats, fake_feats):
|
91
|
-
"""real_feats and fake_feats are lists of intermediate features"""
|
92
|
-
loss = 0
|
93
|
-
for real_layers, fake_layers in zip(real_feats, fake_feats):
|
94
|
-
for r, f in zip(real_layers, fake_layers):
|
95
|
-
loss += F.l1_loss(f, r.detach())
|
96
|
-
return loss
|
97
|
-
|
98
|
-
|
99
|
-
def feature_loss(real_fmaps, fake_fmaps, weight=2.0):
|
100
|
-
loss = 0.0
|
101
|
-
for dr, dg in zip(real_fmaps, fake_fmaps): # Each (layer list from a discriminator)
|
102
|
-
for r_feat, g_feat in zip(dr, dg):
|
103
|
-
loss += F.l1_loss(r_feat, g_feat)
|
104
|
-
return loss * weight
|
105
|
-
|
106
|
-
|
107
|
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
108
|
-
loss = 0.0
|
109
|
-
r_losses = []
|
110
|
-
g_losses = []
|
111
|
-
|
112
|
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
113
|
-
r_loss = F.mse_loss(dr, torch.ones_like(dr))
|
114
|
-
g_loss = F.mse_loss(dg, torch.zeros_like(dg))
|
115
|
-
loss += r_loss + g_loss
|
116
|
-
r_losses.append(r_loss)
|
117
|
-
g_losses.append(g_loss)
|
118
|
-
|
119
|
-
return loss, r_losses, g_losses
|
120
|
-
|
121
|
-
|
122
|
-
def generator_loss(fake_outputs):
|
123
|
-
total = 0.0
|
124
|
-
g_losses = []
|
125
|
-
for out in fake_outputs:
|
126
|
-
loss = F.mse_loss(out, torch.ones_like(out))
|
127
|
-
g_losses.append(loss)
|
128
|
-
total += loss
|
129
|
-
return total, g_losses
|
130
|
-
|
131
|
-
|
132
|
-
def multi_resolution_stft_loss(y, y_hat, fft_sizes=[512, 1024, 2048]):
|
133
|
-
loss = 0
|
134
|
-
for fft_size in fft_sizes:
|
135
|
-
hop = fft_size // 4
|
136
|
-
win = fft_size
|
137
|
-
y_stft = torch.stft(
|
138
|
-
y, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
|
139
|
-
)
|
140
|
-
y_hat_stft = torch.stft(
|
141
|
-
y_hat, n_fft=fft_size, hop_length=hop, win_length=win, return_complex=True
|
142
|
-
)
|
143
|
-
|
144
|
-
loss += F.l1_loss(torch.abs(y_stft), torch.abs(y_hat_stft))
|
145
|
-
return loss
|
lt_tensor/misc_utils.py
CHANGED
@@ -15,6 +15,7 @@ __all__ = [
|
|
15
15
|
"TorchCacheUtils",
|
16
16
|
"clear_cache",
|
17
17
|
"default_device",
|
18
|
+
"soft_restore",
|
18
19
|
"Packing",
|
19
20
|
"Padding",
|
20
21
|
"Masking",
|
@@ -35,14 +36,29 @@ from lt_utils.misc_utils import ff_list
|
|
35
36
|
import torch.nn.functional as F
|
36
37
|
|
37
38
|
|
39
|
+
def soft_restore(tensor, epsilon=1e-6):
|
40
|
+
return torch.where(tensor == 0, torch.full_like(tensor, epsilon), tensor)
|
41
|
+
|
42
|
+
|
38
43
|
def try_torch(fn: str, *args, **kwargs):
|
44
|
+
tryed_torch = False
|
45
|
+
not_present_message = (
|
46
|
+
f"Both `torch` and `torch.nn.functional` does not contain the module `{fn}`"
|
47
|
+
)
|
39
48
|
try:
|
40
|
-
|
41
|
-
|
42
|
-
|
49
|
+
if hasattr(F, fn):
|
50
|
+
return getattr(F, fn)(*args, **kwargs)
|
51
|
+
elif hasattr(torch, fn):
|
52
|
+
tryed_torch = True
|
43
53
|
return getattr(torch, fn)(*args, **kwargs)
|
54
|
+
return not_present_message
|
55
|
+
except Exception as a:
|
56
|
+
try:
|
57
|
+
if not tryed_torch and hasattr(torch, fn):
|
58
|
+
return getattr(torch, fn)(*args, **kwargs)
|
59
|
+
return str(a)
|
44
60
|
except Exception as e:
|
45
|
-
return str(e)
|
61
|
+
return str(e) + " | " + str(a)
|
46
62
|
|
47
63
|
|
48
64
|
def log_tensor(
|
@@ -152,7 +168,12 @@ class LogTensor:
|
|
152
168
|
print(message)
|
153
169
|
sys.stdout.flush()
|
154
170
|
|
155
|
-
def _process(
|
171
|
+
def _process(
|
172
|
+
self,
|
173
|
+
name: str,
|
174
|
+
resp: Union[Callable[[Any], Any], Any],
|
175
|
+
do_not_print: bool = False,
|
176
|
+
):
|
156
177
|
if callable(resp):
|
157
178
|
try:
|
158
179
|
response = resp()
|
@@ -161,15 +182,16 @@ class LogTensor:
|
|
161
182
|
print(e)
|
162
183
|
else:
|
163
184
|
response = resp
|
164
|
-
|
165
|
-
|
166
|
-
|
185
|
+
if not do_not_print and self.do_print:
|
186
|
+
msg = self._setup_message(name, response)
|
187
|
+
self._print(msg)
|
188
|
+
return dict(item=name, value=response)
|
167
189
|
|
168
190
|
def __call__(
|
169
191
|
self,
|
170
|
-
|
192
|
+
inputs: Union[Tensor, np.ndarray, Sequence[Number]],
|
171
193
|
title: Optional[str] = None,
|
172
|
-
|
194
|
+
target: Optional[Union[Tensor, np.ndarray, Sequence[Number]]] = None,
|
173
195
|
*,
|
174
196
|
log_tensor: bool = False,
|
175
197
|
log_device: bool = False,
|
@@ -177,7 +199,7 @@ class LogTensor:
|
|
177
199
|
log_std: bool = False,
|
178
200
|
dim_mean: int = -1,
|
179
201
|
dim_std: int = -1,
|
180
|
-
|
202
|
+
print_extended: bool = False,
|
181
203
|
external_logs: List[Tuple[str, Union[Dict[str, Any], List[Any], Any]]] = [
|
182
204
|
("softmax", {"dim": 0}),
|
183
205
|
("relu", None),
|
@@ -186,12 +208,11 @@ class LogTensor:
|
|
186
208
|
("max", dict(dim=-1)),
|
187
209
|
],
|
188
210
|
validate_item_type: bool = False,
|
211
|
+
exclude_invalid_losses: bool = True,
|
189
212
|
**kwargs,
|
190
213
|
):
|
191
|
-
|
192
|
-
|
193
|
-
invalid_type = not isinstance(main_item, (Tensor, np.ndarray, list, tuple))
|
194
|
-
_main_item_tp = type(main_item)
|
214
|
+
invalid_type = not isinstance(inputs, (Tensor, np.ndarray, list, tuple))
|
215
|
+
_main_item_tp = type(inputs)
|
195
216
|
assert (
|
196
217
|
not validate_item_type or not invalid_type
|
197
218
|
), f"Invalid Type: {_main_item_tp}"
|
@@ -199,8 +220,8 @@ class LogTensor:
|
|
199
220
|
self._print(f"Invalid Type: {_main_item_tp}")
|
200
221
|
return
|
201
222
|
|
202
|
-
|
203
|
-
|
223
|
+
inputs = self._setup_tensor(inputs)
|
224
|
+
target = self._setup_tensor(target)
|
204
225
|
if is_str(title):
|
205
226
|
title = re.sub(r"\s+", " ", title.replace("_", " "))
|
206
227
|
has_title = is_str(title)
|
@@ -209,43 +230,43 @@ class LogTensor:
|
|
209
230
|
else:
|
210
231
|
title = "Unnamed"
|
211
232
|
|
212
|
-
current_register = {"title": title}
|
213
|
-
|
214
|
-
current_register
|
215
|
-
current_register
|
216
|
-
current_register.update(self._process("dtype", main_item.dtype))
|
233
|
+
current_register = {"title": title, "values": []}
|
234
|
+
current_register["shape"] = self._process("shape", inputs.shape)
|
235
|
+
current_register["ndim"] = self._process("ndim", inputs.ndim)
|
236
|
+
current_register["dtype"] = self._process("dtype", inputs.dtype)
|
217
237
|
if log_device:
|
218
|
-
current_register
|
238
|
+
current_register["device"] = self._process("device", inputs.device)
|
219
239
|
|
220
240
|
if log_mean:
|
221
|
-
fn = lambda:
|
241
|
+
fn = lambda: inputs.mean(
|
222
242
|
dim=dim_mean,
|
223
243
|
)
|
224
|
-
current_register
|
244
|
+
current_register["mean"] = self._process("mean", fn)
|
225
245
|
|
226
246
|
if log_std:
|
227
|
-
fn = lambda:
|
228
|
-
current_register
|
247
|
+
fn = lambda: inputs.std(dim=dim_std)
|
248
|
+
current_register["std"] = self._process("std", fn)
|
229
249
|
|
230
250
|
if external_logs:
|
231
251
|
old_print = self.do_print
|
232
|
-
self.do_print =
|
252
|
+
self.do_print = print_extended
|
233
253
|
self._print("\n---[ External Logs ] ---")
|
234
254
|
for log_fn, log_args in external_logs:
|
235
255
|
if isinstance(log_args, Sequence) and not isinstance(log_args, str):
|
236
|
-
value = try_torch(log_fn,
|
256
|
+
value = try_torch(log_fn, inputs, *log_args)
|
237
257
|
elif isinstance(log_args, dict):
|
238
|
-
value = try_torch(log_fn,
|
258
|
+
value = try_torch(log_fn, inputs, **log_args)
|
239
259
|
elif log_args is None:
|
240
|
-
value = try_torch(log_fn,
|
260
|
+
value = try_torch(log_fn, inputs)
|
241
261
|
else:
|
242
|
-
value = try_torch(log_fn,
|
262
|
+
value = try_torch(log_fn, inputs, log_args)
|
263
|
+
|
243
264
|
results = self._process(log_fn, value)
|
244
265
|
current_register[log_fn] = results
|
245
266
|
self.do_print = old_print
|
246
267
|
|
247
|
-
if
|
248
|
-
losses = get_losses(
|
268
|
+
if target is not None:
|
269
|
+
losses = get_losses(inputs, target, exclude_invalid_losses)
|
249
270
|
started_ls = False
|
250
271
|
if self.do_print:
|
251
272
|
for loss, res in losses.items():
|
@@ -258,7 +279,9 @@ class LogTensor:
|
|
258
279
|
current_register["loss"] = losses
|
259
280
|
|
260
281
|
if log_tensor:
|
261
|
-
current_register.
|
282
|
+
current_register["values"].append(
|
283
|
+
self._process("Tensor", inputs, not print_extended)
|
284
|
+
)
|
262
285
|
|
263
286
|
self._print(self.end_with)
|
264
287
|
self._store_item_and_update(current_register)
|
lt_tensor/model_base.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
|
-
__all__ = ["Model"]
|
2
|
-
|
1
|
+
__all__ = ["Model", "_ModelExtended", "LossTracker"]
|
3
2
|
|
3
|
+
import gc
|
4
|
+
import json
|
5
|
+
import math
|
4
6
|
import warnings
|
5
7
|
from .torch_commons import *
|
6
8
|
from lt_utils.common import *
|
7
|
-
from lt_utils.misc_utils import log_traceback
|
9
|
+
from lt_utils.misc_utils import log_traceback, get_current_time
|
8
10
|
|
9
11
|
T = TypeVar("T")
|
10
12
|
|
@@ -17,6 +19,80 @@ POSSIBLE_OUTPUT_TYPES: TypeAlias = Union[
|
|
17
19
|
]
|
18
20
|
|
19
21
|
|
22
|
+
class LossTracker:
|
23
|
+
last_file = f"logs/history_{get_current_time()}.json"
|
24
|
+
|
25
|
+
def __init__(self, max_len=50_000):
|
26
|
+
self.max_len = max_len
|
27
|
+
self.history = {
|
28
|
+
"train": [],
|
29
|
+
"eval": [],
|
30
|
+
}
|
31
|
+
|
32
|
+
def append(self, loss: float, mode: Literal["train", "eval"] = "train"):
|
33
|
+
assert mode in self.history, f"Invalid mode '{mode}'. Use 'train' or 'eval'."
|
34
|
+
self.history[mode].append(float(loss))
|
35
|
+
if len(self.history[mode]) > self.max_len:
|
36
|
+
self.history[mode] = self.history[mode][-self.max_len :]
|
37
|
+
|
38
|
+
def get(self, mode: Literal["train", "eval"] = "train"):
|
39
|
+
return self.history.get(mode, [])
|
40
|
+
|
41
|
+
def save(self, path: Optional[PathLike] = None):
|
42
|
+
if path is None:
|
43
|
+
path = f"logs/history_{get_current_time()}.json"
|
44
|
+
|
45
|
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
46
|
+
with open(path, "w") as f:
|
47
|
+
json.dump(self.history, f, indent=2)
|
48
|
+
|
49
|
+
self.last_file = path
|
50
|
+
|
51
|
+
def load(self, path: Optional[PathLike] = None):
|
52
|
+
if path is None:
|
53
|
+
_path = self.last_file
|
54
|
+
else:
|
55
|
+
_path = path
|
56
|
+
with open(_path) as f:
|
57
|
+
self.history = json.load(f)
|
58
|
+
if path is not None:
|
59
|
+
self.last_file = path
|
60
|
+
|
61
|
+
def plot(self, backend: Literal["matplotlib", "plotly"] = "plotly"):
|
62
|
+
if backend == "plotly":
|
63
|
+
try:
|
64
|
+
import plotly.graph_objs as go
|
65
|
+
except ModuleNotFoundError:
|
66
|
+
warnings.warn(
|
67
|
+
"No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
|
68
|
+
)
|
69
|
+
return
|
70
|
+
fig = go.Figure()
|
71
|
+
for mode, losses in self.history.items():
|
72
|
+
if losses:
|
73
|
+
fig.add_trace(go.Scatter(y=losses, name=mode.capitalize()))
|
74
|
+
fig.update_layout(
|
75
|
+
title="Training vs Evaluation Loss",
|
76
|
+
xaxis_title="Step",
|
77
|
+
yaxis_title="Loss",
|
78
|
+
template="plotly_dark",
|
79
|
+
)
|
80
|
+
fig.show()
|
81
|
+
|
82
|
+
elif backend == "matplotlib":
|
83
|
+
import matplotlib.pyplot as plt
|
84
|
+
|
85
|
+
for mode, losses in self.history.items():
|
86
|
+
if losses:
|
87
|
+
plt.plot(losses, label=f"{mode.capitalize()} Loss")
|
88
|
+
plt.title("Loss over Time")
|
89
|
+
plt.xlabel("Step")
|
90
|
+
plt.ylabel("Loss")
|
91
|
+
plt.legend()
|
92
|
+
plt.grid(True)
|
93
|
+
plt.show()
|
94
|
+
|
95
|
+
|
20
96
|
class Model(nn.Module, ABC):
|
21
97
|
"""
|
22
98
|
This makes it easier to assign a device and retrieves it later
|
@@ -24,6 +100,8 @@ class Model(nn.Module, ABC):
|
|
24
100
|
|
25
101
|
_device: torch.device = ROOT_DEVICE
|
26
102
|
_autocast: bool = False
|
103
|
+
_loss_history: LossTracker = LossTracker(100_000)
|
104
|
+
_is_unfrozen: bool = False
|
27
105
|
|
28
106
|
@property
|
29
107
|
def autocast(self):
|
@@ -61,6 +139,7 @@ class Model(nn.Module, ABC):
|
|
61
139
|
if hasattr(self, weight):
|
62
140
|
w = getattr(self, weight)
|
63
141
|
if isinstance(w, nn.Module):
|
142
|
+
|
64
143
|
w.requires_grad_(not freeze)
|
65
144
|
else:
|
66
145
|
weight.requires_grad_(not freeze)
|
@@ -112,21 +191,27 @@ class Model(nn.Module, ABC):
|
|
112
191
|
for name, param in self.named_parameters():
|
113
192
|
if no_exclusions:
|
114
193
|
try:
|
115
|
-
param.
|
116
|
-
|
194
|
+
if param.requires_grad:
|
195
|
+
param.requires_grad_(False)
|
196
|
+
frozen.append(name)
|
197
|
+
else:
|
198
|
+
not_frozen.append((name, "was_frozen"))
|
117
199
|
except Exception as e:
|
118
200
|
not_frozen.append((name, str(e)))
|
119
201
|
elif any(layer in name for layer in exclude):
|
120
202
|
try:
|
121
|
-
param.
|
122
|
-
|
203
|
+
if param.requires_grad:
|
204
|
+
param.requires_grad_(False)
|
205
|
+
frozen.append(name)
|
206
|
+
else:
|
207
|
+
not_frozen.append((name, "was_frozen"))
|
123
208
|
except Exception as e:
|
124
209
|
not_frozen.append((name, str(e)))
|
125
210
|
else:
|
126
|
-
not_frozen.append((name, "
|
211
|
+
not_frozen.append((name, "excluded"))
|
127
212
|
return dict(frozen=frozen, not_frozen=not_frozen)
|
128
213
|
|
129
|
-
def
|
214
|
+
def unfreeze_all(self, exclude: Optional[list[str]] = None):
|
130
215
|
"""Unfreezes all model parameters except specified layers."""
|
131
216
|
no_exclusions = not exclude
|
132
217
|
unfrozen = []
|
@@ -134,18 +219,24 @@ class Model(nn.Module, ABC):
|
|
134
219
|
for name, param in self.named_parameters():
|
135
220
|
if no_exclusions:
|
136
221
|
try:
|
137
|
-
param.
|
138
|
-
|
222
|
+
if not param.requires_grad:
|
223
|
+
param.requires_grad_(True)
|
224
|
+
unfrozen.append(name)
|
225
|
+
else:
|
226
|
+
not_unfrozen.append((name, "was_unfrozen"))
|
139
227
|
except Exception as e:
|
140
228
|
not_unfrozen.append((name, str(e)))
|
141
229
|
elif any(layer in name for layer in exclude):
|
142
230
|
try:
|
143
|
-
param.
|
144
|
-
|
231
|
+
if not param.requires_grad:
|
232
|
+
param.requires_grad_(True)
|
233
|
+
unfrozen.append(name)
|
234
|
+
else:
|
235
|
+
not_unfrozen.append((name, "was_unfrozen"))
|
145
236
|
except Exception as e:
|
146
237
|
not_unfrozen.append((name, str(e)))
|
147
238
|
else:
|
148
|
-
not_unfrozen.append((name, "
|
239
|
+
not_unfrozen.append((name, "excluded"))
|
149
240
|
return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
|
150
241
|
|
151
242
|
def to(self, *args, **kwargs):
|
@@ -192,6 +283,7 @@ class Model(nn.Module, ABC):
|
|
192
283
|
|
193
284
|
self._apply(convert)
|
194
285
|
self.device = device
|
286
|
+
self._apply_device_to()
|
195
287
|
return self
|
196
288
|
|
197
289
|
def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
@@ -202,6 +294,7 @@ class Model(nn.Module, ABC):
|
|
202
294
|
":" + str(device) if isinstance(device, (int, float)) else device.index
|
203
295
|
)
|
204
296
|
self.device = dvc
|
297
|
+
self._apply_device_to()
|
205
298
|
return self
|
206
299
|
|
207
300
|
def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
@@ -212,6 +305,7 @@ class Model(nn.Module, ABC):
|
|
212
305
|
":" + str(device) if isinstance(device, (int, float)) else device.index
|
213
306
|
)
|
214
307
|
self.device = dvc
|
308
|
+
self._apply_device_to()
|
215
309
|
return self
|
216
310
|
|
217
311
|
def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
@@ -222,6 +316,7 @@ class Model(nn.Module, ABC):
|
|
222
316
|
":" + str(device) if isinstance(device, (int, float)) else device.index
|
223
317
|
)
|
224
318
|
self.device = dvc
|
319
|
+
self._apply_device_to()
|
225
320
|
return self
|
226
321
|
|
227
322
|
def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
@@ -232,11 +327,13 @@ class Model(nn.Module, ABC):
|
|
232
327
|
":" + str(device) if isinstance(device, (int, float)) else device.index
|
233
328
|
)
|
234
329
|
self.device = dvc
|
330
|
+
self._apply_device_to()
|
235
331
|
return self
|
236
332
|
|
237
333
|
def cpu(self) -> T:
|
238
334
|
super().cpu()
|
239
335
|
self.device = "cpu"
|
336
|
+
self._apply_device_to()
|
240
337
|
return self
|
241
338
|
|
242
339
|
def count_trainable_parameters(self, module_name: Optional[str] = None):
|
@@ -314,15 +411,26 @@ class Model(nn.Module, ABC):
|
|
314
411
|
else:
|
315
412
|
print(f"Non-Trainable Parameters: {params}")
|
316
413
|
|
317
|
-
def save_weights(
|
414
|
+
def save_weights(
|
415
|
+
self,
|
416
|
+
path: Union[Path, str],
|
417
|
+
replace: bool = False,
|
418
|
+
):
|
318
419
|
path = Path(path)
|
420
|
+
model_dir = path
|
319
421
|
if path.exists():
|
320
|
-
|
321
|
-
path
|
322
|
-
|
323
|
-
|
422
|
+
if path.is_dir():
|
423
|
+
model_dir = Path(path, f"model_{get_current_time()}.pt")
|
424
|
+
elif path.is_file():
|
425
|
+
if replace:
|
426
|
+
path.unlink()
|
427
|
+
else:
|
428
|
+
model_dir = Path(path.parent, f"model_{get_current_time()}.pt")
|
429
|
+
else:
|
430
|
+
if not "." in str(path):
|
431
|
+
model_dir = Path(path, f"model_{get_current_time()}.pt")
|
324
432
|
path.parent.mkdir(exist_ok=True, parents=True)
|
325
|
-
torch.save(obj=self.state_dict(), f=str(
|
433
|
+
torch.save(obj=self.state_dict(), f=str(model_dir))
|
326
434
|
|
327
435
|
def load_weights(
|
328
436
|
self,
|
@@ -338,7 +446,14 @@ class Model(nn.Module, ABC):
|
|
338
446
|
if not path.exists():
|
339
447
|
assert not raise_if_not_exists, "Path does not exists!"
|
340
448
|
return None
|
341
|
-
|
449
|
+
if path.is_dir():
|
450
|
+
possible_files = list(Path(path).rglob("*.pt"))
|
451
|
+
assert (
|
452
|
+
possible_files or not raise_if_not_exists
|
453
|
+
), "No model could be found in the given path!"
|
454
|
+
if not possible_files:
|
455
|
+
return None
|
456
|
+
path = sorted(possible_files)[-1]
|
342
457
|
state_dict = torch.load(
|
343
458
|
str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
|
344
459
|
)
|
@@ -353,30 +468,131 @@ class Model(nn.Module, ABC):
|
|
353
468
|
def inference(self, *args, **kwargs):
|
354
469
|
if self.training:
|
355
470
|
self.eval()
|
356
|
-
if self.autocast:
|
357
|
-
with torch.autocast(device_type=self.device.type):
|
358
|
-
return self(*args, **kwargs)
|
359
471
|
return self(*args, **kwargs)
|
360
472
|
|
361
473
|
def train_step(
|
362
474
|
self,
|
363
|
-
*
|
475
|
+
*inputs,
|
364
476
|
**kwargs,
|
365
477
|
):
|
366
478
|
"""Train Step"""
|
367
479
|
if not self.training:
|
368
480
|
self.train()
|
369
|
-
return self(*
|
370
|
-
|
371
|
-
@torch.autocast(device_type=_device.type)
|
372
|
-
def ac_forward(self, *args, **kwargs):
|
373
|
-
return
|
481
|
+
return self(*inputs, **kwargs)
|
374
482
|
|
375
483
|
def __call__(self, *args, **kwds) -> POSSIBLE_OUTPUT_TYPES:
|
376
|
-
|
484
|
+
if self.autocast and not self.training:
|
485
|
+
with torch.autocast(device_type=self.device.type):
|
486
|
+
return super().__call__(*args, **kwds)
|
487
|
+
else:
|
488
|
+
return super().__call__(*args, **kwds)
|
377
489
|
|
378
490
|
@abstractmethod
|
379
491
|
def forward(
|
380
492
|
self, *args, **kwargs
|
381
493
|
) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
|
382
494
|
pass
|
495
|
+
|
496
|
+
def add_loss(
|
497
|
+
self, loss: Union[float, list[float]], mode: Literal["train", "eval"] = "train"
|
498
|
+
):
|
499
|
+
if isinstance(loss, Number) and loss:
|
500
|
+
self._loss_history.append(loss, mode)
|
501
|
+
elif isinstance(loss, (list, tuple)):
|
502
|
+
if loss:
|
503
|
+
self._loss_history.append(sum(loss) / len(loss), mode=mode)
|
504
|
+
elif isinstance(loss, Tensor):
|
505
|
+
try:
|
506
|
+
self._loss_history.append(loss.detach().flatten().mean().item())
|
507
|
+
except Exception as e:
|
508
|
+
log_traceback(e, "add_loss - Tensor")
|
509
|
+
|
510
|
+
def save_loss_history(self, path: Optional[PathLike] = None):
|
511
|
+
self._loss_history.save(path)
|
512
|
+
|
513
|
+
def load_loss_history(self, path: Optional[PathLike] = None):
|
514
|
+
self._loss_history.load(path)
|
515
|
+
|
516
|
+
def get_loss_avg(self, mode: Literal["train", "eval"], quantity: int = 0):
|
517
|
+
t_list = self._loss_history.get("train")
|
518
|
+
if not t_list:
|
519
|
+
return float("nan")
|
520
|
+
if quantity > 0:
|
521
|
+
t_list = t_list[-quantity:]
|
522
|
+
return sum(t_list) / len(t_list)
|
523
|
+
|
524
|
+
def freeze_unfreeze_loss(
|
525
|
+
self,
|
526
|
+
losses: Optional[Union[float, List[float]]] = None,
|
527
|
+
trigger_loss: float = 0.1,
|
528
|
+
excluded_modules: Optional[List[str]] = None,
|
529
|
+
eval_last: int = 1000,
|
530
|
+
):
|
531
|
+
"""If a certain threshold is reached the weights will freeze or unfreeze the modules.
|
532
|
+
the biggest use-case for this function is when training GANs where the balance
|
533
|
+
from the discriminator and generator must be kept.
|
534
|
+
|
535
|
+
Args:
|
536
|
+
losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
|
537
|
+
trigger_loss (float, optional): The value where the weights will be either freeze or unfreeze. Defaults to 0.1.
|
538
|
+
excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
|
539
|
+
eval_last (float, optional): The number of previous losses to be locked behind to calculate the current averange. Default to 1000.
|
540
|
+
|
541
|
+
returns:
|
542
|
+
bool: True when its frozen and false when its trainable.
|
543
|
+
"""
|
544
|
+
if losses is not None:
|
545
|
+
calculated = None
|
546
|
+
self.add_loss(losses)
|
547
|
+
|
548
|
+
value = self.get_loss_avg("train", eval_last)
|
549
|
+
|
550
|
+
if value <= trigger_loss:
|
551
|
+
if self._is_unfrozen:
|
552
|
+
self.freeze_all(excluded_modules)
|
553
|
+
self._is_unfrozen = False
|
554
|
+
return True
|
555
|
+
else:
|
556
|
+
if not self._is_unfrozen:
|
557
|
+
self.unfreeze_all(excluded_modules)
|
558
|
+
self._is_unfrozen = True
|
559
|
+
return False
|
560
|
+
|
561
|
+
|
562
|
+
class _ModelExtended(Model):
|
563
|
+
"""Planed, but not ready, maybe in the near future?"""
|
564
|
+
criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None
|
565
|
+
optimizer: Optional[optim.Optimizer] = None
|
566
|
+
|
567
|
+
def train_step(
|
568
|
+
self,
|
569
|
+
*inputs,
|
570
|
+
loss_label: Optional[Tensor] = None,
|
571
|
+
**kwargs,
|
572
|
+
):
|
573
|
+
if not self.training:
|
574
|
+
self.train()
|
575
|
+
if self.optimizer is not None:
|
576
|
+
self.optimizer.zero_grad()
|
577
|
+
if self.autocast:
|
578
|
+
if self.criterion is None:
|
579
|
+
raise RuntimeError(
|
580
|
+
"To use autocast during training, you must assign a criterion first!"
|
581
|
+
)
|
582
|
+
with torch.autocast(device_type=self.device.type):
|
583
|
+
out = self.forward(*loss_label, **kwargs)
|
584
|
+
loss = self.criterion(out, loss_label)
|
585
|
+
|
586
|
+
if self.optimizer is not None:
|
587
|
+
loss.backward()
|
588
|
+
self.optimizer.step()
|
589
|
+
return loss
|
590
|
+
elif self.criterion is not None:
|
591
|
+
out = self.forward(*loss_label, **kwargs)
|
592
|
+
loss = self.criterion(out, loss_label)
|
593
|
+
if self.optimizer is not None:
|
594
|
+
loss.backward()
|
595
|
+
self.optimizer.step()
|
596
|
+
return loss
|
597
|
+
else:
|
598
|
+
return self(*inputs, **kwargs)
|
lt_tensor/model_zoo/bsc.py
CHANGED
@@ -208,3 +208,25 @@ class MultiScaleEncoder1D(Model):
|
|
208
208
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
209
209
|
# x: [B, C, T]
|
210
210
|
return self.net(x) # [B, hidden, T]
|
211
|
+
|
212
|
+
|
213
|
+
class AudioClassifier(Model):
|
214
|
+
def __init__(self, n_mels:int=80, num_classes=5):
|
215
|
+
super().__init__()
|
216
|
+
self.model = nn.Sequential(
|
217
|
+
nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
|
218
|
+
nn.LeakyReLU(0.2),
|
219
|
+
nn.Conv1d(256, 256, kernel_size=3, padding=1, groups=4),
|
220
|
+
nn.BatchNorm1d(256),
|
221
|
+
nn.LeakyReLU(0.2),
|
222
|
+
nn.Conv1d(256, 256, kernel_size=3, padding=1),
|
223
|
+
nn.BatchNorm1d(256),
|
224
|
+
nn.LeakyReLU(0.2),
|
225
|
+
nn.AdaptiveAvgPool1d(1), # Output shape: [B, 64, 1]
|
226
|
+
nn.Flatten(), # -> [B, 64]
|
227
|
+
nn.Linear(256, num_classes),
|
228
|
+
)
|
229
|
+
self.eval()
|
230
|
+
|
231
|
+
def forward(self, x):
|
232
|
+
return self.model(x)
|
lt_tensor/model_zoo/disc.py
CHANGED
@@ -76,20 +76,6 @@ class PeriodDiscriminator(Model):
|
|
76
76
|
return x.flatten(1, -1), f_map
|
77
77
|
|
78
78
|
|
79
|
-
class MultiPeriodDiscriminator(Model):
|
80
|
-
def __init__(self, periods=[2, 3, 5, 7, 11]):
|
81
|
-
super().__init__()
|
82
|
-
|
83
|
-
self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
|
84
|
-
|
85
|
-
def forward(self, x: torch.Tensor):
|
86
|
-
"""
|
87
|
-
x: (B, T)
|
88
|
-
Returns: list of tuples of outputs from each period discriminator and the f_map.
|
89
|
-
"""
|
90
|
-
return [d(x) for d in self.discriminators]
|
91
|
-
|
92
|
-
|
93
79
|
class ScaleDiscriminator(nn.Module):
|
94
80
|
def __init__(self, use_spectral_norm=False):
|
95
81
|
super().__init__()
|
@@ -123,11 +109,11 @@ class ScaleDiscriminator(nn.Module):
|
|
123
109
|
|
124
110
|
|
125
111
|
class MultiScaleDiscriminator(Model):
|
126
|
-
def __init__(self):
|
112
|
+
def __init__(self, layers: int = 3):
|
127
113
|
super().__init__()
|
128
114
|
self.pooling = nn.AvgPool1d(4, 2, padding=2)
|
129
115
|
self.discriminators = nn.ModuleList(
|
130
|
-
[ScaleDiscriminator(i == 0) for i in range(
|
116
|
+
[ScaleDiscriminator(i == 0) for i in range(layers)]
|
131
117
|
)
|
132
118
|
|
133
119
|
def forward(self, x: torch.Tensor):
|
@@ -136,57 +122,75 @@ class MultiScaleDiscriminator(Model):
|
|
136
122
|
Returns: list of outputs from each scale discriminator
|
137
123
|
"""
|
138
124
|
outputs = []
|
125
|
+
features = []
|
139
126
|
for i, d in enumerate(self.discriminators):
|
140
127
|
if i != 0:
|
141
128
|
x = self.pooling(x)
|
142
|
-
|
143
|
-
|
129
|
+
out, f_map = d(x)
|
130
|
+
outputs.append(out)
|
131
|
+
features.append(f_map)
|
132
|
+
return outputs, features
|
144
133
|
|
145
134
|
|
146
|
-
class
|
147
|
-
|
148
|
-
|
149
|
-
def __init__(self):
|
135
|
+
class MultiPeriodDiscriminator(Model):
|
136
|
+
def __init__(self, periods: List[int] = [2, 3, 5, 7, 11]):
|
150
137
|
super().__init__()
|
151
|
-
self.
|
152
|
-
self.msd = MultiScaleDiscriminator()
|
153
|
-
self.print_trainable_parameters()
|
154
|
-
|
155
|
-
def _get_group_(self):
|
156
|
-
pass
|
138
|
+
self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])
|
157
139
|
|
158
|
-
def forward(self, x: Tensor
|
159
|
-
|
140
|
+
def forward(self, x: torch.Tensor):
|
141
|
+
"""
|
142
|
+
x: (B, T)
|
143
|
+
Returns: list of tuples of outputs from each period discriminator and the f_map.
|
144
|
+
"""
|
145
|
+
# torch.log(torch.clip(x, min=clip_val))
|
146
|
+
out_map = []
|
147
|
+
feat_map = []
|
148
|
+
for d in self.discriminators:
|
149
|
+
out, feat = d(x)
|
150
|
+
out_map.append(out)
|
151
|
+
feat_map.append(feat)
|
152
|
+
return out_map, feat_map
|
160
153
|
|
161
154
|
|
162
|
-
def discriminator_loss(
|
155
|
+
def discriminator_loss(real_out_map, fake_out_map):
|
163
156
|
loss = 0.0
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
)
|
170
|
-
|
157
|
+
rl, fl = [], []
|
158
|
+
for real_out, fake_out in zip(real_out_map, fake_out_map):
|
159
|
+
real_loss = torch.mean((1.0 - real_out) ** 2)
|
160
|
+
fake_loss = torch.mean(fake_out**2)
|
161
|
+
loss += real_loss + fake_loss
|
162
|
+
rl.append(real_loss.item())
|
163
|
+
fl.append(fake_loss.item())
|
164
|
+
return loss, sum(rl), sum(fl)
|
171
165
|
|
172
166
|
|
173
|
-
def generator_adv_loss(
|
167
|
+
def generator_adv_loss(fake_disc_outputs: List[Tensor]):
|
174
168
|
loss = 0.0
|
175
|
-
for fake_out in
|
169
|
+
for fake_out in fake_disc_outputs:
|
176
170
|
fake_score = fake_out[0]
|
177
171
|
loss += -torch.mean(fake_score)
|
178
172
|
return loss
|
179
173
|
|
180
174
|
|
181
|
-
def
|
182
|
-
|
183
|
-
|
184
|
-
|
175
|
+
def feature_loss(
|
176
|
+
fmap_r,
|
177
|
+
fmap_g,
|
178
|
+
weight=2.0,
|
179
|
+
loss_fn: Callable[[Tensor, Tensor], Tensor] = F.l1_loss,
|
185
180
|
):
|
186
181
|
loss = 0.0
|
187
|
-
for
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
182
|
+
for dr, dg in zip(fmap_r, fmap_g):
|
183
|
+
for rl, gl in zip(dr, dg):
|
184
|
+
loss += loss_fn(rl - gl)
|
185
|
+
return loss * weight
|
186
|
+
|
187
|
+
|
188
|
+
def generator_loss(disc_generated_outputs):
|
189
|
+
loss = 0.0
|
190
|
+
gen_losses = []
|
191
|
+
for dg in disc_generated_outputs:
|
192
|
+
l = torch.mean((1.0 - dg) ** 2)
|
193
|
+
gen_losses.append(l.item())
|
194
|
+
loss += l
|
195
|
+
|
196
|
+
return loss, gen_losses
|
lt_tensor/model_zoo/istft.py
CHANGED
@@ -106,44 +106,3 @@ class Generator(Model):
|
|
106
106
|
classname = m.__class__.__name__
|
107
107
|
if "Conv" in classname:
|
108
108
|
m.weight.data.normal_(mean, std)
|
109
|
-
|
110
|
-
|
111
|
-
# Below are items found in the Rishikesh's repo that might work for this generator.
|
112
|
-
# https://github.com/rishikksh20/iSTFTNet-pytorch/blob/781480e9563d4dff5a8cc9ef1af6c6e0cab025c8/models.py
|
113
|
-
|
114
|
-
|
115
|
-
def feature_loss(fmap_r, fmap_g, weight=2.0):
|
116
|
-
"""Feature matching loss between real and generated feature maps."""
|
117
|
-
loss = 0.0
|
118
|
-
for dr, dg in zip(fmap_r, fmap_g):
|
119
|
-
for rl, gl in zip(dr, dg):
|
120
|
-
loss += torch.mean(torch.abs(rl - gl))
|
121
|
-
return loss * weight
|
122
|
-
|
123
|
-
|
124
|
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
125
|
-
"""LSGAN-style loss for real and fake predictions."""
|
126
|
-
loss = 0.0
|
127
|
-
r_losses, g_losses = [], []
|
128
|
-
|
129
|
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
130
|
-
r_loss = torch.mean((1.0 - dr) ** 2)
|
131
|
-
g_loss = torch.mean(dg**2)
|
132
|
-
loss += r_loss + g_loss
|
133
|
-
r_losses.append(r_loss.item())
|
134
|
-
g_losses.append(g_loss.item())
|
135
|
-
|
136
|
-
return loss, r_losses, g_losses
|
137
|
-
|
138
|
-
|
139
|
-
def generator_loss(disc_generated_outputs):
|
140
|
-
"""LSGAN generator loss encouraging fake to look like real (close to 1)."""
|
141
|
-
loss = 0.0
|
142
|
-
gen_losses = []
|
143
|
-
|
144
|
-
for dg in disc_generated_outputs:
|
145
|
-
l = torch.mean((1.0 - dg) ** 2)
|
146
|
-
gen_losses.append(l.item())
|
147
|
-
loss += l
|
148
|
-
|
149
|
-
return loss, gen_losses
|
lt_tensor/noise_tools.py
CHANGED
@@ -271,9 +271,8 @@ class NoiseSchedulerB(nn.Module):
|
|
271
271
|
def forward(
|
272
272
|
self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
|
273
273
|
) -> Tensor:
|
274
|
-
apply_noise()
|
275
274
|
assert (
|
276
|
-
|
275
|
+
0 <= t < self.timesteps
|
277
276
|
), f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
|
278
277
|
|
279
278
|
if noise is None:
|
lt_tensor/transform.py
CHANGED
@@ -420,44 +420,11 @@ class InverseTransform(Model):
|
|
420
420
|
self.onesided = onesided
|
421
421
|
self.normalized = normalized
|
422
422
|
self.window = torch.hann_window(win_length) if window is None else window
|
423
|
-
|
424
|
-
|
423
|
+
|
425
424
|
def _apply_device_to(self):
|
426
425
|
"""Applies to device while used with module `Model`"""
|
427
426
|
self.window = self.window.to(device=self.device)
|
428
427
|
|
429
|
-
def update_settings(
|
430
|
-
self,
|
431
|
-
*,
|
432
|
-
n_fft: Optional[int] = None,
|
433
|
-
hop_length: Optional[int] = None,
|
434
|
-
win_length: Optional[int] = None,
|
435
|
-
length: Optional[int] = None,
|
436
|
-
window: Optional[Tensor] = None,
|
437
|
-
onesided: Optional[bool] = None,
|
438
|
-
return_complex: Optional[bool] = None,
|
439
|
-
center: Optional[bool] = None,
|
440
|
-
normalized: Optional[bool] = None,
|
441
|
-
**_,
|
442
|
-
):
|
443
|
-
|
444
|
-
self.kwargs = dict(
|
445
|
-
n_fft=default(n_fft, self.n_fft),
|
446
|
-
hop_length=default(hop_length, self.hop_length),
|
447
|
-
win_length=default(win_length, self.win_length),
|
448
|
-
length=default(length, self.length),
|
449
|
-
window=default(window, self.window),
|
450
|
-
onesided=default(onesided, self.onesided),
|
451
|
-
return_complex=default(return_complex, self.return_complex),
|
452
|
-
center=default(center, self.center),
|
453
|
-
normalized=default(normalized, self.normalized),
|
454
|
-
)
|
455
|
-
if self.kwargs["onesided"] and self.kwargs["return_complex"]:
|
456
|
-
warnings.warn(
|
457
|
-
"You cannot use return_complex with `onesided` enabled. `return_complex` is set to False."
|
458
|
-
)
|
459
|
-
self.kwargs["return_complex"] = False
|
460
|
-
|
461
428
|
def forward(self, spec: Tensor, phase: Tensor, **kwargs):
|
462
429
|
"""
|
463
430
|
Perform the inverse short-time Fourier transform.
|
@@ -476,7 +443,16 @@ class InverseTransform(Model):
|
|
476
443
|
Tensor
|
477
444
|
Time-domain waveform reconstructed from `spec` and `phase`.
|
478
445
|
"""
|
479
|
-
if kwargs:
|
480
|
-
self.update_settings(**kwargs)
|
481
446
|
|
482
|
-
return torch.istft(
|
447
|
+
return torch.istft(
|
448
|
+
spec * torch.exp(phase * 1j),
|
449
|
+
n_fft = self.n_fft,
|
450
|
+
hop_length=self.hop_length,
|
451
|
+
win_length=self.win_length,
|
452
|
+
window=self.window,
|
453
|
+
center=self.center,
|
454
|
+
normalized=self.normalized,
|
455
|
+
onesided=self.onesided,
|
456
|
+
length=self.length,
|
457
|
+
return_complex=self.return_complex,
|
458
|
+
)
|
@@ -1,28 +1,28 @@
|
|
1
1
|
lt_tensor/__init__.py,sha256=uwJ7uiO18VYj8Z1V4KSOQ3ZrnowSgJWKCIiFBrzLMOI,429
|
2
|
-
lt_tensor/losses.py,sha256=
|
2
|
+
lt_tensor/losses.py,sha256=1wrke1e68hUBNAoPdJgKni0pJvXKcieza_R8nwBzMW4,4937
|
3
3
|
lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
|
4
4
|
lt_tensor/math_ops.py,sha256=ewIYkvxIy_Lab_9ExjFUgLs-oYLOu8IRRDo7f1pn3i8,2248
|
5
|
-
lt_tensor/misc_utils.py,sha256
|
6
|
-
lt_tensor/model_base.py,sha256=
|
5
|
+
lt_tensor/misc_utils.py,sha256=8LqtpmLKqCo79NdH160ByQojG8YTDcw8aHKFgOFGVLI,25425
|
6
|
+
lt_tensor/model_base.py,sha256=a2ogixC2fUyOLqz15TzCRcGXvBam--TdmpG83jw9Of8,21543
|
7
7
|
lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
|
8
|
-
lt_tensor/noise_tools.py,sha256=
|
8
|
+
lt_tensor/noise_tools.py,sha256=rfFbPsrsycWVuH9G4zZCQC9Vgi9r8hDaECcB0TZYSYQ,11345
|
9
9
|
lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
|
10
|
-
lt_tensor/transform.py,sha256=
|
10
|
+
lt_tensor/transform.py,sha256=LZZ9G7ud1cojERC7N7hMAbH9GC3ImY1hBIY00kVMs-I,13492
|
11
11
|
lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
lt_tensor/datasets/audio.py,sha256=YREyRsCvy-KS5tE0JNMWEdlIJogE1khLqhiq4wOWXVg,3777
|
13
13
|
lt_tensor/model_zoo/__init__.py,sha256=jipEk50_DTMQbGg8FnDDukxmh7Bcwvl_QVRS3rkb7aY,283
|
14
|
-
lt_tensor/model_zoo/bsc.py,sha256=
|
15
|
-
lt_tensor/model_zoo/disc.py,sha256=
|
14
|
+
lt_tensor/model_zoo/bsc.py,sha256=OQqsQDRBf6gWqoeGeEuIaTh96AqcDyTIbO8MAMNTtI4,7045
|
15
|
+
lt_tensor/model_zoo/disc.py,sha256=9RxyHYH2nGhxLs_yoEFVgerBfH4-qdaL2Mu9akyG0_M,5841
|
16
16
|
lt_tensor/model_zoo/fsn.py,sha256=5ySsg2OHjvTV_coPAdZQ0f7bz4ugJB8mDYsItmd61qA,2102
|
17
17
|
lt_tensor/model_zoo/gns.py,sha256=Tirr_grONp_FFQ_L7K-zV2lvkaC39h8mMl4QDpx9vLQ,6028
|
18
|
-
lt_tensor/model_zoo/istft.py,sha256=
|
18
|
+
lt_tensor/model_zoo/istft.py,sha256=RV7KVY7q4CYzzsWXH4NGJQwSqrYWwHh-16Q62lKoA2k,3594
|
19
19
|
lt_tensor/model_zoo/pos.py,sha256=N28v-rF8CELouYxQ9r45Jbd4ri5DNydwDgg7nzmQ4Ig,4471
|
20
20
|
lt_tensor/model_zoo/rsd.py,sha256=5bba50g1Hm5kMexuJ4SwOIJuyQ1qJd8Acrq-Ax6CqE8,6958
|
21
21
|
lt_tensor/model_zoo/tfrms.py,sha256=kauh-A13pk08SZ5OspEE5a-gPKD4rZr6tqMKWu3KGhk,4237
|
22
22
|
lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
|
23
23
|
lt_tensor/processors/audio.py,sha256=2Sta_KytTqGZh-ZeHpcCbqP6O8VT6QQVkx-7szA3Itc,8830
|
24
|
-
lt_tensor-0.0.
|
25
|
-
lt_tensor-0.0.
|
26
|
-
lt_tensor-0.0.
|
27
|
-
lt_tensor-0.0.
|
28
|
-
lt_tensor-0.0.
|
24
|
+
lt_tensor-0.0.1a11.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
|
25
|
+
lt_tensor-0.0.1a11.dist-info/METADATA,sha256=DNs5JZfr_mjve_GHy13Auics3BI_f1pNYBth-dQW04M,966
|
26
|
+
lt_tensor-0.0.1a11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
27
|
+
lt_tensor-0.0.1a11.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
|
28
|
+
lt_tensor-0.0.1a11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|