minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
s3tokenizer/model.py ADDED
@@ -0,0 +1,546 @@
1
+ # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2
+ # 2024 Tsinghua Univ. (authors: Xingchen Song)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Modified from https://github.com/openai/whisper/blob/main/whisper/model.py
16
+ Add EuclideanCodebook & VectorQuantization
17
+ """
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from einops import rearrange
26
+ from torch import Tensor, nn
27
+
28
+ from .utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
29
+
30
+
31
+ @dataclass
32
+ class ModelConfig:
33
+ n_mels: int = 128
34
+ n_audio_ctx: int = 1500
35
+ n_audio_state: int = 1280
36
+ n_audio_head: int = 20
37
+ n_audio_layer: int = 6
38
+ n_codebook_size: int = 4096
39
+
40
+ use_sdpa: bool = False
41
+
42
+
43
+ class LayerNorm(nn.LayerNorm):
44
+
45
+ def forward(self, x: Tensor) -> Tensor:
46
+ return super().forward(x.float()).type(x.dtype)
47
+
48
+
49
+ class Linear(nn.Linear):
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return F.linear(
53
+ x,
54
+ self.weight.to(x.dtype),
55
+ None if self.bias is None else self.bias.to(x.dtype),
56
+ )
57
+
58
+
59
+ class Conv1d(nn.Conv1d):
60
+
61
+ def _conv_forward(self, x: Tensor, weight: Tensor,
62
+ bias: Optional[Tensor]) -> Tensor:
63
+ return super()._conv_forward(
64
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
65
+
66
+
67
+ def sinusoids(length, channels, max_timescale=10000):
68
+ """Returns sinusoids for positional embedding"""
69
+ assert channels % 2 == 0
70
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
71
+ inv_timescales = torch.exp(-log_timescale_increment *
72
+ torch.arange(channels // 2))
73
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[
74
+ np.newaxis, :]
75
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
76
+
77
+
78
+ class MultiHeadAttention(nn.Module):
79
+
80
+ def __init__(self, n_state: int, n_head: int, use_sdpa: bool = False):
81
+ super().__init__()
82
+ self.n_head = n_head
83
+ self.query = Linear(n_state, n_state)
84
+ self.key = Linear(n_state, n_state, bias=False)
85
+ self.value = Linear(n_state, n_state)
86
+ self.out = Linear(n_state, n_state)
87
+
88
+ self.use_sdpa = use_sdpa
89
+
90
+ def forward(
91
+ self,
92
+ x: Tensor,
93
+ mask: Optional[Tensor] = None,
94
+ ):
95
+ q = self.query(x)
96
+ k = self.key(x)
97
+ v = self.value(x)
98
+
99
+ wv, qk = self.qkv_attention(q, k, v, mask)
100
+ return self.out(wv), qk
101
+
102
+ def qkv_attention(self,
103
+ q: Tensor,
104
+ k: Tensor,
105
+ v: Tensor,
106
+ mask: Optional[Tensor] = None):
107
+ _, _, D = q.shape
108
+ scale = (D // self.n_head)**-0.25
109
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
110
+ k = k.view(*k.shape[:2], self.n_head, -1)
111
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
112
+
113
+ if not self.use_sdpa:
114
+ k = k.permute(0, 2, 3, 1) * scale
115
+ qk = q @ k # (B, n_head, T, T)
116
+ if mask is not None:
117
+ qk = qk + mask
118
+ qk = qk.float()
119
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
120
+ return (w @ v).permute(0, 2, 1,
121
+ 3).flatten(start_dim=2), qk.detach()
122
+ else:
123
+ k = k.permute(0, 2, 1, 3) * scale
124
+ assert mask is not None
125
+ output = torch.nn.functional.scaled_dot_product_attention(
126
+ q,
127
+ k,
128
+ v,
129
+ attn_mask=mask,
130
+ dropout_p=0.,
131
+ scale=1.,
132
+ )
133
+ output = (output.transpose(1,
134
+ 2).contiguous().view(q.size(0), -1, D)
135
+ ) # (batch, time1, d_model)
136
+ return output, None
137
+
138
+
139
+ class ResidualAttentionBlock(nn.Module):
140
+
141
+ def __init__(self, n_state: int, n_head: int, use_sdpa: bool):
142
+ super().__init__()
143
+
144
+ self.attn = MultiHeadAttention(n_state, n_head, use_sdpa=use_sdpa)
145
+ self.attn_ln = LayerNorm(n_state)
146
+
147
+ n_mlp = n_state * 4
148
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(),
149
+ Linear(n_mlp, n_state))
150
+ self.mlp_ln = LayerNorm(n_state)
151
+
152
+ def forward(
153
+ self,
154
+ x: Tensor,
155
+ mask: Optional[Tensor] = None,
156
+ ):
157
+ x = x + self.attn(self.attn_ln(x), mask=mask)[0]
158
+ x = x + self.mlp(self.mlp_ln(x))
159
+ return x
160
+
161
+
162
+ class AudioEncoder(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ n_mels: int,
167
+ n_ctx: int,
168
+ n_state: int,
169
+ n_head: int,
170
+ n_layer: int,
171
+ stride: int,
172
+ use_sdpa: bool,
173
+ ):
174
+ super().__init__()
175
+ self.stride = stride
176
+ self.conv1 = Conv1d(n_mels,
177
+ n_state,
178
+ kernel_size=3,
179
+ stride=stride,
180
+ padding=1)
181
+ self.conv2 = Conv1d(n_state,
182
+ n_state,
183
+ kernel_size=3,
184
+ stride=2,
185
+ padding=1)
186
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
187
+
188
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
189
+ ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
190
+ for _ in range(n_layer)
191
+ ])
192
+
193
+ def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
194
+ """
195
+ x : torch.Tensor, shape = (batch_size, n_mels, T)
196
+ the mel spectrogram of the audio
197
+ x_len: torch.Tensor, shape = (batch_size,)
198
+ length of each audio in x
199
+ """
200
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
201
+ x = F.gelu(self.conv1(x * mask))
202
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
203
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
204
+ x = F.gelu(self.conv2(x * mask))
205
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
206
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
207
+ x = x.permute(0, 2, 1) # (B, T // 2, n_state)
208
+
209
+ mask = mask_to_bias(mask, x.dtype)
210
+
211
+ x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype)
212
+
213
+ for block in self.blocks:
214
+ x = block(x, mask.unsqueeze(1))
215
+
216
+ return x, x_len
217
+
218
+
219
+ class EuclideanCodebook(nn.Module):
220
+ """Codebook with Euclidean distance (inference-only).
221
+ Args:
222
+ dim (int): Dimension.
223
+ codebook_size (int): Codebook size.
224
+ """
225
+
226
+ def __init__(self, dim: int, codebook_size: int):
227
+ super().__init__()
228
+ embed = torch.zeros(codebook_size, dim)
229
+ self.codebook_size = codebook_size
230
+ self.register_buffer("embed", embed)
231
+
232
+ @torch.inference_mode()
233
+ def preprocess(self, x: Tensor) -> Tensor:
234
+ x = rearrange(x, "... d -> (...) d")
235
+ return x
236
+
237
+ @torch.inference_mode()
238
+ def quantize(self, x: Tensor) -> Tensor:
239
+ embed = self.embed.t().to(x.dtype)
240
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
241
+ embed.pow(2).sum(0, keepdim=True))
242
+ embed_ind = dist.max(dim=-1).indices
243
+ return embed_ind
244
+
245
+ @torch.inference_mode()
246
+ def postprocess_emb(self, embed_ind, shape):
247
+ return embed_ind.view(*shape[:-1])
248
+
249
+ @torch.inference_mode()
250
+ def dequantize(self, embed_ind: Tensor) -> Tensor:
251
+ quantize = F.embedding(embed_ind, self.embed)
252
+ return quantize
253
+
254
+ @torch.inference_mode()
255
+ def encode(self, x: Tensor) -> Tensor:
256
+ shape = x.shape
257
+ # pre-process
258
+ x = self.preprocess(x)
259
+ # quantize
260
+ embed_ind = self.quantize(x)
261
+ # post-process
262
+ embed_ind = self.postprocess_emb(embed_ind, shape)
263
+ return embed_ind
264
+
265
+ @torch.inference_mode()
266
+ def decode(self, embed_ind: Tensor) -> Tensor:
267
+ quantize = self.dequantize(embed_ind)
268
+ return quantize
269
+
270
+
271
+ class VectorQuantization(nn.Module):
272
+ """Vector quantization implementation (inference-only).
273
+ Args:
274
+ dim (int): Dimension
275
+ codebook_size (int): Codebook size
276
+ """
277
+
278
+ def __init__(self, dim: int, codebook_size: int):
279
+ super().__init__()
280
+ self._codebook = EuclideanCodebook(dim=dim,
281
+ codebook_size=codebook_size)
282
+ self.codebook_size = codebook_size
283
+
284
+ @property
285
+ def codebook(self):
286
+ return self._codebook.embed
287
+
288
+ @torch.inference_mode()
289
+ def encode(self, x: Tensor) -> Tensor:
290
+ x = F.normalize(x.float(), p=2, dim=-1)
291
+ embed_in = self._codebook.encode(x)
292
+ return embed_in
293
+
294
+ @torch.inference_mode()
295
+ def decode(self, embed_ind: Tensor) -> Tensor:
296
+ quantize = self._codebook.decode(embed_ind)
297
+ quantize = rearrange(quantize, "b n d -> b d n")
298
+ return quantize
299
+
300
+
301
+ class S3Tokenizer(nn.Module):
302
+ """S3 tokenizer implementation (inference-only).
303
+ Args:
304
+ config (ModelConfig): Config
305
+ """
306
+
307
+ def __init__(self, name: str, config: ModelConfig = ModelConfig()):
308
+ super().__init__()
309
+ self.name = name # Store model name for token_rate determination
310
+ self.config = config
311
+ self.encoder = AudioEncoder(
312
+ self.config.n_mels,
313
+ self.config.n_audio_ctx,
314
+ self.config.n_audio_state,
315
+ self.config.n_audio_head,
316
+ self.config.n_audio_layer,
317
+ 2 if name == "speech_tokenizer_v1_25hz" else 1,
318
+ self.config.use_sdpa,
319
+ )
320
+ self.quantizer = VectorQuantization(self.config.n_audio_state,
321
+ self.config.n_codebook_size)
322
+
323
+ def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
324
+ return self.quantize(mel, mel_len)
325
+
326
+ @torch.inference_mode()
327
+ def quantize(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
328
+ """
329
+ Quantize mel spectrogram to tokens, with automatic long audio handling.
330
+
331
+ Args:
332
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
333
+ mel_len: mel length tensor, shape (batch_size,)
334
+
335
+ Returns:
336
+ code: quantized tokens, shape (batch_size, T')
337
+ code_len: token length, shape (batch_size,)
338
+ """
339
+ # Check if any audio in the batch exceeds 30 seconds
340
+ # Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
341
+ max_frames = 3000
342
+
343
+ # Check which samples are long audio
344
+ long_audio_mask = mel_len > max_frames
345
+
346
+ if long_audio_mask.any():
347
+ # Has long audio - need special processing
348
+ return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
349
+ max_frames)
350
+ else:
351
+ # All short audio - use original method
352
+ hidden, code_len = self.encoder(mel, mel_len)
353
+ code = self.quantizer.encode(hidden)
354
+ return code, code_len
355
+
356
+ @torch.inference_mode()
357
+ def _quantize_mixed_batch(self, mel: Tensor, mel_len: Tensor,
358
+ long_audio_mask: Tensor,
359
+ max_frames: int) -> Tuple[Tensor, Tensor]:
360
+ """
361
+ Handle mixed batch with both short and long audio using unified batch processing.
362
+
363
+ Args:
364
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
365
+ mel_len: mel length tensor, shape (batch_size,)
366
+ long_audio_mask: boolean mask for long audio, shape (batch_size,)
367
+ max_frames: maximum frames for short audio
368
+
369
+ Returns:
370
+ code: quantized tokens, shape (batch_size, T')
371
+ code_len: token length, shape (batch_size,)
372
+ """
373
+ batch_size = mel.size(0)
374
+
375
+ # Parameters for sliding window
376
+ sample_rate = 16000
377
+ hop_length = 160 # Default hop length for mel spectrogram
378
+ window_size = 30 # seconds
379
+ overlap = 4 # seconds
380
+
381
+ # Calculate frame-based parameters
382
+ frames_per_window = window_size * sample_rate // hop_length # 3000 frames
383
+ frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
384
+ frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
385
+
386
+ # Collect all segments to process (including short and long audio segments)
387
+ all_segments = []
388
+ all_segments_len = []
389
+ segment_info = [
390
+ ] # Record which audio each segment belongs to and whether it's long audio
391
+
392
+ # Process all audio in the batch
393
+ for batch_idx in range(batch_size):
394
+ audio_mel = mel[batch_idx]
395
+ audio_mel_len = mel_len[batch_idx]
396
+ is_long_audio = long_audio_mask[batch_idx].item()
397
+
398
+ if not is_long_audio:
399
+ # Short audio: process directly as a single segment
400
+ segment = audio_mel[:, :audio_mel_len]
401
+ seg_len = audio_mel_len.item()
402
+
403
+ # Pad to max_frames if necessary
404
+ if seg_len < frames_per_window:
405
+ pad_size = frames_per_window - seg_len
406
+ segment = F.pad(segment, (0, pad_size))
407
+
408
+ all_segments.append(segment)
409
+ all_segments_len.append(
410
+ torch.tensor(seg_len, device=mel.device))
411
+ segment_info.append({
412
+ 'batch_idx': batch_idx,
413
+ 'is_long_audio': False,
414
+ 'segment_idx': 0,
415
+ 'total_segments': 1
416
+ })
417
+ else:
418
+ # Long audio: split into multiple segments
419
+ start = 0
420
+ segment_idx = 0
421
+ while start < audio_mel_len:
422
+ end = min(start + frames_per_window, audio_mel_len)
423
+ segment = audio_mel[:, start:end]
424
+
425
+ seg_len = segment.size(1)
426
+ # Pad if necessary
427
+ if seg_len < frames_per_window:
428
+ pad_size = frames_per_window - seg_len
429
+ segment = F.pad(segment, (0, pad_size))
430
+
431
+ all_segments.append(segment)
432
+ all_segments_len.append(
433
+ torch.tensor(seg_len, device=mel.device))
434
+ segment_info.append({
435
+ 'batch_idx': batch_idx,
436
+ 'is_long_audio': True,
437
+ 'segment_idx': segment_idx,
438
+ 'total_segments': None # Will be filled later
439
+ })
440
+
441
+ segment_idx += 1
442
+ start += frames_per_stride
443
+
444
+ # Update total_segments info
445
+ total_segments = segment_idx
446
+ for info in segment_info:
447
+ if info['batch_idx'] == batch_idx and info['is_long_audio']:
448
+ info['total_segments'] = total_segments
449
+
450
+ if not all_segments:
451
+ # Fallback if no segments
452
+ return torch.zeros(batch_size,
453
+ 0,
454
+ dtype=torch.long,
455
+ device=mel.device), torch.zeros(
456
+ batch_size,
457
+ dtype=torch.long,
458
+ device=mel.device)
459
+
460
+ # Unified batch processing for all segments
461
+ unified_batch_mel = torch.stack(all_segments)
462
+ unified_batch_lens = torch.stack(all_segments_len)
463
+
464
+ # Process all segments at once
465
+ hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
466
+ codes = self.quantizer.encode(hidden)
467
+
468
+ # Reorganize results based on segment_info
469
+ results = {} # batch_idx -> (code_tensor, code_len)
470
+
471
+ for seg_idx, info in enumerate(segment_info):
472
+ batch_idx = info['batch_idx']
473
+ is_long_audio = info['is_long_audio']
474
+ segment_idx = info['segment_idx']
475
+
476
+ # Get codes for current segment
477
+ segment_code = codes[
478
+ seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
479
+
480
+ if not is_long_audio:
481
+ # Short audio: use directly
482
+ code_tensor = torch.tensor(segment_code,
483
+ dtype=torch.long,
484
+ device=mel.device)
485
+ results[batch_idx] = (code_tensor, len(segment_code))
486
+ else:
487
+ # Long audio: collect all segments
488
+ if batch_idx not in results:
489
+ results[batch_idx] = []
490
+ results[batch_idx].append(segment_code)
491
+
492
+ # Process long audio segment merging
493
+ for batch_idx in range(batch_size):
494
+ if long_audio_mask[batch_idx].item():
495
+ # Merge long audio segments
496
+ audio_codes = results[batch_idx]
497
+
498
+ # Determine token rate based on model name
499
+ if hasattr(self,
500
+ 'name') and self.name == "speech_tokenizer_v1":
501
+ token_rate = 50
502
+ else:
503
+ token_rate = 25
504
+
505
+ merged_codes = merge_tokenized_segments(audio_codes,
506
+ overlap=overlap,
507
+ token_rate=token_rate)
508
+
509
+ # Convert to tensor
510
+ merged_codes_tensor = torch.tensor(merged_codes,
511
+ dtype=torch.long,
512
+ device=mel.device)
513
+ results[batch_idx] = (merged_codes_tensor, len(merged_codes))
514
+
515
+ # Construct final output
516
+ max_code_len = max(code_info[1] for code_info in results.values())
517
+
518
+ output_codes = torch.zeros(batch_size,
519
+ max_code_len,
520
+ dtype=torch.long,
521
+ device=mel.device)
522
+ output_codes_len = torch.zeros(batch_size,
523
+ dtype=torch.long,
524
+ device=mel.device)
525
+
526
+ for batch_idx, (code_tensor, code_len) in results.items():
527
+ output_codes[batch_idx, :code_len] = code_tensor
528
+ output_codes_len[batch_idx] = code_len
529
+
530
+ return output_codes, output_codes_len
531
+
532
+ @property
533
+ def device(self):
534
+ return next(self.parameters()).device
535
+
536
+ def init_from_onnx(self, onnx_path: str):
537
+ ckpt = onnx2torch(onnx_path, None, False)
538
+ self.load_state_dict(ckpt, strict=True)
539
+
540
+ def init_from_pt(self, ckpt_path: str):
541
+ ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
542
+ self.load_state_dict(ckpt, strict=True)
543
+
544
+ def freeze(self):
545
+ for _, param in self.named_parameters():
546
+ param.requires_grad = False