lt-tensor 0.0.1a13__py3-none-any.whl → 0.0.1a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/datasets/audio.py +23 -6
- lt_tensor/misc_utils.py +1 -1
- lt_tensor/model_base.py +163 -123
- lt_tensor/model_zoo/diffwave/__init__.py +0 -0
- lt_tensor/model_zoo/diffwave/model.py +200 -0
- lt_tensor/model_zoo/diffwave/params.py +58 -0
- lt_tensor/model_zoo/discriminator.py +269 -151
- lt_tensor/model_zoo/features.py +102 -11
- lt_tensor/model_zoo/istft/generator.py +10 -66
- lt_tensor/model_zoo/istft/trainer.py +224 -72
- lt_tensor/model_zoo/residual.py +136 -32
- lt_tensor/processors/audio.py +5 -16
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/METADATA +2 -2
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/RECORD +17 -14
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a13.dist-info → lt_tensor-0.0.1a15.dist-info}/top_level.txt +0 -0
lt_tensor/datasets/audio.py
CHANGED
@@ -20,6 +20,7 @@ class WaveMelDataset(Dataset):
|
|
20
20
|
randomize_ranges: bool = False
|
21
21
|
alpha_wv: float = 1.0
|
22
22
|
limit_files: Optional[int] = None
|
23
|
+
min_frame_length: Optional[int] = None
|
23
24
|
max_frame_length: Optional[int] = None
|
24
25
|
|
25
26
|
def __init__(
|
@@ -27,12 +28,13 @@ class WaveMelDataset(Dataset):
|
|
27
28
|
audio_processor: AudioProcessor,
|
28
29
|
dataset_path: PathLike,
|
29
30
|
limit_files: Optional[int] = None,
|
31
|
+
min_frame_length: Optional[int] = None,
|
30
32
|
max_frame_length: Optional[int] = None,
|
31
33
|
randomize_ranges: Optional[bool] = None,
|
32
34
|
pre_load: bool = False,
|
33
35
|
normalize_waves: Optional[bool] = None,
|
34
36
|
alpha_wv: Optional[float] = None,
|
35
|
-
|
37
|
+
lib_norm: bool = True,
|
36
38
|
):
|
37
39
|
super().__init__()
|
38
40
|
assert max_frame_length is None or max_frame_length >= (
|
@@ -52,11 +54,15 @@ class WaveMelDataset(Dataset):
|
|
52
54
|
self.randomize_ranges = randomize_ranges
|
53
55
|
|
54
56
|
self.post_n_fft = (audio_processor.n_fft // 2) + 1
|
55
|
-
|
57
|
+
self.lib_norm = lib_norm
|
56
58
|
if max_frame_length is not None:
|
57
59
|
max_frame_length = max(self.post_n_fft + 1, max_frame_length)
|
58
60
|
self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
|
59
61
|
self.max_frame_length = max_frame_length
|
62
|
+
if min_frame_length is not None:
|
63
|
+
self.min_frame_length = max(
|
64
|
+
self.post_n_fft + 1, min(min_frame_length, max_frame_length)
|
65
|
+
)
|
60
66
|
|
61
67
|
self.files = self.ap.find_audios(dataset_path, maximum=None)
|
62
68
|
if limit_files:
|
@@ -96,21 +102,26 @@ class WaveMelDataset(Dataset):
|
|
96
102
|
}
|
97
103
|
|
98
104
|
def load_data(self, file: PathLike):
|
99
|
-
initial_audio = self.ap.
|
100
|
-
self.
|
101
|
-
file, normalize=self.normalize_waves, alpha=self.alpha_wv
|
102
|
-
)
|
105
|
+
initial_audio = self.ap.load_audio(
|
106
|
+
file, normalize=self.lib_norm, alpha=self.alpha_wv
|
103
107
|
)
|
108
|
+
if self.normalize_waves:
|
109
|
+
initial_audio = self.ap.normalize_audio(initial_audio)
|
104
110
|
if initial_audio.shape[-1] < self.post_n_fft:
|
105
111
|
return None
|
106
112
|
|
113
|
+
if self.min_frame_length is not None:
|
114
|
+
if self.min_frame_length > initial_audio.shape[-1]:
|
115
|
+
return None
|
107
116
|
if (
|
108
117
|
not self.max_frame_length
|
109
118
|
or initial_audio.shape[-1] <= self.max_frame_length
|
110
119
|
):
|
120
|
+
|
111
121
|
audio_rms = self.ap.compute_rms(initial_audio)
|
112
122
|
audio_pitch = self.ap.compute_pitch(initial_audio)
|
113
123
|
audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
|
124
|
+
|
114
125
|
return [
|
115
126
|
self._add_dict(initial_audio, audio_mel, audio_pitch, audio_rms, file)
|
116
127
|
]
|
@@ -129,6 +140,12 @@ class WaveMelDataset(Dataset):
|
|
129
140
|
if fragment.shape[-1] < self.post_n_fft:
|
130
141
|
# Too small
|
131
142
|
continue
|
143
|
+
if (
|
144
|
+
self.min_frame_length is not None
|
145
|
+
and self.min_frame_length > fragment.shape[-1]
|
146
|
+
):
|
147
|
+
continue
|
148
|
+
|
132
149
|
audio_rms = self.ap.compute_rms(fragment)
|
133
150
|
audio_pitch = self.ap.compute_pitch(fragment)
|
134
151
|
audio_mel = self.ap.compute_mel(fragment, add_base=True)
|
lt_tensor/misc_utils.py
CHANGED
@@ -240,7 +240,7 @@ class LogTensor:
|
|
240
240
|
stored_items: List[
|
241
241
|
Dict[str, Union[str, Number, Tensor, List[Union[Tensor, Number, str]]]]
|
242
242
|
] = []
|
243
|
-
max_stored_items: int =
|
243
|
+
max_stored_items: int = 8
|
244
244
|
|
245
245
|
def _setup_message(self, title: str, t: Union[Tensor, str, int]):
|
246
246
|
try:
|
lt_tensor/model_base.py
CHANGED
@@ -179,6 +179,20 @@ class Model(_Devices_Base, ABC):
|
|
179
179
|
"""
|
180
180
|
|
181
181
|
_is_unfrozen: bool = False
|
182
|
+
# list with modules that can be frozen or unfrozen
|
183
|
+
registered_freezable_modules: List[str] = []
|
184
|
+
is_frozen: bool = False
|
185
|
+
_is_gradient_freezable: bool = (
|
186
|
+
False # to control if the module can or cannot be freezed by other modules from 'Model' class
|
187
|
+
)
|
188
|
+
# this is to be used on the case of they module requires low-rank adapters
|
189
|
+
_low_rank_lambda: Optional[Callable[[], nn.Module]] = (
|
190
|
+
None # Example: lambda: nn.Linear(32, 32, True)
|
191
|
+
)
|
192
|
+
low_rank_adapter: Union[nn.Identity, nn.Module, nn.Sequential] = nn.Identity()
|
193
|
+
|
194
|
+
# dont save list:
|
195
|
+
_dont_save_items: List[str] = []
|
182
196
|
|
183
197
|
def _apply_device_to(self):
|
184
198
|
"""Add here components that are needed to have device applied to them,
|
@@ -192,116 +206,67 @@ class Model(_Devices_Base, ABC):
|
|
192
206
|
"""
|
193
207
|
pass
|
194
208
|
|
195
|
-
def freeze_weight(self, weight: Union[str, nn.Module], freeze: bool):
|
196
|
-
assert isinstance(weight, (str, nn.Module))
|
197
|
-
if isinstance(weight, str):
|
198
|
-
if hasattr(self, weight):
|
199
|
-
w = getattr(self, weight)
|
200
|
-
if isinstance(w, nn.Module):
|
201
|
-
|
202
|
-
w.requires_grad_(not freeze)
|
203
|
-
else:
|
204
|
-
weight.requires_grad_(not freeze)
|
205
|
-
|
206
|
-
def _freeze_unfreeze(
|
207
|
-
self,
|
208
|
-
weight: Union[str, nn.Module],
|
209
|
-
task: Literal["freeze", "unfreeze"] = "freeze",
|
210
|
-
_skip_except: bool = False,
|
211
|
-
):
|
212
|
-
try:
|
213
|
-
assert isinstance(weight, (str, nn.Module))
|
214
|
-
if isinstance(weight, str):
|
215
|
-
w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a valid attribute of {self._get_name()}"
|
216
|
-
if hasattr(self, weight):
|
217
|
-
w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a Module type."
|
218
|
-
w = getattr(self, weight)
|
219
|
-
if isinstance(w, nn.Module):
|
220
|
-
w_txt = f"Successfully {task} the module '{weight}'."
|
221
|
-
w.requires_grad_(task == "unfreeze")
|
222
|
-
|
223
|
-
else:
|
224
|
-
w.requires_grad_(task == "unfreeze")
|
225
|
-
w_txt = f"Successfully '{task}' the module '{weight}'."
|
226
|
-
return w_txt
|
227
|
-
except Exception as e:
|
228
|
-
if not _skip_except:
|
229
|
-
raise e
|
230
|
-
return str(e)
|
231
|
-
|
232
|
-
def freeze_weight(
|
233
|
-
self,
|
234
|
-
weight: Union[str, nn.Module],
|
235
|
-
_skip_except: bool = False,
|
236
|
-
):
|
237
|
-
return self._freeze_unfreeze(weight, "freeze", _skip_except)
|
238
|
-
|
239
|
-
def unfreeze_weight(
|
240
|
-
self,
|
241
|
-
weight: Union[str, nn.Module],
|
242
|
-
_skip_except: bool = False,
|
243
|
-
):
|
244
|
-
return self._freeze_unfreeze(weight, "freeze", _skip_except)
|
245
|
-
|
246
209
|
def freeze_all(self, exclude: Optional[List[str]] = None):
|
247
210
|
no_exclusions = not exclude
|
248
|
-
|
249
|
-
|
250
|
-
for name,
|
211
|
+
no_exclusions = not exclude
|
212
|
+
results = []
|
213
|
+
for name, module in self.named_modules():
|
214
|
+
if name not in self.registered_freezable_modules:
|
215
|
+
results.append(
|
216
|
+
(
|
217
|
+
name,
|
218
|
+
"Unregistered module, to freeze/unfreeze it add its name into 'registered_freezable_modules'.",
|
219
|
+
)
|
220
|
+
)
|
221
|
+
continue
|
251
222
|
if no_exclusions:
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
frozen.append(name)
|
256
|
-
else:
|
257
|
-
not_frozen.append((name, "was_frozen"))
|
258
|
-
except Exception as e:
|
259
|
-
not_frozen.append((name, str(e)))
|
260
|
-
elif any(layer in name for layer in exclude):
|
261
|
-
try:
|
262
|
-
if param.requires_grad:
|
263
|
-
param.requires_grad_(False)
|
264
|
-
frozen.append(name)
|
265
|
-
else:
|
266
|
-
not_frozen.append((name, "was_frozen"))
|
267
|
-
except Exception as e:
|
268
|
-
not_frozen.append((name, str(e)))
|
223
|
+
self.change_frozen_state(True, module)
|
224
|
+
elif not any(exclusion in name for exclusion in exclude):
|
225
|
+
results.append((name, self.change_frozen_state(True, module)))
|
269
226
|
else:
|
270
|
-
|
271
|
-
return
|
227
|
+
results.append((name, "excluded"))
|
228
|
+
return results
|
272
229
|
|
273
230
|
def unfreeze_all(self, exclude: Optional[list[str]] = None):
|
274
231
|
"""Unfreezes all model parameters except specified layers."""
|
275
232
|
no_exclusions = not exclude
|
276
|
-
|
277
|
-
|
278
|
-
|
233
|
+
results = []
|
234
|
+
for name, module in self.named_modules():
|
235
|
+
if name not in self.registered_freezable_modules:
|
236
|
+
results.append(
|
237
|
+
(
|
238
|
+
name,
|
239
|
+
"Unregistered module, to freeze/unfreeze it add it into 'registered_freezable_modules'.",
|
240
|
+
)
|
241
|
+
)
|
242
|
+
continue
|
279
243
|
if no_exclusions:
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
unfrozen.append(name)
|
284
|
-
else:
|
285
|
-
not_unfrozen.append((name, "was_unfrozen"))
|
286
|
-
except Exception as e:
|
287
|
-
not_unfrozen.append((name, str(e)))
|
288
|
-
elif any(layer in name for layer in exclude):
|
289
|
-
try:
|
290
|
-
if not param.requires_grad:
|
291
|
-
param.requires_grad_(True)
|
292
|
-
unfrozen.append(name)
|
293
|
-
else:
|
294
|
-
not_unfrozen.append((name, "was_unfrozen"))
|
295
|
-
except Exception as e:
|
296
|
-
not_unfrozen.append((name, str(e)))
|
244
|
+
self.change_frozen_state(False, module)
|
245
|
+
elif not any(exclusion in name for exclusion in exclude):
|
246
|
+
results.append((name, self.change_frozen_state(False, module)))
|
297
247
|
else:
|
298
|
-
|
299
|
-
return
|
248
|
+
results.append((name, "excluded"))
|
249
|
+
return results
|
250
|
+
|
251
|
+
def change_frozen_state(self, freeze: bool, module: nn.Module):
|
252
|
+
try:
|
253
|
+
if isinstance(module, Model):
|
254
|
+
if module._is_gradient_freezable:
|
255
|
+
if freeze:
|
256
|
+
return module.freeze_all()
|
257
|
+
return module.unfreeze_all()
|
258
|
+
else:
|
259
|
+
return "Not Allowed"
|
260
|
+
elif isinstance(module, nn.Module):
|
261
|
+
module.requires_grad_(not freeze)
|
262
|
+
return not freeze
|
263
|
+
except Exception as e:
|
264
|
+
return e
|
300
265
|
|
301
|
-
def
|
266
|
+
def trainable_parameters(self, module_name: Optional[str] = None):
|
302
267
|
"""Gets the number of trainable parameters from either the entire model or from a specific module."""
|
303
268
|
if module_name is not None:
|
304
|
-
assert hasattr(self, module_name), f"Module {module_name}
|
269
|
+
assert hasattr(self, module_name), f"Module '{module_name}' not found."
|
305
270
|
module = getattr(self, module_name)
|
306
271
|
return sum(
|
307
272
|
[
|
@@ -318,10 +283,10 @@ class Model(_Devices_Base, ABC):
|
|
318
283
|
]
|
319
284
|
)
|
320
285
|
|
321
|
-
def
|
286
|
+
def non_trainable_parameters(self, module_name: Optional[str] = None):
|
322
287
|
"""Gets the number of non-trainable parameters from either the entire model or from a specific module."""
|
323
288
|
if module_name is not None:
|
324
|
-
assert hasattr(self, module_name), f"Module {module_name}
|
289
|
+
assert hasattr(self, module_name), f"Module '{module_name}' not found."
|
325
290
|
module = getattr(self, module_name)
|
326
291
|
return sum(
|
327
292
|
[
|
@@ -338,10 +303,10 @@ class Model(_Devices_Base, ABC):
|
|
338
303
|
]
|
339
304
|
)
|
340
305
|
|
341
|
-
def
|
306
|
+
def extract_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
|
342
307
|
"""Returns the weights of the model entry model or from a specified module"""
|
343
308
|
if module_name is not None:
|
344
|
-
assert hasattr(self, module_name), f"Module {module_name}
|
309
|
+
assert hasattr(self, module_name), f"Module '{module_name}' not found."
|
345
310
|
module = getattr(self, module_name)
|
346
311
|
params = []
|
347
312
|
if isinstance(module, nn.Module):
|
@@ -351,27 +316,29 @@ class Model(_Devices_Base, ABC):
|
|
351
316
|
raise (f"{module_name} is has no weights")
|
352
317
|
return [x.data.detach() for x in self.parameters()]
|
353
318
|
|
354
|
-
def
|
355
|
-
self,
|
356
|
-
|
357
|
-
|
319
|
+
def format_trainable_parameters(self, module_name: Optional[str] = None) -> str:
|
320
|
+
params = format(self.trainable_parameters(module_name), ",").replace(",", ".")
|
321
|
+
return params
|
322
|
+
|
323
|
+
def format_non_trainable_parameters(self, module_name: Optional[str] = None) -> str:
|
324
|
+
params = format(self.non_trainable_parameters(module_name), ",").replace(
|
358
325
|
",", "."
|
359
326
|
)
|
360
|
-
|
361
|
-
|
327
|
+
return params
|
328
|
+
|
329
|
+
def print_trainable_parameters(self, module_name: Optional[str] = None) -> str:
|
330
|
+
fmt = self.format_trainable_parameters(module_name)
|
331
|
+
if module_name is not None:
|
332
|
+
print(f"Trainable parameter(s) for module '{module_name}': {fmt}")
|
362
333
|
else:
|
363
|
-
print(f"Trainable
|
334
|
+
print(f"Trainable parameter(s): {fmt}")
|
364
335
|
|
365
|
-
def print_non_trainable_parameters(
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
",", "."
|
370
|
-
)
|
371
|
-
if module_name:
|
372
|
-
print(f'Non-Trainable Parameters from "{module_name}": {params}')
|
336
|
+
def print_non_trainable_parameters(self, module_name: Optional[str] = None) -> str:
|
337
|
+
fmt = self.format_non_trainable_parameters(module_name)
|
338
|
+
if module_name is not None:
|
339
|
+
print(f"Non-Trainable parameter(s) for module '{module_name}': {fmt}")
|
373
340
|
else:
|
374
|
-
print(f"Non-Trainable
|
341
|
+
print(f"Non-Trainable parameter(s): {fmt}")
|
375
342
|
|
376
343
|
@classmethod
|
377
344
|
def from_pretrained(
|
@@ -390,6 +357,7 @@ class Model(_Devices_Base, ABC):
|
|
390
357
|
self,
|
391
358
|
path: Union[Path, str],
|
392
359
|
replace: bool = False,
|
360
|
+
save_with_adapters: bool = False,
|
393
361
|
):
|
394
362
|
path = Path(path)
|
395
363
|
model_dir = path
|
@@ -405,17 +373,86 @@ class Model(_Devices_Base, ABC):
|
|
405
373
|
if not "." in str(path):
|
406
374
|
model_dir = Path(path, f"model_{get_current_time()}.pt")
|
407
375
|
path.parent.mkdir(exist_ok=True, parents=True)
|
408
|
-
|
376
|
+
|
377
|
+
state_dict = self.state_dict()
|
378
|
+
if not save_with_adapters or isinstance(self.low_rank_adapter, nn.Identity):
|
379
|
+
state_dict.pop("low_rank_adapter", None)
|
380
|
+
torch.save(obj=state_dict, f=str(model_dir))
|
381
|
+
|
382
|
+
def save_lora(
|
383
|
+
self,
|
384
|
+
path: Union[Path, str],
|
385
|
+
replace: bool = False,
|
386
|
+
):
|
387
|
+
assert not isinstance(
|
388
|
+
self.low_rank_adapter, nn.Identity
|
389
|
+
), "The adapter is empty!"
|
390
|
+
path = Path(path)
|
391
|
+
model_dir = path
|
392
|
+
if path.exists():
|
393
|
+
if path.is_dir():
|
394
|
+
model_dir = Path(path, f"adapter_{get_current_time()}.pt")
|
395
|
+
elif path.is_file():
|
396
|
+
if replace:
|
397
|
+
path.unlink()
|
398
|
+
else:
|
399
|
+
model_dir = Path(path.parent, f"adapter_{get_current_time()}.pt")
|
400
|
+
else:
|
401
|
+
if not "." in str(path):
|
402
|
+
model_dir = Path(path, f"adapter_{get_current_time()}.pt")
|
403
|
+
|
404
|
+
state_dict = self.low_rank_adapter.state_dict()
|
405
|
+
torch.save(obj=state_dict, f=str(model_dir))
|
406
|
+
|
407
|
+
def load_lora(
|
408
|
+
self,
|
409
|
+
path: Union[Path, str],
|
410
|
+
raise_if_not_exists: bool = False,
|
411
|
+
strict: bool = False,
|
412
|
+
assign: bool = False,
|
413
|
+
weights_only: bool = True,
|
414
|
+
mmap: Optional[bool] = None,
|
415
|
+
**pickle_load_args,
|
416
|
+
):
|
417
|
+
assert (
|
418
|
+
self._low_rank_lambda is not None
|
419
|
+
), "Lora not implemented! '_low_rank_lambda' must be setup to deploy a proper module"
|
420
|
+
path = Path(path)
|
421
|
+
if not path.exists():
|
422
|
+
assert not raise_if_not_exists, "Path does not exists!"
|
423
|
+
return None
|
424
|
+
|
425
|
+
if path.is_dir():
|
426
|
+
possible_files = list(Path(path).rglob("adapter_*.pt"))
|
427
|
+
assert (
|
428
|
+
possible_files or not raise_if_not_exists
|
429
|
+
), "No model could be found in the given path!"
|
430
|
+
if not possible_files:
|
431
|
+
return None
|
432
|
+
path = sorted(possible_files)[-1]
|
433
|
+
|
434
|
+
state_dict = torch.load(
|
435
|
+
str(path), weights_only=weights_only, mmap=mmap, **pickle_load_args
|
436
|
+
)
|
437
|
+
self.low_rank_adapter = None
|
438
|
+
gc.collect()
|
439
|
+
self.low_rank_adapter = self._low_rank_lambda()
|
440
|
+
incompatible_keys = self.low_rank_adapter.load_state_dict(
|
441
|
+
state_dict,
|
442
|
+
strict=strict,
|
443
|
+
assign=assign,
|
444
|
+
)
|
445
|
+
return incompatible_keys
|
409
446
|
|
410
447
|
def load_weights(
|
411
448
|
self,
|
412
449
|
path: Union[Path, str],
|
413
450
|
raise_if_not_exists: bool = False,
|
414
|
-
strict: bool =
|
451
|
+
strict: bool = False,
|
415
452
|
assign: bool = False,
|
416
|
-
weights_only: bool =
|
453
|
+
weights_only: bool = True,
|
417
454
|
mmap: Optional[bool] = None,
|
418
|
-
**
|
455
|
+
**pickle_load_args,
|
419
456
|
):
|
420
457
|
path = Path(path)
|
421
458
|
if not path.exists():
|
@@ -430,7 +467,7 @@ class Model(_Devices_Base, ABC):
|
|
430
467
|
return None
|
431
468
|
path = sorted(possible_files)[-1]
|
432
469
|
state_dict = torch.load(
|
433
|
-
str(path), weights_only=weights_only, mmap=mmap, **
|
470
|
+
str(path), weights_only=weights_only, mmap=mmap, **pickle_load_args
|
434
471
|
)
|
435
472
|
incompatible_keys = self.load_state_dict(
|
436
473
|
state_dict,
|
@@ -439,6 +476,9 @@ class Model(_Devices_Base, ABC):
|
|
439
476
|
)
|
440
477
|
return incompatible_keys
|
441
478
|
|
479
|
+
def lora_step(self, *arg, **kwargs):
|
480
|
+
raise NotImplementedError("Not implemented for this model")
|
481
|
+
|
442
482
|
@torch.no_grad()
|
443
483
|
def inference(self, *args, **kwargs):
|
444
484
|
if self.training:
|
@@ -524,7 +564,7 @@ class Model(_Devices_Base, ABC):
|
|
524
564
|
bool: True when its frozen and false when its trainable.
|
525
565
|
"""
|
526
566
|
if losses is not None:
|
527
|
-
self.add_loss(losses)
|
567
|
+
self.add_loss(losses, "train")
|
528
568
|
|
529
569
|
if isinstance(trigger_loss, bool):
|
530
570
|
if trigger_loss:
|
File without changes
|
@@ -0,0 +1,200 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
from math import sqrt
|
7
|
+
|
8
|
+
|
9
|
+
class AttrDict(dict):
|
10
|
+
def __init__(self, *args, **kwargs):
|
11
|
+
super(AttrDict, self).__init__(*args, **kwargs)
|
12
|
+
self.__dict__ = self
|
13
|
+
|
14
|
+
def override(self, attrs):
|
15
|
+
if isinstance(attrs, dict):
|
16
|
+
self.__dict__.update(**attrs)
|
17
|
+
elif isinstance(attrs, (list, tuple, set)):
|
18
|
+
for attr in attrs:
|
19
|
+
self.override(attr)
|
20
|
+
elif attrs is not None:
|
21
|
+
raise NotImplementedError
|
22
|
+
return self
|
23
|
+
|
24
|
+
|
25
|
+
params = AttrDict(
|
26
|
+
# Training params
|
27
|
+
batch_size=16,
|
28
|
+
learning_rate=2e-4,
|
29
|
+
max_grad_norm=None,
|
30
|
+
# Data params
|
31
|
+
sample_rate=22050,
|
32
|
+
n_mels=80,
|
33
|
+
n_fft=1024,
|
34
|
+
hop_samples=256,
|
35
|
+
crop_mel_frames=62, # Probably an error in paper.
|
36
|
+
# Model params
|
37
|
+
residual_layers=30,
|
38
|
+
residual_channels=64,
|
39
|
+
dilation_cycle_length=10,
|
40
|
+
unconditional=False,
|
41
|
+
noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
|
42
|
+
inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],
|
43
|
+
# unconditional sample len
|
44
|
+
audio_len=22050 * 5, # unconditional_synthesis_samples
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
def Conv1d(*args, **kwargs):
|
49
|
+
layer = nn.Conv1d(*args, **kwargs)
|
50
|
+
nn.init.kaiming_normal_(layer.weight)
|
51
|
+
return layer
|
52
|
+
|
53
|
+
|
54
|
+
class DiffusionEmbedding(nn.Module):
|
55
|
+
def __init__(self, max_steps):
|
56
|
+
super().__init__()
|
57
|
+
self.register_buffer(
|
58
|
+
"embedding", self._build_embedding(max_steps), persistent=False
|
59
|
+
)
|
60
|
+
self.projection1 = nn.Linear(128, 512)
|
61
|
+
self.projection2 = nn.Linear(512, 512)
|
62
|
+
self.activation = nn.SiLU()
|
63
|
+
|
64
|
+
def forward(self, diffusion_step):
|
65
|
+
if diffusion_step.dtype in [torch.int32, torch.int64]:
|
66
|
+
x = self.embedding[diffusion_step]
|
67
|
+
else:
|
68
|
+
x = self._lerp_embedding(diffusion_step)
|
69
|
+
x = self.projection1(x)
|
70
|
+
x = self.activation(x)
|
71
|
+
x = self.projection2(x)
|
72
|
+
x = self.activation(x)
|
73
|
+
return x
|
74
|
+
|
75
|
+
def _lerp_embedding(self, t):
|
76
|
+
low_idx = torch.floor(t).long()
|
77
|
+
high_idx = torch.ceil(t).long()
|
78
|
+
low = self.embedding[low_idx]
|
79
|
+
high = self.embedding[high_idx]
|
80
|
+
return low + (high - low) * (t - low_idx)
|
81
|
+
|
82
|
+
def _build_embedding(self, max_steps):
|
83
|
+
steps = torch.arange(max_steps).unsqueeze(1) # [T,1]
|
84
|
+
dims = torch.arange(64).unsqueeze(0) # [1,64]
|
85
|
+
table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64]
|
86
|
+
table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
|
87
|
+
return table
|
88
|
+
|
89
|
+
|
90
|
+
class SpectrogramUpsampler(nn.Module):
|
91
|
+
def __init__(self, n_mels):
|
92
|
+
super().__init__()
|
93
|
+
self.conv1 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
|
94
|
+
self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
|
95
|
+
|
96
|
+
def forward(self, x):
|
97
|
+
x = torch.unsqueeze(x, 1)
|
98
|
+
x = self.conv1(x)
|
99
|
+
x = F.leaky_relu(x, 0.4)
|
100
|
+
x = self.conv2(x)
|
101
|
+
x = F.leaky_relu(x, 0.4)
|
102
|
+
x = torch.squeeze(x, 1)
|
103
|
+
return x
|
104
|
+
|
105
|
+
|
106
|
+
class ResidualBlock(nn.Module):
|
107
|
+
def __init__(self, n_mels, residual_channels, dilation, uncond=False):
|
108
|
+
"""
|
109
|
+
:param n_mels: inplanes of conv1x1 for spectrogram conditional
|
110
|
+
:param residual_channels: audio conv
|
111
|
+
:param dilation: audio conv dilation
|
112
|
+
:param uncond: disable spectrogram conditional
|
113
|
+
"""
|
114
|
+
super().__init__()
|
115
|
+
self.dilated_conv = Conv1d(
|
116
|
+
residual_channels,
|
117
|
+
2 * residual_channels,
|
118
|
+
3,
|
119
|
+
padding=dilation,
|
120
|
+
dilation=dilation,
|
121
|
+
)
|
122
|
+
self.diffusion_projection = nn.Linear(512, residual_channels)
|
123
|
+
if not uncond: # conditional model
|
124
|
+
self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
|
125
|
+
else: # unconditional model
|
126
|
+
self.conditioner_projection = None
|
127
|
+
|
128
|
+
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
129
|
+
|
130
|
+
def forward(self, x, diffusion_step, conditioner=None):
|
131
|
+
assert (conditioner is None and self.conditioner_projection is None) or (
|
132
|
+
conditioner is not None and self.conditioner_projection is not None
|
133
|
+
)
|
134
|
+
|
135
|
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
136
|
+
y = x + diffusion_step
|
137
|
+
if self.conditioner_projection is None: # using a unconditional model
|
138
|
+
y = self.dilated_conv(y)
|
139
|
+
else:
|
140
|
+
conditioner = self.conditioner_projection(conditioner)
|
141
|
+
y = self.dilated_conv(y) + conditioner
|
142
|
+
|
143
|
+
gate, filter = torch.chunk(y, 2, dim=1)
|
144
|
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
145
|
+
|
146
|
+
y = self.output_projection(y)
|
147
|
+
residual, skip = torch.chunk(y, 2, dim=1)
|
148
|
+
return (x + residual) / sqrt(2.0), skip
|
149
|
+
|
150
|
+
|
151
|
+
class DiffWave(nn.Module):
|
152
|
+
def __init__(self, params):
|
153
|
+
super().__init__()
|
154
|
+
self.params = params
|
155
|
+
self.input_projection = Conv1d(1, params.residual_channels, 1)
|
156
|
+
self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
|
157
|
+
if self.params.unconditional: # use unconditional model
|
158
|
+
self.spectrogram_upsampler = None
|
159
|
+
else:
|
160
|
+
self.spectrogram_upsampler = SpectrogramUpsampler(params.n_mels)
|
161
|
+
|
162
|
+
self.residual_layers = nn.ModuleList(
|
163
|
+
[
|
164
|
+
ResidualBlock(
|
165
|
+
params.n_mels,
|
166
|
+
params.residual_channels,
|
167
|
+
2 ** (i % params.dilation_cycle_length),
|
168
|
+
uncond=params.unconditional,
|
169
|
+
)
|
170
|
+
for i in range(params.residual_layers)
|
171
|
+
]
|
172
|
+
)
|
173
|
+
self.skip_projection = Conv1d(
|
174
|
+
params.residual_channels, params.residual_channels, 1
|
175
|
+
)
|
176
|
+
self.output_projection = Conv1d(params.residual_channels, 1, 1)
|
177
|
+
nn.init.zeros_(self.output_projection.weight)
|
178
|
+
|
179
|
+
def forward(self, audio, diffusion_step, spectrogram=None):
|
180
|
+
assert (spectrogram is None and self.spectrogram_upsampler is None) or (
|
181
|
+
spectrogram is not None and self.spectrogram_upsampler is not None
|
182
|
+
)
|
183
|
+
x = audio.unsqueeze(1)
|
184
|
+
x = self.input_projection(x)
|
185
|
+
x = F.relu(x)
|
186
|
+
|
187
|
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
188
|
+
if self.spectrogram_upsampler: # use conditional model
|
189
|
+
spectrogram = self.spectrogram_upsampler(spectrogram)
|
190
|
+
|
191
|
+
skip = None
|
192
|
+
for layer in self.residual_layers:
|
193
|
+
x, skip_connection = layer(x, diffusion_step, spectrogram)
|
194
|
+
skip = skip_connection if skip is None else skip_connection + skip
|
195
|
+
|
196
|
+
x = skip / sqrt(len(self.residual_layers))
|
197
|
+
x = self.skip_projection(x)
|
198
|
+
x = F.relu(x)
|
199
|
+
x = self.output_projection(x)
|
200
|
+
return x
|