minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,248 @@
1
+ import io
2
+ import sys
3
+ import types
4
+
5
+ import torch
6
+ import torchaudio
7
+ import s3tokenizer
8
+ import onnxruntime
9
+ import numpy as np
10
+ from copy import deepcopy
11
+
12
+ import torchaudio.compliance.kaldi as kaldi
13
+ from stepaudio2.flashcosyvoice.modules.hifigan import HiFTGenerator
14
+ from stepaudio2.flashcosyvoice.utils.audio import mel_spectrogram
15
+ from hyperpyyaml import load_hyperpyyaml
16
+
17
+
18
+ def _setup_cosyvoice2_alias():
19
+ """给 hyperpyyaml 提供 cosyvoice2.* 的兼容别名,不改 flow.yaml。"""
20
+ if 'cosyvoice2.flow.flow' in sys.modules:
21
+ # 已经设置过,直接复用
22
+ return
23
+
24
+ # 导入 stepaudio2 里真实的实现
25
+ import stepaudio2.cosyvoice2.flow.flow as _step_flow
26
+ import stepaudio2.cosyvoice2.flow.flow_matching as _step_flow_matching
27
+ import stepaudio2.cosyvoice2.flow.decoder_dit as _step_decoder_dit
28
+ import stepaudio2.cosyvoice2.transformer.upsample_encoder_v2 as _step_upsample
29
+
30
+ # 创建顶层 cosyvoice2 包和子包
31
+ cosyvoice2_pkg = types.ModuleType('cosyvoice2')
32
+ cosyvoice2_flow_pkg = types.ModuleType('cosyvoice2.flow')
33
+ cosyvoice2_transformer_pkg = types.ModuleType('cosyvoice2.transformer')
34
+
35
+ # 挂载子模块
36
+ cosyvoice2_flow_pkg.flow = _step_flow
37
+ cosyvoice2_flow_pkg.flow_matching = _step_flow_matching
38
+ cosyvoice2_flow_pkg.decoder_dit = _step_decoder_dit
39
+
40
+ cosyvoice2_transformer_pkg.upsample_encoder_v2 = _step_upsample
41
+
42
+ cosyvoice2_pkg.flow = cosyvoice2_flow_pkg
43
+ cosyvoice2_pkg.transformer = cosyvoice2_transformer_pkg
44
+
45
+ # 注册到 sys.modules,让 `cosyvoice2.flow.flow.*` 这类路径可被 import
46
+ sys.modules['cosyvoice2'] = cosyvoice2_pkg
47
+ sys.modules['cosyvoice2.flow'] = cosyvoice2_flow_pkg
48
+ sys.modules['cosyvoice2.flow.flow'] = _step_flow
49
+ sys.modules['cosyvoice2.flow.flow_matching'] = _step_flow_matching
50
+ sys.modules['cosyvoice2.flow.decoder_dit'] = _step_decoder_dit
51
+ sys.modules['cosyvoice2.transformer'] = cosyvoice2_transformer_pkg
52
+ sys.modules['cosyvoice2.transformer.upsample_encoder_v2'] = _step_upsample
53
+
54
+ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
55
+ """perform fade_in_out in tensor style
56
+ """
57
+ mel_overlap_len = int(window.shape[0] / 2)
58
+ fade_in_mel = fade_in_mel.clone()
59
+ fade_in_mel[..., :mel_overlap_len] = \
60
+ fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
61
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
62
+ return fade_in_mel
63
+
64
+
65
+ class Token2wav():
66
+ def __init__(self, model_path, float16=False, n_timesteps=10):
67
+ self.float16 = float16
68
+ self.n_timesteps = n_timesteps
69
+
70
+ # 在加载 flow.yaml 之前,先把 cosyvoice2 的别名注册好
71
+ _setup_cosyvoice2_alias()
72
+
73
+ self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval()
74
+
75
+ option = onnxruntime.SessionOptions()
76
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
77
+ option.intra_op_num_threads = 1
78
+ self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
79
+
80
+ with open(f"{model_path}/flow.yaml", "r") as f:
81
+ configs = load_hyperpyyaml(f)
82
+ self.flow = configs['flow']
83
+ if float16:
84
+ self.flow.half()
85
+ self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True)
86
+ self.flow.cuda().eval()
87
+
88
+ self.hift = HiFTGenerator()
89
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()}
90
+ self.hift.load_state_dict(hift_state_dict, strict=True)
91
+ self.hift.cuda().eval()
92
+
93
+ self.cache = None
94
+
95
+ # stream conf
96
+ self.mel_cache_len = 8 # hard-coded, 160ms
97
+ self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
98
+ self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
99
+
100
+ # hifigan cache
101
+ self.hift_cache_dict = {}
102
+
103
+ def _prepare_prompt(self, prompt_wav):
104
+ audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T]
105
+ # TODO 在audio 后面 pad 3个token长度 = 0.12s的音频
106
+ mels = s3tokenizer.log_mel_spectrogram(audio)
107
+ mels, mels_lens = s3tokenizer.padding([mels])
108
+ prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda())
109
+ # TODO 手动赋值最后3个token为 4218 静音token
110
+
111
+ spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
112
+ spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
113
+ spk_emb = torch.tensor(self.spk_model.run(
114
+ None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
115
+ )[0], device='cuda')
116
+
117
+ audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile')
118
+ audio = audio.mean(dim=0, keepdim=True) # [1, T]
119
+ if sample_rate != 24000:
120
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
121
+ prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
122
+ prompt_mels = prompt_mel.unsqueeze(0).cuda()
123
+ prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda')
124
+ prompt_mels = torch.nn.functional.pad(prompt_mels, (0, 0, 0, prompt_speech_tokens.shape[1] * self.flow.up_rate - prompt_mels.shape[1]), mode='replicate')
125
+ return prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens
126
+
127
+ def __call__(self, generated_speech_tokens, prompt_wav):
128
+ if self.cache is None:
129
+ self.cache = self._prepare_prompt(prompt_wav)
130
+ prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache
131
+
132
+ generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
133
+ generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
134
+
135
+ with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
136
+ mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens,
137
+ prompt_speech_tokens, prompt_speech_tokens_lens,
138
+ prompt_mels, prompt_mels_lens, spk_emb, 10)
139
+
140
+ wav, _ = self.hift(speech_feat=mel)
141
+ output = io.BytesIO()
142
+ torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav')
143
+
144
+ return output.getvalue()
145
+
146
+ def set_stream_cache(self, prompt_wav):
147
+ self.cache = self._prepare_prompt(prompt_wav)
148
+ prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache
149
+
150
+ print("prompt_speech_tokens_lens", prompt_speech_tokens_lens)
151
+ print("prompt_mels.shape", prompt_mels.shape)
152
+ print("prompt_mels_lens", prompt_mels_lens)
153
+
154
+ right_pad_speech_tokens = torch.ones(1, 3, device=prompt_speech_tokens.device, dtype=prompt_speech_tokens.dtype) * 4218
155
+
156
+ # self.stream_cache = self.flow.setup_cache(
157
+ # torch.cat([prompt_speech_tokens, prompt_speech_tokens[:, :3]], dim=1),
158
+ # prompt_mels, spk_emb, n_timesteps=self.n_timesteps)
159
+
160
+ stream_cache = self.flow.setup_cache(
161
+ torch.cat([prompt_speech_tokens, right_pad_speech_tokens], dim=1),
162
+ prompt_mels, spk_emb, n_timesteps=self.n_timesteps)
163
+
164
+ # hift cache
165
+ hift_cache_dict = dict(
166
+ mel = torch.zeros(1, prompt_mels.shape[2], 0, device='cuda'),
167
+ source = torch.zeros(1, 1, 0, device='cuda'),
168
+ speech = torch.zeros(1, 0, device='cuda'),
169
+ )
170
+
171
+ return stream_cache, hift_cache_dict
172
+
173
+
174
+ def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False, return_waveform=False):
175
+ if self.cache is None:
176
+ self.cache = self._prepare_prompt(prompt_wav)
177
+ prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache
178
+
179
+ generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
180
+ generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
181
+
182
+ if self.stream_cache is None:
183
+ raise ValueError("stream_cache is not set")
184
+
185
+ with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
186
+ chunk_mel, self.stream_cache = self.flow.inference_chunk(
187
+ token=generated_speech_tokens,
188
+ spk=spk_emb,
189
+ cache=self.stream_cache,
190
+ last_chunk=last_chunk,
191
+ n_timesteps=self.n_timesteps,
192
+ )
193
+ if self.stream_cache['estimator_att_cache'].shape[4] > (prompt_mels.shape[1] + 100):
194
+ self.stream_cache['estimator_att_cache'] = torch.cat([
195
+ self.stream_cache['estimator_att_cache'][:, :, :, :, :prompt_mels.shape[1]],
196
+ self.stream_cache['estimator_att_cache'][:, :, :, :, -100:],
197
+ ], dim=4)
198
+
199
+ # bug fix - 20260107
200
+ # 同样截断 conformer_att_cache,防止无限增长导致 position embedding 不匹配
201
+ if self.stream_cache['conformer_att_cache'].shape[3] > (prompt_mels.shape[1] + 100):
202
+ self.stream_cache['conformer_att_cache'] = torch.cat([
203
+ self.stream_cache['conformer_att_cache'][:, :, :, :prompt_mels.shape[1], :],
204
+ self.stream_cache['conformer_att_cache'][:, :, :, -100:, :],
205
+ ], dim=3)
206
+
207
+ # vocoder cache
208
+ hift_cache_mel = self.hift_cache_dict['mel']
209
+ hift_cache_source = self.hift_cache_dict['source']
210
+ hift_cache_speech = self.hift_cache_dict['speech']
211
+ mel = torch.concat([hift_cache_mel, chunk_mel], dim=2)
212
+
213
+ speech, source = self.hift(mel, hift_cache_source)
214
+
215
+ # overlap speech smooth
216
+ if hift_cache_speech.shape[-1] > 0:
217
+ speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
218
+
219
+ # 检测是否是第一个 chunk(没有有效的 speech cache)
220
+ is_first_chunk = hift_cache_speech.shape[-1] == 0
221
+
222
+ # update vocoder cache
223
+ self.hift_cache_dict = dict(
224
+ mel = mel[..., -self.mel_cache_len:].clone().detach(),
225
+ source = source[:, :, -self.source_cache_len:].clone().detach(),
226
+ speech = speech[:, -self.source_cache_len:].clone().detach(),
227
+ )
228
+
229
+ if not last_chunk:
230
+ if is_first_chunk:
231
+ # 第一个 chunk:截断尾部,在开头添加静音来补偿长度
232
+ # 不做淡入,保持原始音频内容不变
233
+ silence_padding = torch.zeros(1, self.source_cache_len, device=speech.device)
234
+ speech = torch.cat([silence_padding, speech[:, :-self.source_cache_len]], dim=1)
235
+ else:
236
+ # 后续 chunk:正常截断尾部(由下一个 chunk 的 fade_in 补偿)
237
+ speech = speech[:, :-self.source_cache_len]
238
+
239
+ wav_np = speech.cpu().numpy()
240
+ if return_waveform:
241
+ return wav_np
242
+
243
+ # Clip to [-1, 1] to avoid overflow, then scale to int16
244
+ wav_np = np.clip(wav_np, -1.0, 1.0)
245
+ wav_int16 = (wav_np * 32767.0).astype('<i2') # 16-bit little-endian PCM
246
+ pcm_bytes = wav_int16.tobytes()
247
+ return pcm_bytes
248
+
stepaudio2/utils.py ADDED
@@ -0,0 +1,91 @@
1
+ import librosa
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import torchaudio
6
+ from typing import List
7
+
8
+
9
+ def _mel_filters(n_mels: int) -> torch.Tensor:
10
+ """Load the mel filterbank matrix for projecting STFT into a Mel spectrogram."""
11
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
12
+ if n_mels == 128:
13
+ return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128))
14
+ else:
15
+ return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))
16
+
17
+ def load_audio(file_path, target_rate=16000, max_length=None):
18
+ """
19
+ Open an audio file and read as mono waveform, resampling as necessary
20
+ If max_length is provided, truncate the audio to that length
21
+ """
22
+ waveform, sample_rate = torchaudio.load(file_path)
23
+ if sample_rate != target_rate:
24
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
25
+ audio = waveform[0] # get the first channel
26
+
27
+ # Truncate audio if it exceeds max_length
28
+ if max_length is not None and audio.shape[0] > max_length:
29
+ audio = audio[:max_length]
30
+
31
+ return audio
32
+
33
+ def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None):
34
+ """
35
+ Compute the log-Mel spectrogram with specific padding for StepAudio
36
+ """
37
+ if not torch.is_tensor(audio):
38
+ if isinstance(audio, str):
39
+ audio = load_audio(audio)
40
+ audio = torch.from_numpy(audio)
41
+ if device is not None:
42
+ audio = audio.to(device)
43
+ if padding > 0:
44
+ audio = F.pad(audio, (0, padding))
45
+ window = torch.hann_window(400).to(audio.device)
46
+ stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
47
+ magnitudes = stft[..., :-1].abs() ** 2
48
+ filters = _mel_filters(n_mels)
49
+ mel_spec = filters @ magnitudes
50
+
51
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
52
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
53
+ log_spec = (log_spec + 4.0) / 4.0
54
+ return log_spec
55
+
56
+ def compute_token_num(max_feature_len):
57
+ # First, audio goes through encoder:
58
+ # 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged
59
+ # 2. conv2: kernel=3, stride=2, padding=1 -> size/2
60
+ # 3. avg_pooler: kernel=2, stride=2 -> size/2
61
+ max_feature_len = max_feature_len - 2 # remove padding
62
+ encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler
63
+
64
+ # Then through adaptor (parameters from config file):
65
+ padding = 1
66
+ kernel_size = 3 # from config: audio_encoder_config.kernel_size
67
+ stride = 2 # from config: audio_encoder_config.adapter_stride
68
+ adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1
69
+ return adapter_output_dim
70
+
71
+ def padding_mels(data: List[torch.Tensor]):
72
+ """ Padding the data into batch data
73
+
74
+ Parameters
75
+ ----------
76
+ data: List[Tensor], shape of Tensor (128, T)
77
+
78
+ Returns:
79
+ -------
80
+ feats, feats lengths
81
+ """
82
+ sample = data
83
+ assert isinstance(sample, list)
84
+ feats_lengths = torch.tensor([s.size(1)-2 for s in sample],
85
+ dtype=torch.int32)
86
+ feats = [s.t() for s in sample]
87
+ padded_feats = pad_sequence(feats,
88
+ batch_first=True,
89
+ padding_value=0)
90
+
91
+ return padded_feats.transpose(1, 2), feats_lengths