minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,217 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import math
4
+ import os
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from librosa.util import normalize
12
+ from scipy.io.wavfile import read
13
+
14
+ MAX_WAV_VALUE = 32768.0
15
+
16
+
17
+ def load_wav(full_path):
18
+ sampling_rate, data = read(full_path)
19
+ return data, sampling_rate
20
+
21
+
22
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
23
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
24
+
25
+
26
+ def dynamic_range_decompression(x, C=1):
27
+ return np.exp(x) / C
28
+
29
+
30
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
31
+ return torch.log(torch.clamp(x, min=clip_val) * C)
32
+
33
+
34
+ def dynamic_range_decompression_torch(x, C=1):
35
+ return torch.exp(x) / C
36
+
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+
43
+ def spectral_de_normalize_torch(magnitudes):
44
+ output = dynamic_range_decompression_torch(magnitudes)
45
+ return output
46
+
47
+
48
+ mel_basis = {}
49
+ hann_window = {}
50
+
51
+
52
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
53
+ if torch.min(y) < -1.0:
54
+ print("min value is ", torch.min(y))
55
+ if torch.max(y) > 1.0:
56
+ print("max value is ", torch.max(y))
57
+
58
+ global mel_basis, hann_window # pylint: disable=global-statement
59
+ if fmax not in mel_basis:
60
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
61
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
62
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
63
+
64
+ y = torch.nn.functional.pad(
65
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
66
+ )
67
+ y = y.squeeze(1)
68
+
69
+ spec = torch.view_as_real(
70
+ torch.stft(
71
+ y,
72
+ n_fft,
73
+ hop_length=hop_size,
74
+ win_length=win_size,
75
+ window=hann_window[str(y.device)],
76
+ center=center,
77
+ pad_mode="reflect",
78
+ normalized=False,
79
+ onesided=True,
80
+ return_complex=True,
81
+ )
82
+ )
83
+
84
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
85
+
86
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
87
+ spec = spectral_normalize_torch(spec)
88
+
89
+ return spec
90
+
91
+
92
+ def get_dataset_filelist(a):
93
+ with open(a.input_training_file, encoding="utf-8") as fi:
94
+ training_files = [
95
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
96
+ ]
97
+
98
+ with open(a.input_validation_file, encoding="utf-8") as fi:
99
+ validation_files = [
100
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
101
+ ]
102
+ return training_files, validation_files
103
+
104
+
105
+ class MelDataset(torch.utils.data.Dataset):
106
+ def __init__(
107
+ self,
108
+ training_files,
109
+ segment_size,
110
+ n_fft,
111
+ num_mels,
112
+ hop_size,
113
+ win_size,
114
+ sampling_rate,
115
+ fmin,
116
+ fmax,
117
+ split=True,
118
+ shuffle=True,
119
+ n_cache_reuse=1,
120
+ device=None,
121
+ fmax_loss=None,
122
+ fine_tuning=False,
123
+ base_mels_path=None,
124
+ ):
125
+ self.audio_files = training_files
126
+ random.seed(1234)
127
+ if shuffle:
128
+ random.shuffle(self.audio_files)
129
+ self.segment_size = segment_size
130
+ self.sampling_rate = sampling_rate
131
+ self.split = split
132
+ self.n_fft = n_fft
133
+ self.num_mels = num_mels
134
+ self.hop_size = hop_size
135
+ self.win_size = win_size
136
+ self.fmin = fmin
137
+ self.fmax = fmax
138
+ self.fmax_loss = fmax_loss
139
+ self.cached_wav = None
140
+ self.n_cache_reuse = n_cache_reuse
141
+ self._cache_ref_count = 0
142
+ self.device = device
143
+ self.fine_tuning = fine_tuning
144
+ self.base_mels_path = base_mels_path
145
+
146
+ def __getitem__(self, index):
147
+ filename = self.audio_files[index]
148
+ if self._cache_ref_count == 0:
149
+ audio, sampling_rate = load_wav(filename)
150
+ audio = audio / MAX_WAV_VALUE
151
+ if not self.fine_tuning:
152
+ audio = normalize(audio) * 0.95
153
+ self.cached_wav = audio
154
+ if sampling_rate != self.sampling_rate:
155
+ raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
156
+ self._cache_ref_count = self.n_cache_reuse
157
+ else:
158
+ audio = self.cached_wav
159
+ self._cache_ref_count -= 1
160
+
161
+ audio = torch.FloatTensor(audio)
162
+ audio = audio.unsqueeze(0)
163
+
164
+ if not self.fine_tuning:
165
+ if self.split:
166
+ if audio.size(1) >= self.segment_size:
167
+ max_audio_start = audio.size(1) - self.segment_size
168
+ audio_start = random.randint(0, max_audio_start)
169
+ audio = audio[:, audio_start : audio_start + self.segment_size]
170
+ else:
171
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
172
+
173
+ mel = mel_spectrogram(
174
+ audio,
175
+ self.n_fft,
176
+ self.num_mels,
177
+ self.sampling_rate,
178
+ self.hop_size,
179
+ self.win_size,
180
+ self.fmin,
181
+ self.fmax,
182
+ center=False,
183
+ )
184
+ else:
185
+ mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
186
+ mel = torch.from_numpy(mel)
187
+
188
+ if len(mel.shape) < 3:
189
+ mel = mel.unsqueeze(0)
190
+
191
+ if self.split:
192
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
193
+
194
+ if audio.size(1) >= self.segment_size:
195
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
196
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
197
+ audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
198
+ else:
199
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
200
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
201
+
202
+ mel_loss = mel_spectrogram(
203
+ audio,
204
+ self.n_fft,
205
+ self.num_mels,
206
+ self.sampling_rate,
207
+ self.hop_size,
208
+ self.win_size,
209
+ self.fmin,
210
+ self.fmax_loss,
211
+ center=False,
212
+ )
213
+
214
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
215
+
216
+ def __len__(self):
217
+ return len(self.audio_files)
@@ -0,0 +1,368 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+
9
+ from .xutils import get_padding, init_weights
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ class ResBlock1(torch.nn.Module):
15
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
16
+ super().__init__()
17
+ self.h = h
18
+ self.convs1 = nn.ModuleList(
19
+ [
20
+ weight_norm(
21
+ Conv1d(
22
+ channels,
23
+ channels,
24
+ kernel_size,
25
+ 1,
26
+ dilation=dilation[0],
27
+ padding=get_padding(kernel_size, dilation[0]),
28
+ )
29
+ ),
30
+ weight_norm(
31
+ Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ 1,
36
+ dilation=dilation[1],
37
+ padding=get_padding(kernel_size, dilation[1]),
38
+ )
39
+ ),
40
+ weight_norm(
41
+ Conv1d(
42
+ channels,
43
+ channels,
44
+ kernel_size,
45
+ 1,
46
+ dilation=dilation[2],
47
+ padding=get_padding(kernel_size, dilation[2]),
48
+ )
49
+ ),
50
+ ]
51
+ )
52
+ self.convs1.apply(init_weights)
53
+
54
+ self.convs2 = nn.ModuleList(
55
+ [
56
+ weight_norm(
57
+ Conv1d(
58
+ channels,
59
+ channels,
60
+ kernel_size,
61
+ 1,
62
+ dilation=1,
63
+ padding=get_padding(kernel_size, 1),
64
+ )
65
+ ),
66
+ weight_norm(
67
+ Conv1d(
68
+ channels,
69
+ channels,
70
+ kernel_size,
71
+ 1,
72
+ dilation=1,
73
+ padding=get_padding(kernel_size, 1),
74
+ )
75
+ ),
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ 1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ ),
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ def forward(self, x):
91
+ for c1, c2 in zip(self.convs1, self.convs2):
92
+ xt = F.leaky_relu(x, LRELU_SLOPE)
93
+ xt = c1(xt)
94
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
95
+ xt = c2(xt)
96
+ x = xt + x
97
+ return x
98
+
99
+ def remove_weight_norm(self):
100
+ for l in self.convs1:
101
+ remove_weight_norm(l)
102
+ for l in self.convs2:
103
+ remove_weight_norm(l)
104
+
105
+
106
+ class ResBlock2(torch.nn.Module):
107
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
108
+ super().__init__()
109
+ self.h = h
110
+ self.convs = nn.ModuleList(
111
+ [
112
+ weight_norm(
113
+ Conv1d(
114
+ channels,
115
+ channels,
116
+ kernel_size,
117
+ 1,
118
+ dilation=dilation[0],
119
+ padding=get_padding(kernel_size, dilation[0]),
120
+ )
121
+ ),
122
+ weight_norm(
123
+ Conv1d(
124
+ channels,
125
+ channels,
126
+ kernel_size,
127
+ 1,
128
+ dilation=dilation[1],
129
+ padding=get_padding(kernel_size, dilation[1]),
130
+ )
131
+ ),
132
+ ]
133
+ )
134
+ self.convs.apply(init_weights)
135
+
136
+ def forward(self, x):
137
+ for c in self.convs:
138
+ xt = F.leaky_relu(x, LRELU_SLOPE)
139
+ xt = c(xt)
140
+ x = xt + x
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs:
145
+ remove_weight_norm(l)
146
+
147
+
148
+ class Generator(torch.nn.Module):
149
+ def __init__(self, h):
150
+ super().__init__()
151
+ self.h = h
152
+ self.num_kernels = len(h.resblock_kernel_sizes)
153
+ self.num_upsamples = len(h.upsample_rates)
154
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
155
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
156
+
157
+ self.ups = nn.ModuleList()
158
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
159
+ self.ups.append(
160
+ weight_norm(
161
+ ConvTranspose1d(
162
+ h.upsample_initial_channel // (2**i),
163
+ h.upsample_initial_channel // (2 ** (i + 1)),
164
+ k,
165
+ u,
166
+ padding=(k - u) // 2,
167
+ )
168
+ )
169
+ )
170
+
171
+ self.resblocks = nn.ModuleList()
172
+ for i in range(len(self.ups)):
173
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
174
+ for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
175
+ self.resblocks.append(resblock(h, ch, k, d))
176
+
177
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
178
+ self.ups.apply(init_weights)
179
+ self.conv_post.apply(init_weights)
180
+
181
+ def forward(self, x):
182
+ x = self.conv_pre(x)
183
+ for i in range(self.num_upsamples):
184
+ x = F.leaky_relu(x, LRELU_SLOPE)
185
+ x = self.ups[i](x)
186
+ xs = None
187
+ for j in range(self.num_kernels):
188
+ if xs is None:
189
+ xs = self.resblocks[i * self.num_kernels + j](x)
190
+ else:
191
+ xs += self.resblocks[i * self.num_kernels + j](x)
192
+ x = xs / self.num_kernels
193
+ x = F.leaky_relu(x)
194
+ x = self.conv_post(x)
195
+ x = torch.tanh(x)
196
+
197
+ return x
198
+
199
+ def remove_weight_norm(self):
200
+ print("Removing weight norm...")
201
+ for l in self.ups:
202
+ remove_weight_norm(l)
203
+ for l in self.resblocks:
204
+ l.remove_weight_norm()
205
+ remove_weight_norm(self.conv_pre)
206
+ remove_weight_norm(self.conv_post)
207
+
208
+
209
+ class DiscriminatorP(torch.nn.Module):
210
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
211
+ super().__init__()
212
+ self.period = period
213
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
214
+ self.convs = nn.ModuleList(
215
+ [
216
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
217
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
218
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
219
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
220
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
221
+ ]
222
+ )
223
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
224
+
225
+ def forward(self, x):
226
+ fmap = []
227
+
228
+ # 1d to 2d
229
+ b, c, t = x.shape
230
+ if t % self.period != 0: # pad first
231
+ n_pad = self.period - (t % self.period)
232
+ x = F.pad(x, (0, n_pad), "reflect")
233
+ t = t + n_pad
234
+ x = x.view(b, c, t // self.period, self.period)
235
+
236
+ for l in self.convs:
237
+ x = l(x)
238
+ x = F.leaky_relu(x, LRELU_SLOPE)
239
+ fmap.append(x)
240
+ x = self.conv_post(x)
241
+ fmap.append(x)
242
+ x = torch.flatten(x, 1, -1)
243
+
244
+ return x, fmap
245
+
246
+
247
+ class MultiPeriodDiscriminator(torch.nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+ self.discriminators = nn.ModuleList(
251
+ [
252
+ DiscriminatorP(2),
253
+ DiscriminatorP(3),
254
+ DiscriminatorP(5),
255
+ DiscriminatorP(7),
256
+ DiscriminatorP(11),
257
+ ]
258
+ )
259
+
260
+ def forward(self, y, y_hat):
261
+ y_d_rs = []
262
+ y_d_gs = []
263
+ fmap_rs = []
264
+ fmap_gs = []
265
+ for _, d in enumerate(self.discriminators):
266
+ y_d_r, fmap_r = d(y)
267
+ y_d_g, fmap_g = d(y_hat)
268
+ y_d_rs.append(y_d_r)
269
+ fmap_rs.append(fmap_r)
270
+ y_d_gs.append(y_d_g)
271
+ fmap_gs.append(fmap_g)
272
+
273
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
+
275
+
276
+ class DiscriminatorS(torch.nn.Module):
277
+ def __init__(self, use_spectral_norm=False):
278
+ super().__init__()
279
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
280
+ self.convs = nn.ModuleList(
281
+ [
282
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
283
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
284
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
285
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
286
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
287
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
288
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
289
+ ]
290
+ )
291
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
292
+
293
+ def forward(self, x):
294
+ fmap = []
295
+ for l in self.convs:
296
+ x = l(x)
297
+ x = F.leaky_relu(x, LRELU_SLOPE)
298
+ fmap.append(x)
299
+ x = self.conv_post(x)
300
+ fmap.append(x)
301
+ x = torch.flatten(x, 1, -1)
302
+
303
+ return x, fmap
304
+
305
+
306
+ class MultiScaleDiscriminator(torch.nn.Module):
307
+ def __init__(self):
308
+ super().__init__()
309
+ self.discriminators = nn.ModuleList(
310
+ [
311
+ DiscriminatorS(use_spectral_norm=True),
312
+ DiscriminatorS(),
313
+ DiscriminatorS(),
314
+ ]
315
+ )
316
+ self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
317
+
318
+ def forward(self, y, y_hat):
319
+ y_d_rs = []
320
+ y_d_gs = []
321
+ fmap_rs = []
322
+ fmap_gs = []
323
+ for i, d in enumerate(self.discriminators):
324
+ if i != 0:
325
+ y = self.meanpools[i - 1](y)
326
+ y_hat = self.meanpools[i - 1](y_hat)
327
+ y_d_r, fmap_r = d(y)
328
+ y_d_g, fmap_g = d(y_hat)
329
+ y_d_rs.append(y_d_r)
330
+ fmap_rs.append(fmap_r)
331
+ y_d_gs.append(y_d_g)
332
+ fmap_gs.append(fmap_g)
333
+
334
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
335
+
336
+
337
+ def feature_loss(fmap_r, fmap_g):
338
+ loss = 0
339
+ for dr, dg in zip(fmap_r, fmap_g):
340
+ for rl, gl in zip(dr, dg):
341
+ loss += torch.mean(torch.abs(rl - gl))
342
+
343
+ return loss * 2
344
+
345
+
346
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
347
+ loss = 0
348
+ r_losses = []
349
+ g_losses = []
350
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
351
+ r_loss = torch.mean((1 - dr) ** 2)
352
+ g_loss = torch.mean(dg**2)
353
+ loss += r_loss + g_loss
354
+ r_losses.append(r_loss.item())
355
+ g_losses.append(g_loss.item())
356
+
357
+ return loss, r_losses, g_losses
358
+
359
+
360
+ def generator_loss(disc_outputs):
361
+ loss = 0
362
+ gen_losses = []
363
+ for dg in disc_outputs:
364
+ l = torch.mean((1 - dg) ** 2)
365
+ gen_losses.append(l)
366
+ loss += l
367
+
368
+ return loss, gen_losses
@@ -0,0 +1,60 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import glob
4
+ import os
5
+
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+
13
+
14
+ def plot_spectrogram(spectrogram):
15
+ fig, ax = plt.subplots(figsize=(10, 2))
16
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
17
+ plt.colorbar(im, ax=ax)
18
+
19
+ fig.canvas.draw()
20
+ plt.close()
21
+
22
+ return fig
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def apply_weight_norm(m):
32
+ classname = m.__class__.__name__
33
+ if classname.find("Conv") != -1:
34
+ weight_norm(m)
35
+
36
+
37
+ def get_padding(kernel_size, dilation=1):
38
+ return int((kernel_size * dilation - dilation) / 2)
39
+
40
+
41
+ def load_checkpoint(filepath, device):
42
+ assert os.path.isfile(filepath)
43
+ print(f"Loading '{filepath}'")
44
+ checkpoint_dict = torch.load(filepath, map_location=device)
45
+ print("Complete.")
46
+ return checkpoint_dict
47
+
48
+
49
+ def save_checkpoint(filepath, obj):
50
+ print(f"Saving checkpoint to {filepath}")
51
+ torch.save(obj, filepath)
52
+ print("Complete.")
53
+
54
+
55
+ def scan_checkpoint(cp_dir, prefix):
56
+ pattern = os.path.join(cp_dir, prefix + "????????")
57
+ cp_list = glob.glob(pattern)
58
+ if len(cp_list) == 0:
59
+ return None
60
+ return sorted(cp_list)[-1]
File without changes