minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
try:
|
|
5
|
+
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
|
6
|
+
except ImportError:
|
|
7
|
+
from torch.nn.utils import weight_norm, spectral_norm
|
|
8
|
+
from typing import List, Optional, Tuple
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from torchaudio.transforms import Spectrogram
|
|
11
|
+
|
|
12
|
+
LRELU_SLOPE = 0.1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MultipleDiscriminator(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self, mpd: nn.Module, mrd: nn.Module
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.mpd = mpd
|
|
21
|
+
self.mrd = mrd
|
|
22
|
+
|
|
23
|
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
|
24
|
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
|
25
|
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
|
26
|
+
y_d_rs += this_y_d_rs
|
|
27
|
+
y_d_gs += this_y_d_gs
|
|
28
|
+
fmap_rs += this_fmap_rs
|
|
29
|
+
fmap_gs += this_fmap_gs
|
|
30
|
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
|
31
|
+
y_d_rs += this_y_d_rs
|
|
32
|
+
y_d_gs += this_y_d_gs
|
|
33
|
+
fmap_rs += this_fmap_rs
|
|
34
|
+
fmap_gs += this_fmap_gs
|
|
35
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MultiResolutionDiscriminator(nn.Module):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
|
42
|
+
num_embeddings: Optional[int] = None,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
|
46
|
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
|
50
|
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
|
51
|
+
Defaults to None.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.discriminators = nn.ModuleList(
|
|
56
|
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def forward(
|
|
60
|
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
|
61
|
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
|
62
|
+
y_d_rs = []
|
|
63
|
+
y_d_gs = []
|
|
64
|
+
fmap_rs = []
|
|
65
|
+
fmap_gs = []
|
|
66
|
+
|
|
67
|
+
for d in self.discriminators:
|
|
68
|
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
|
69
|
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
|
70
|
+
y_d_rs.append(y_d_r)
|
|
71
|
+
fmap_rs.append(fmap_r)
|
|
72
|
+
y_d_gs.append(y_d_g)
|
|
73
|
+
fmap_gs.append(fmap_g)
|
|
74
|
+
|
|
75
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class DiscriminatorR(nn.Module):
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
window_length: int,
|
|
82
|
+
num_embeddings: Optional[int] = None,
|
|
83
|
+
channels: int = 32,
|
|
84
|
+
hop_factor: float = 0.25,
|
|
85
|
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
|
86
|
+
):
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.window_length = window_length
|
|
89
|
+
self.hop_factor = hop_factor
|
|
90
|
+
self.spec_fn = Spectrogram(
|
|
91
|
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
|
92
|
+
)
|
|
93
|
+
n_fft = window_length // 2 + 1
|
|
94
|
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
|
95
|
+
self.bands = bands
|
|
96
|
+
convs = lambda: nn.ModuleList(
|
|
97
|
+
[
|
|
98
|
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
|
99
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
|
100
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
|
101
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
|
102
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
|
103
|
+
]
|
|
104
|
+
)
|
|
105
|
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
|
106
|
+
|
|
107
|
+
if num_embeddings is not None:
|
|
108
|
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
|
109
|
+
torch.nn.init.zeros_(self.emb.weight)
|
|
110
|
+
|
|
111
|
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
|
112
|
+
|
|
113
|
+
def spectrogram(self, x):
|
|
114
|
+
# Remove DC offset
|
|
115
|
+
x = x - x.mean(dim=-1, keepdims=True)
|
|
116
|
+
# Peak normalize the volume of input audio
|
|
117
|
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
|
118
|
+
x = self.spec_fn(x)
|
|
119
|
+
x = torch.view_as_real(x)
|
|
120
|
+
x = rearrange(x, "b f t c -> b c t f")
|
|
121
|
+
# Split into bands
|
|
122
|
+
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
|
123
|
+
return x_bands
|
|
124
|
+
|
|
125
|
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
|
126
|
+
x_bands = self.spectrogram(x)
|
|
127
|
+
fmap = []
|
|
128
|
+
x = []
|
|
129
|
+
for band, stack in zip(x_bands, self.band_convs):
|
|
130
|
+
for i, layer in enumerate(stack):
|
|
131
|
+
band = layer(band)
|
|
132
|
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
|
133
|
+
if i > 0:
|
|
134
|
+
fmap.append(band)
|
|
135
|
+
x.append(band)
|
|
136
|
+
x = torch.cat(x, dim=-1)
|
|
137
|
+
if cond_embedding_id is not None:
|
|
138
|
+
emb = self.emb(cond_embedding_id)
|
|
139
|
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
|
140
|
+
else:
|
|
141
|
+
h = 0
|
|
142
|
+
x = self.conv_post(x)
|
|
143
|
+
fmap.append(x)
|
|
144
|
+
x += h
|
|
145
|
+
|
|
146
|
+
return x, fmap
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class MultiResSpecDiscriminator(torch.nn.Module):
|
|
150
|
+
|
|
151
|
+
def __init__(self,
|
|
152
|
+
fft_sizes=[1024, 2048, 512],
|
|
153
|
+
hop_sizes=[120, 240, 50],
|
|
154
|
+
win_lengths=[600, 1200, 240],
|
|
155
|
+
window="hann_window"):
|
|
156
|
+
|
|
157
|
+
super(MultiResSpecDiscriminator, self).__init__()
|
|
158
|
+
self.discriminators = nn.ModuleList([
|
|
159
|
+
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
|
160
|
+
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
|
161
|
+
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
|
|
162
|
+
|
|
163
|
+
def forward(self, y, y_hat):
|
|
164
|
+
y_d_rs = []
|
|
165
|
+
y_d_gs = []
|
|
166
|
+
fmap_rs = []
|
|
167
|
+
fmap_gs = []
|
|
168
|
+
for _, d in enumerate(self.discriminators):
|
|
169
|
+
y_d_r, fmap_r = d(y)
|
|
170
|
+
y_d_g, fmap_g = d(y_hat)
|
|
171
|
+
y_d_rs.append(y_d_r)
|
|
172
|
+
fmap_rs.append(fmap_r)
|
|
173
|
+
y_d_gs.append(y_d_g)
|
|
174
|
+
fmap_gs.append(fmap_g)
|
|
175
|
+
|
|
176
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def stft(x, fft_size, hop_size, win_length, window):
|
|
180
|
+
"""Perform STFT and convert to magnitude spectrogram.
|
|
181
|
+
Args:
|
|
182
|
+
x (Tensor): Input signal tensor (B, T).
|
|
183
|
+
fft_size (int): FFT size.
|
|
184
|
+
hop_size (int): Hop size.
|
|
185
|
+
win_length (int): Window length.
|
|
186
|
+
window (str): Window function type.
|
|
187
|
+
Returns:
|
|
188
|
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
|
189
|
+
"""
|
|
190
|
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
|
191
|
+
|
|
192
|
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
|
193
|
+
return torch.abs(x_stft).transpose(2, 1)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class SpecDiscriminator(nn.Module):
|
|
197
|
+
"""docstring for Discriminator."""
|
|
198
|
+
|
|
199
|
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
|
200
|
+
super(SpecDiscriminator, self).__init__()
|
|
201
|
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
|
202
|
+
self.fft_size = fft_size
|
|
203
|
+
self.shift_size = shift_size
|
|
204
|
+
self.win_length = win_length
|
|
205
|
+
self.window = getattr(torch, window)(win_length)
|
|
206
|
+
self.discriminators = nn.ModuleList([
|
|
207
|
+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
|
208
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
|
209
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
|
210
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
|
211
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
|
212
|
+
])
|
|
213
|
+
|
|
214
|
+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
|
215
|
+
|
|
216
|
+
def forward(self, y):
|
|
217
|
+
|
|
218
|
+
fmap = []
|
|
219
|
+
y = y.squeeze(1)
|
|
220
|
+
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
|
|
221
|
+
y = y.unsqueeze(1)
|
|
222
|
+
for _, d in enumerate(self.discriminators):
|
|
223
|
+
y = d(y)
|
|
224
|
+
y = F.leaky_relu(y, LRELU_SLOPE)
|
|
225
|
+
fmap.append(y)
|
|
226
|
+
|
|
227
|
+
y = self.out(y)
|
|
228
|
+
fmap.append(y)
|
|
229
|
+
|
|
230
|
+
return torch.flatten(y, 1, -1), fmap
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
try:
|
|
17
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
18
|
+
except ImportError:
|
|
19
|
+
from torch.nn.utils import weight_norm
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConvRNNF0Predictor(nn.Module):
|
|
23
|
+
def __init__(self,
|
|
24
|
+
num_class: int = 1,
|
|
25
|
+
in_channels: int = 80,
|
|
26
|
+
cond_channels: int = 512
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self.num_class = num_class
|
|
31
|
+
self.condnet = nn.Sequential(
|
|
32
|
+
weight_norm(
|
|
33
|
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
|
34
|
+
),
|
|
35
|
+
nn.ELU(),
|
|
36
|
+
weight_norm(
|
|
37
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
38
|
+
),
|
|
39
|
+
nn.ELU(),
|
|
40
|
+
weight_norm(
|
|
41
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
42
|
+
),
|
|
43
|
+
nn.ELU(),
|
|
44
|
+
weight_norm(
|
|
45
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
46
|
+
),
|
|
47
|
+
nn.ELU(),
|
|
48
|
+
weight_norm(
|
|
49
|
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
|
50
|
+
),
|
|
51
|
+
nn.ELU(),
|
|
52
|
+
)
|
|
53
|
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
|
54
|
+
|
|
55
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
x = self.condnet(x)
|
|
57
|
+
x = x.transpose(1, 2)
|
|
58
|
+
return torch.abs(self.classifier(x).squeeze(-1))
|