minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,605 @@
1
+ # Copyright (c) (Mddct: Dinghao Zhou)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from einops import rearrange
20
+
21
+ from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention
22
+ from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
23
+
24
+
25
+ @dataclass
26
+ class ModelConfig:
27
+ n_mels: int = 128
28
+ n_audio_ctx: int = 1500
29
+ n_audio_state: int = 1280
30
+ n_audio_head: int = 20
31
+ n_audio_layer: int = 6
32
+ n_codebook_size: int = 3**8
33
+
34
+ use_sdpa: bool = False
35
+
36
+
37
+ def precompute_freqs_cis(dim: int,
38
+ end: int,
39
+ theta: float = 10000.0,
40
+ scaling=None):
41
+ freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
42
+ t = torch.arange(end, device=freqs.device) # type: ignore
43
+ if scaling is not None:
44
+ t = t * scaling
45
+ freqs = torch.outer(t, freqs).float() # type: ignore
46
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
47
+
48
+ return torch.cat((freqs_cis, freqs_cis), dim=-1)
49
+
50
+
51
+ def apply_rotary_emb(
52
+ xq: torch.Tensor,
53
+ xk: torch.Tensor,
54
+ freqs_cis: torch.Tensor,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ real = torch.view_as_real(freqs_cis)
57
+ cos, sin = real[:, :, 0], real[:, :, 1]
58
+ cos = cos.unsqueeze(0).unsqueeze(2)
59
+ sin = sin.unsqueeze(0).unsqueeze(2)
60
+
61
+ D = xq.shape[-1]
62
+ half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:]
63
+ xq_r = torch.cat((-half_r, half_l), dim=-1)
64
+
65
+ D = xk.shape[-1]
66
+
67
+ half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:]
68
+ xk_r = torch.cat((-half_r, half_l), dim=-1)
69
+
70
+ return xq * cos + xq_r * sin, xk * cos + xk_r * sin
71
+
72
+
73
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
74
+ ndim = x.ndim
75
+ assert 0 <= 1 < ndim
76
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
77
+ shape = [
78
+ d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
79
+ ]
80
+ return freqs_cis.view(*shape)
81
+
82
+
83
+ class FSQCodebook(torch.nn.Module):
84
+
85
+ def __init__(self, dim: int, level: int = 3):
86
+ super().__init__()
87
+ self.project_down = torch.nn.Linear(dim, 8)
88
+ self.level = level
89
+ self.embed = None
90
+
91
+ @torch.inference_mode()
92
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = rearrange(x, "... d -> (...) d")
94
+ return x
95
+
96
+ @torch.inference_mode()
97
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
98
+ print("fixed s3 encode")
99
+ x_shape = x.shape
100
+ x = self.preprocess(x) # -> (B, T, F)
101
+ h = self.project_down(x).float() # -> (B, T, D)
102
+ # 将连续值映射到 [0, L-1]
103
+ L = int(self.level) # 每维等级数
104
+ eps = 1e-6
105
+ h = torch.tanh(h)
106
+ h = torch.clamp(h, -1 + eps, 1 - eps)
107
+ h = ((h + 1.0) * (L - 1) / 2.0).round().to(torch.int64) # digits: (B, T, D) in [0..L-1]
108
+ # 打包为单一索引(base-L)
109
+ D = h.size(-1)
110
+ powers = (L ** torch.arange(D, device=h.device, dtype=torch.int64)) # (D,)
111
+ idx = (h * powers.unsqueeze(0)).sum(dim=-1)
112
+ idx = idx.reshape(x_shape[0], x_shape[1]).int() # -> (B, T)
113
+ return idx
114
+
115
+ @torch.inference_mode()
116
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
117
+ raise NotImplementedError(
118
+ 'There is no official up project component provided')
119
+
120
+
121
+ class FSQVectorQuantization(torch.nn.Module):
122
+ """Vector quantization implementation (inference-only).
123
+ Args:
124
+ dim (int): Dimension
125
+ codebook_size (int): Codebook size
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ codebook_size: int,
132
+ ):
133
+ super().__init__()
134
+ assert 3**8 == codebook_size
135
+ self._codebook = FSQCodebook(dim=dim, level=3)
136
+ self.codebook_size = codebook_size
137
+
138
+ @property
139
+ def codebook(self):
140
+ return self._codebook.embed
141
+
142
+ @torch.inference_mode()
143
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
144
+ return self._codebook.encode(x)
145
+
146
+ @torch.inference_mode()
147
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
148
+ quantize = self._codebook.decode(embed_ind)
149
+ quantize = rearrange(quantize, "b n d -> b d n")
150
+ return quantize
151
+
152
+
153
+ class FSMNMultiHeadAttention(MultiHeadAttention):
154
+
155
+ def __init__(
156
+ self,
157
+ n_state: int,
158
+ n_head: int,
159
+ kernel_size: int = 31,
160
+ use_sdpa: bool = False,
161
+ ):
162
+ super().__init__(n_state, n_head)
163
+
164
+ self.fsmn_block = torch.nn.Conv1d(n_state,
165
+ n_state,
166
+ kernel_size,
167
+ stride=1,
168
+ padding=0,
169
+ groups=n_state,
170
+ bias=False)
171
+ self.left_padding = (kernel_size - 1) // 2
172
+ self.right_padding = kernel_size - 1 - self.left_padding
173
+ self.pad_fn = torch.nn.ConstantPad1d(
174
+ (self.left_padding, self.right_padding), 0.0)
175
+
176
+ self.use_sdpa = use_sdpa
177
+
178
+ def forward_fsmn(self,
179
+ inputs: torch.Tensor,
180
+ mask: Optional[torch.Tensor] = None):
181
+ b, t, _, _ = inputs.size()
182
+ inputs = inputs.view(b, t, -1)
183
+ if mask is not None and mask.size(2) > 0: # time2 > 0
184
+ inputs = inputs * mask
185
+ x = inputs.transpose(1, 2)
186
+ x = self.pad_fn(x)
187
+ x = self.fsmn_block(x)
188
+ x = x.transpose(1, 2)
189
+ x += inputs
190
+ return x * mask
191
+
192
+ def qkv_attention(self,
193
+ q: torch.Tensor,
194
+ k: torch.Tensor,
195
+ v: torch.Tensor,
196
+ mask: Optional[torch.Tensor] = None,
197
+ mask_pad: Optional[torch.Tensor] = None,
198
+ freqs_cis: Optional[torch.Tensor] = None):
199
+ _, _, D = q.shape
200
+ scale = (D // self.n_head)**-0.25
201
+ q = q.view(*q.shape[:2], self.n_head, -1)
202
+ k = k.view(*k.shape[:2], self.n_head, -1)
203
+ v = v.view(*v.shape[:2], self.n_head, -1)
204
+
205
+ if freqs_cis is not None:
206
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
207
+
208
+ fsm_memory = self.forward_fsmn(v, mask_pad)
209
+
210
+ q = q.permute(0, 2, 1, 3) * scale
211
+ v = v.permute(0, 2, 1, 3)
212
+
213
+ if not self.use_sdpa:
214
+ k = k.permute(0, 2, 3, 1) * scale
215
+ qk = q @ k # (B, n_head, T, T)
216
+ if mask is not None:
217
+ qk = qk + mask
218
+ qk = qk.float()
219
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
220
+ return (w @ v).permute(
221
+ 0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
222
+ else:
223
+ k = k.permute(0, 2, 1, 3) * scale
224
+ assert mask is not None
225
+ output = torch.nn.functional.scaled_dot_product_attention(
226
+ q,
227
+ k,
228
+ v,
229
+ attn_mask=mask,
230
+ dropout_p=0.,
231
+ scale=1.,
232
+ )
233
+ output = (output.transpose(1,
234
+ 2).contiguous().view(q.size(0), -1, D)
235
+ ) # (batch, time1, d_model)
236
+ return output, None, fsm_memory
237
+
238
+ def forward(self,
239
+ x: torch.Tensor,
240
+ mask: Optional[torch.Tensor] = None,
241
+ mask_pad: Optional[torch.Tensor] = None,
242
+ freqs_cis: Optional[torch.Tensor] = None):
243
+
244
+ q = self.query(x)
245
+ k = self.key(x)
246
+ v = self.value(x)
247
+
248
+ wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad,
249
+ freqs_cis)
250
+ return self.out(wv) + fsm_memory, qk
251
+
252
+
253
+ class ResidualAttentionBlock(torch.nn.Module):
254
+
255
+ def __init__(
256
+ self,
257
+ n_state: int,
258
+ n_head: int,
259
+ kernel_size: int = 31,
260
+ use_sdpa: bool = False,
261
+ ):
262
+ super().__init__()
263
+
264
+ self.attn = FSMNMultiHeadAttention(n_state,
265
+ n_head,
266
+ kernel_size,
267
+ use_sdpa=use_sdpa)
268
+ self.attn_ln = LayerNorm(n_state, eps=1e-6)
269
+
270
+ n_mlp = n_state * 4
271
+
272
+ self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(),
273
+ Linear(n_mlp, n_state))
274
+ self.mlp_ln = LayerNorm(n_state)
275
+
276
+ def forward(
277
+ self,
278
+ x: torch.Tensor,
279
+ mask: Optional[torch.Tensor] = None,
280
+ mask_pad: Optional[torch.Tensor] = None,
281
+ freqs_cis: Optional[torch.Tensor] = None,
282
+ ):
283
+ x = x + self.attn(
284
+ self.attn_ln(x), mask=mask, mask_pad=mask_pad,
285
+ freqs_cis=freqs_cis)[0]
286
+
287
+ x = x + self.mlp(self.mlp_ln(x))
288
+ return x
289
+
290
+
291
+ class AudioEncoderV2(torch.nn.Module):
292
+
293
+ def __init__(
294
+ self,
295
+ n_mels: int,
296
+ n_state: int,
297
+ n_head: int,
298
+ n_layer: int,
299
+ stride: int,
300
+ use_sdpa: bool,
301
+ ):
302
+ super().__init__()
303
+ self.stride = stride
304
+
305
+ self.conv1 = Conv1d(n_mels,
306
+ n_state,
307
+ kernel_size=3,
308
+ stride=stride,
309
+ padding=1)
310
+ self.conv2 = Conv1d(n_state,
311
+ n_state,
312
+ kernel_size=3,
313
+ stride=2,
314
+ padding=1)
315
+ self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
316
+ self.blocks = torch.nn.ModuleList([
317
+ ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
318
+ for _ in range(n_layer)
319
+ ])
320
+
321
+ def forward(self, x: torch.Tensor,
322
+ x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
323
+ """
324
+ x : torch.Tensor, shape = (batch_size, n_mels, T)
325
+ the mel spectrogram of the audio
326
+ x_len: torch.Tensor, shape = (batch_size,)
327
+ length of each audio in x
328
+ """
329
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
330
+ x = torch.nn.functional.gelu(self.conv1(x * mask))
331
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
332
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
333
+ x = torch.nn.functional.gelu(self.conv2(x * mask))
334
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
335
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
336
+ x = x.permute(0, 2, 1) # (B, T // 2, n_state)
337
+ freqs_cis = self.freqs_cis.to(x.device)
338
+ mask_pad = mask.transpose(1, 2)
339
+ mask = mask_to_bias(mask, x.dtype)
340
+
341
+ tmp = torch.view_as_real(freqs_cis)
342
+ cos, sin = tmp[:, :, 0], tmp[:, :, 1]
343
+
344
+ cos = torch.cat((cos, cos), dim=-1)
345
+ sin = torch.cat((sin, sin), dim=-1)
346
+ cos = cos.unsqueeze(0).unsqueeze(2)
347
+ sin = sin.unsqueeze(0).unsqueeze(2)
348
+
349
+ for block in self.blocks:
350
+ x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)])
351
+
352
+ return x, x_len
353
+
354
+
355
+ class S3TokenizerV2(torch.nn.Module):
356
+ """S3 tokenizer v2 implementation (inference-only).
357
+ Args:
358
+ config (ModelConfig): Config
359
+ """
360
+
361
+ def __init__(self, name: str, config: ModelConfig = ModelConfig()):
362
+ super().__init__()
363
+ self.name = name # Store model name for token_rate determination
364
+ if 'v1' not in name:
365
+ assert 'v2' in name
366
+ # TODO(Mddct): make it configureable
367
+ config.n_codebook_size = 3**8
368
+ self.config = config
369
+ self.encoder = AudioEncoderV2(
370
+ self.config.n_mels,
371
+ self.config.n_audio_state,
372
+ self.config.n_audio_head,
373
+ self.config.n_audio_layer,
374
+ 2,
375
+ self.config.use_sdpa,
376
+ )
377
+ self.quantizer = FSQVectorQuantization(
378
+ self.config.n_audio_state,
379
+ self.config.n_codebook_size,
380
+ )
381
+
382
+ def forward(self, mel: torch.Tensor,
383
+ mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
384
+ return self.quantize(mel, mel_len)
385
+
386
+ @torch.inference_mode()
387
+ def quantize(self, mel: torch.Tensor,
388
+ mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
389
+ """
390
+ Quantize mel spectrogram to tokens, with automatic long audio handling.
391
+
392
+ Args:
393
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
394
+ mel_len: mel length tensor, shape (batch_size,)
395
+
396
+ Returns:
397
+ code: quantized tokens, shape (batch_size, T')
398
+ code_len: token length, shape (batch_size,)
399
+ """
400
+ # Check if any audio in the batch exceeds 30 seconds
401
+ # Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
402
+ max_frames = 3000
403
+
404
+ # Check which samples are long audio
405
+ long_audio_mask = mel_len > max_frames
406
+
407
+ if long_audio_mask.any():
408
+ # Has long audio - need special processing
409
+ return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
410
+ max_frames)
411
+ else:
412
+ # All short audio - use original method
413
+ hidden, code_len = self.encoder(mel, mel_len)
414
+ code = self.quantizer.encode(hidden)
415
+ return code, code_len
416
+
417
+ @torch.inference_mode()
418
+ def _quantize_mixed_batch(
419
+ self, mel: torch.Tensor, mel_len: torch.Tensor,
420
+ long_audio_mask: torch.Tensor,
421
+ max_frames: int) -> Tuple[torch.Tensor, torch.Tensor]:
422
+ """
423
+ Handle mixed batch with both short and long audio using unified batch processing.
424
+
425
+ Args:
426
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
427
+ mel_len: mel length tensor, shape (batch_size,)
428
+ long_audio_mask: boolean mask for long audio, shape (batch_size,)
429
+ max_frames: maximum frames for short audio
430
+
431
+ Returns:
432
+ code: quantized tokens, shape (batch_size, T')
433
+ code_len: token length, shape (batch_size,)
434
+ """
435
+ batch_size = mel.size(0)
436
+
437
+ # Parameters for sliding window
438
+ sample_rate = 16000
439
+ hop_length = 160 # Default hop length for mel spectrogram
440
+ window_size = 30 # seconds
441
+ overlap = 4 # seconds
442
+
443
+ # Calculate frame-based parameters
444
+ frames_per_window = window_size * sample_rate // hop_length # 3000 frames
445
+ frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
446
+ frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
447
+
448
+ # Collect all segments to process (including short and long audio segments)
449
+ all_segments = []
450
+ all_segments_len = []
451
+ segment_info = [
452
+ ] # Record which audio each segment belongs to and whether it's long audio
453
+
454
+ # Process all audio in the batch
455
+ for batch_idx in range(batch_size):
456
+ audio_mel = mel[batch_idx]
457
+ audio_mel_len = mel_len[batch_idx]
458
+ is_long_audio = long_audio_mask[batch_idx].item()
459
+
460
+ if not is_long_audio:
461
+ # Short audio: process directly as a single segment
462
+ segment = audio_mel[:, :audio_mel_len]
463
+ seg_len = audio_mel_len.item()
464
+
465
+ # Pad to max_frames if necessary
466
+ if seg_len < frames_per_window:
467
+ pad_size = frames_per_window - seg_len
468
+ segment = torch.nn.functional.pad(segment, (0, pad_size))
469
+
470
+ all_segments.append(segment)
471
+ all_segments_len.append(
472
+ torch.tensor(seg_len, device=mel.device))
473
+ segment_info.append({
474
+ 'batch_idx': batch_idx,
475
+ 'is_long_audio': False,
476
+ 'segment_idx': 0,
477
+ 'total_segments': 1
478
+ })
479
+ else:
480
+ # Long audio: split into multiple segments
481
+ start = 0
482
+ segment_idx = 0
483
+ while start < audio_mel_len:
484
+ end = min(start + frames_per_window, audio_mel_len)
485
+ segment = audio_mel[:, start:end]
486
+
487
+ seg_len = segment.size(1)
488
+ # Pad if necessary
489
+ if seg_len < frames_per_window:
490
+ pad_size = frames_per_window - seg_len
491
+ segment = torch.nn.functional.pad(
492
+ segment, (0, pad_size))
493
+
494
+ all_segments.append(segment)
495
+ all_segments_len.append(
496
+ torch.tensor(seg_len, device=mel.device))
497
+ segment_info.append({
498
+ 'batch_idx': batch_idx,
499
+ 'is_long_audio': True,
500
+ 'segment_idx': segment_idx,
501
+ 'total_segments': None # Will be filled later
502
+ })
503
+
504
+ segment_idx += 1
505
+ start += frames_per_stride
506
+
507
+ # Update total_segments info
508
+ total_segments = segment_idx
509
+ for info in segment_info:
510
+ if info['batch_idx'] == batch_idx and info['is_long_audio']:
511
+ info['total_segments'] = total_segments
512
+
513
+ if not all_segments:
514
+ # Fallback if no segments
515
+ return torch.zeros(batch_size,
516
+ 0,
517
+ dtype=torch.long,
518
+ device=mel.device), torch.zeros(
519
+ batch_size,
520
+ dtype=torch.long,
521
+ device=mel.device)
522
+
523
+ # Unified batch processing for all segments
524
+ unified_batch_mel = torch.stack(all_segments)
525
+ unified_batch_lens = torch.stack(all_segments_len)
526
+
527
+ # Process all segments at once
528
+ hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
529
+ codes = self.quantizer.encode(hidden)
530
+
531
+ # Reorganize results based on segment_info
532
+ results = {} # batch_idx -> (code_tensor, code_len)
533
+
534
+ for seg_idx, info in enumerate(segment_info):
535
+ batch_idx = info['batch_idx']
536
+ is_long_audio = info['is_long_audio']
537
+ segment_idx = info['segment_idx']
538
+
539
+ # Get codes for current segment
540
+ segment_code = codes[
541
+ seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
542
+
543
+ if not is_long_audio:
544
+ # Short audio: use directly
545
+ code_tensor = torch.tensor(segment_code,
546
+ dtype=torch.long,
547
+ device=mel.device)
548
+ results[batch_idx] = (code_tensor, len(segment_code))
549
+ else:
550
+ # Long audio: collect all segments
551
+ if batch_idx not in results:
552
+ results[batch_idx] = []
553
+ results[batch_idx].append(segment_code)
554
+
555
+ # Process long audio segment merging
556
+ for batch_idx in range(batch_size):
557
+ if long_audio_mask[batch_idx].item():
558
+ # Merge long audio segments
559
+ audio_codes = results[batch_idx]
560
+
561
+ # V2 models use 25Hz token rate
562
+ token_rate = 25
563
+
564
+ merged_codes = merge_tokenized_segments(audio_codes,
565
+ overlap=overlap,
566
+ token_rate=token_rate)
567
+
568
+ # Convert to tensor
569
+ merged_codes_tensor = torch.tensor(merged_codes,
570
+ dtype=torch.long,
571
+ device=mel.device)
572
+ results[batch_idx] = (merged_codes_tensor, len(merged_codes))
573
+
574
+ # Construct final output
575
+ max_code_len = max(code_info[1] for code_info in results.values())
576
+
577
+ output_codes = torch.zeros(batch_size,
578
+ max_code_len,
579
+ dtype=torch.long,
580
+ device=mel.device)
581
+ output_codes_len = torch.zeros(batch_size,
582
+ dtype=torch.long,
583
+ device=mel.device)
584
+
585
+ for batch_idx, (code_tensor, code_len) in results.items():
586
+ output_codes[batch_idx, :code_len] = code_tensor
587
+ output_codes_len[batch_idx] = code_len
588
+
589
+ return output_codes, output_codes_len
590
+
591
+ @property
592
+ def device(self):
593
+ return next(self.parameters()).device
594
+
595
+ def init_from_onnx(self, onnx_path: str):
596
+ ckpt = onnx2torch(onnx_path, None, False)
597
+ self.load_state_dict(ckpt, strict=True)
598
+
599
+ def init_from_pt(self, ckpt_path: str):
600
+ ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
601
+ self.load_state_dict(ckpt, strict=True)
602
+
603
+ def freeze(self):
604
+ for _, param in self.named_parameters():
605
+ param.requires_grad = False