minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
s3tokenizer/utils.py ADDED
@@ -0,0 +1,390 @@
1
+ # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2
+ # 2024 Tsinghua Univ. (authors: Xingchen Song)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
16
+ Add rename_weights() & onnx2torch() & make_non_pad_mask() & mask_to_bias()
17
+ Copy merge_tokenized_segments() from https://github.com/Mddct/s3tokenizer-long/blob/main/example.py
18
+ """
19
+
20
+ import os
21
+ from functools import lru_cache
22
+ from typing import List, Optional, Union
23
+
24
+ import numpy as np
25
+ import onnx
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torchaudio
29
+ from torch.nn.utils.rnn import pad_sequence
30
+
31
+
32
+ def _rename_weights(weights_dict: dict):
33
+ """
34
+ Rename onnx weights to pytorch format.
35
+
36
+ Parameters
37
+ ----------
38
+ weight_dict: dict
39
+ The dict containing weights in onnx format
40
+
41
+ Returns
42
+ -------
43
+ A new weight dict containing the weights in pytorch format.
44
+ """
45
+ new_weight_dict = {}
46
+ for k in weights_dict.keys():
47
+ if "quantizer" in k: # vq or fsq
48
+ if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1":
49
+ new_weight_dict["quantizer._codebook.embed"] = weights_dict[k]
50
+ elif 'project_down' in k: # v2
51
+ new_weight_dict[k] = weights_dict[k]
52
+ elif "positional_embedding" in k: # positional emb
53
+ new_weight_dict[k] = weights_dict[k]
54
+ elif "conv" in k: # 1/2 or 1/4 subsample
55
+ new_weight_dict[k] = weights_dict[k]
56
+ else: # transformer blocks
57
+ assert "blocks" in k
58
+ new_k = (k[1:].replace('/', '.').replace(
59
+ 'MatMul', 'weight').replace('Add_1', 'bias').replace(
60
+ 'Mul', 'weight').replace('Add', 'bias').replace(
61
+ 'mlp.mlp', 'mlp')).replace('fsmn_block.Conv',
62
+ 'fsmn_block.weight')
63
+
64
+ new_weight_dict[f"encoder.{new_k}"] = weights_dict[k]
65
+ return new_weight_dict
66
+
67
+
68
+ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False):
69
+ """
70
+ Open an onnx file and convert to pytorch format.
71
+
72
+ Parameters
73
+ ----------
74
+ onnx_path: str
75
+ The onnx file to open, typically `speech_tokenizer_v1.onnx`
76
+
77
+ torch_path: str
78
+ The path to save the torch-formated checkpoint.
79
+
80
+ verbose: bool
81
+ Logging info or not.
82
+
83
+ Returns
84
+ -------
85
+ A checkpoint dict containing the weights and their names, if torch_path is
86
+ None. Otherwise save checkpoint dict to the desired path.
87
+ """
88
+ onnx_model = onnx.load(onnx_path)
89
+ weights_dict = {}
90
+ initializer_map = {
91
+ initializer.name: initializer
92
+ for initializer in onnx_model.graph.initializer
93
+ }
94
+ for node in onnx_model.graph.node:
95
+ for input_name in node.input:
96
+ if input_name in initializer_map:
97
+ ln_bias_name, ln_weight_name = None, None # for v2 ln
98
+ initializer = initializer_map[input_name]
99
+ if input_name in [
100
+ "onnx::Conv_1519",
101
+ "encoders.conv1.weight",
102
+ "onnx::Conv_2216",
103
+ ]: # v1_50hz, v1_25hz, v2_25hz
104
+ weight_name = "encoder.conv1.weight"
105
+ elif input_name in [
106
+ "onnx::Conv_1520",
107
+ "encoders.conv1.bias",
108
+ "onnx::Conv_2217",
109
+ ]: # v1_50hz, v1_25hz, v2_25hz
110
+ weight_name = "encoder.conv1.bias"
111
+ elif input_name in [
112
+ "onnx::Conv_1521",
113
+ "encoders.conv2.weight",
114
+ "onnx::Conv_2218",
115
+ ]:
116
+ weight_name = "encoder.conv2.weight"
117
+ elif input_name in [
118
+ "onnx::Conv_1522",
119
+ "encoders.conv2.bias",
120
+ "onnx::Conv_2219",
121
+ ]:
122
+ weight_name = "encoder.conv2.bias"
123
+ elif input_name == "encoders.positional_embedding":
124
+ weight_name = "encoder.positional_embedding"
125
+ elif input_name == 'quantizer.project_in.bias':
126
+ weight_name = "quantizer._codebook.project_down.bias"
127
+ elif input_name == 'onnx::MatMul_2536':
128
+ weight_name = "quantizer._codebook.project_down.weight"
129
+ else:
130
+ if node.op_type == 'LayerNormalization': # in input_name:
131
+ ln_name = node.name.replace('/LayerNormalization', '')
132
+ ln_weight_name = ln_name + '.weight'
133
+ ln_bias_name = ln_name + '.bias'
134
+ else:
135
+ weight_name = node.name
136
+ if ln_weight_name is not None and ln_bias_name is not None:
137
+ ln_inputs = node.input
138
+ scale_name = ln_inputs[1]
139
+ bias_name = ln_inputs[2]
140
+ scale = onnx.numpy_helper.to_array(
141
+ initializer_map[scale_name]).copy(
142
+ ) if scale_name in initializer_map else None
143
+ bias = onnx.numpy_helper.to_array(
144
+ initializer_map[bias_name]).copy(
145
+ ) if bias_name in initializer_map else None
146
+ scale.flags.writeable = True
147
+ bias.flags.writeable = True
148
+ weight_tensor = torch.from_numpy(scale)
149
+ bias_tensor = torch.from_numpy(bias)
150
+
151
+ weights_dict[ln_bias_name] = bias_tensor
152
+ weights_dict[ln_weight_name] = weight_tensor
153
+ else:
154
+ weight_array = onnx.numpy_helper.to_array(
155
+ initializer).copy()
156
+ weight_array.flags.writeable = True
157
+ weight_tensor = torch.from_numpy(weight_array)
158
+ if len(weight_tensor.shape) > 2 or weight_name in [
159
+ "encoder.positional_embedding"
160
+ ]:
161
+ weights_dict[weight_name] = weight_tensor
162
+ else:
163
+ weights_dict[weight_name] = weight_tensor.t()
164
+
165
+ new_weights_dict = _rename_weights(weights_dict)
166
+ if verbose:
167
+ for k, v in new_weights_dict.items():
168
+ print(f"{k} : {v.shape} {v.dtype}")
169
+ print(f"PyTorch weights saved to {torch_path}")
170
+ del weights_dict, onnx_model
171
+ if torch_path:
172
+ torch.save(new_weights_dict, torch_path)
173
+ else:
174
+ return new_weights_dict
175
+
176
+
177
+ def load_audio(file: str, sr: int = 16000):
178
+ """
179
+ Open an audio file and read as mono waveform, resampling as necessary
180
+
181
+ Parameters
182
+ ----------
183
+ file: str
184
+ The audio file to open
185
+
186
+ sr: int
187
+ The sample rate to resample the audio if necessary
188
+
189
+ Returns
190
+ -------
191
+ A torch.Tensor containing the audio waveform, in float32 dtype.
192
+ """
193
+ audio, sample_rate = torchaudio.load(file)
194
+ if sample_rate != sr:
195
+ audio = torchaudio.transforms.Resample(sample_rate, sr)(audio)
196
+ audio = audio[0] # get the first channel
197
+ return audio
198
+
199
+
200
+ @lru_cache(maxsize=None)
201
+ def _mel_filters(device, n_mels: int) -> torch.Tensor:
202
+ """
203
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
204
+ Allows decoupling librosa dependency; saved using:
205
+
206
+ np.savez_compressed(
207
+ "mel_filters.npz",
208
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
209
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
210
+ )
211
+ """
212
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
213
+
214
+ filters_path = os.path.join(os.path.dirname(__file__), "assets",
215
+ "mel_filters.npz")
216
+ with np.load(filters_path, allow_pickle=False) as f:
217
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
218
+
219
+
220
+ def log_mel_spectrogram(
221
+ audio: Union[str, np.ndarray, torch.Tensor],
222
+ n_mels: int = 128,
223
+ padding: int = 0,
224
+ device: Optional[Union[str, torch.device]] = None,
225
+ ):
226
+ """
227
+ Compute the log-Mel spectrogram of
228
+
229
+ Parameters
230
+ ----------
231
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
232
+ The path to audio or either a NumPy array or Tensor containing the
233
+ audio waveform in 16 kHz
234
+
235
+ n_mels: int
236
+ The number of Mel-frequency filters, only 80 is supported
237
+
238
+ padding: int
239
+ Number of zero samples to pad to the right
240
+
241
+ device: Optional[Union[str, torch.device]]
242
+ If given, the audio tensor is moved to this device before STFT
243
+
244
+ Returns
245
+ -------
246
+ torch.Tensor, shape = (128, n_frames)
247
+ A Tensor that contains the Mel spectrogram
248
+ """
249
+ if not torch.is_tensor(audio):
250
+ if isinstance(audio, str):
251
+ audio = load_audio(audio)
252
+
253
+ if device is not None:
254
+ audio = audio.to(device)
255
+ if padding > 0:
256
+ audio = F.pad(audio, (0, padding))
257
+ window = torch.hann_window(400).to(audio.device)
258
+ stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
259
+ magnitudes = stft[..., :-1].abs()**2
260
+
261
+ filters = _mel_filters(audio.device, n_mels)
262
+ mel_spec = filters @ magnitudes
263
+
264
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
265
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
266
+ log_spec = (log_spec + 4.0) / 4.0
267
+ return log_spec
268
+
269
+
270
+ def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
271
+ """Make mask tensor containing indices of non-padded part.
272
+
273
+ The sequences in a batch may have different lengths. To enable
274
+ batch computing, padding is need to make all sequence in same
275
+ size. To avoid the padding part pass value to context dependent
276
+ block such as attention or convolution , this padding part is
277
+ masked.
278
+
279
+ 1 for non-padded part and 0 for padded part.
280
+
281
+ Parameters
282
+ ----------
283
+ lengths (torch.Tensor): Batch of lengths (B,).
284
+
285
+ Returns:
286
+ -------
287
+ torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
288
+
289
+ Examples:
290
+ >>> import torch
291
+ >>> import s3tokenizer
292
+ >>> lengths = torch.tensor([5, 3, 2])
293
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
294
+ masks = [[1, 1, 1, 1, 1],
295
+ [1, 1, 1, 0, 0],
296
+ [1, 1, 0, 0, 0]]
297
+ """
298
+ batch_size = lengths.size(0)
299
+ max_len = max_len if max_len > 0 else lengths.max().item()
300
+ seq_range = torch.arange(0,
301
+ max_len,
302
+ dtype=torch.int64,
303
+ device=lengths.device)
304
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
305
+ seq_length_expand = lengths.unsqueeze(-1)
306
+ mask = seq_range_expand >= seq_length_expand
307
+ return ~mask
308
+
309
+
310
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
311
+ """Convert bool-tensor to float-tensor for flash attention.
312
+
313
+ Parameters
314
+ ----------
315
+ lengths (torch.Tensor): Batch of lengths (B, ?).
316
+
317
+ Returns:
318
+ -------
319
+ torch.Tensor: Mask tensor containing indices of padded part (B, ?).
320
+
321
+ Examples:
322
+ >>> import torch
323
+ >>> import s3tokenizer
324
+ >>> lengths = torch.tensor([5, 3, 2])
325
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
326
+ masks = [[1, 1, 1, 1, 1],
327
+ [1, 1, 1, 0, 0],
328
+ [1, 1, 0, 0, 0]]
329
+ >>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
330
+ new_masks =
331
+ [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
332
+ [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
333
+ [-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
334
+ """
335
+ assert mask.dtype == torch.bool
336
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
337
+ mask = mask.to(dtype)
338
+
339
+ # attention mask bias
340
+ # NOTE(Mddct): torch.finfo jit issues
341
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
342
+ mask = (1.0 - mask) * -1.0e+10
343
+ return mask
344
+
345
+
346
+ def padding(data: List[torch.Tensor]):
347
+ """ Padding the data into batch data
348
+
349
+ Parameters
350
+ ----------
351
+ data: List[Tensor], shape of Tensor (128, T)
352
+
353
+ Returns:
354
+ -------
355
+ feats [B, 128, T_max], feats lengths [B]
356
+ """
357
+ sample = data
358
+ assert isinstance(sample, list)
359
+ feats_lengths = torch.tensor([s.size(1) for s in sample],
360
+ dtype=torch.int32)
361
+ feats = [s.t() for s in sample]
362
+ padded_feats = pad_sequence(feats, batch_first=True, padding_value=0)
363
+
364
+ return padded_feats.transpose(1, 2), feats_lengths
365
+
366
+
367
+ def merge_tokenized_segments(tokenized_segments, overlap, token_rate):
368
+ """
369
+ Merges tokenized outputs by keeping the middle and dropping half of the overlapped tokens.
370
+
371
+ Args:
372
+ - tokenized_segments (List[List[int]]): List of tokenized sequences.
373
+ - overlap (int): Overlapping duration in seconds (default: 4s).
374
+ - token_rate (int): Number of tokens per second.
375
+
376
+ Returns:
377
+ - List[int]: A single merged token sequence.
378
+ """
379
+ merged_tokens = []
380
+ overlap_tokens = (
381
+ overlap //
382
+ 2) * token_rate # Tokens corresponding to half of the overlap duration
383
+
384
+ for i, tokens in enumerate(tokenized_segments):
385
+ l = 0 if i == 0 else overlap_tokens
386
+ r = -overlap_tokens if i != len(tokenized_segments) - 1 else len(tokens)
387
+ # Keep only the middle part (drop overlap / 2 from both sides)
388
+ merged_tokens.extend(tokens[l:r])
389
+
390
+ return merged_tokens
stepaudio2/__init__.py ADDED
@@ -0,0 +1,40 @@
1
+ """StepAudio2: Audio tokenizer and TTS model package."""
2
+
3
+ from .token2wav import Token2wav
4
+ from .stepaudio2 import StepAudio2Base, StepAudio2
5
+
6
+ # Export classes from flashcosyvoice for backward compatibility
7
+ from .flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec as CausalMaskedDiffWithXvecFlash
8
+
9
+ # Export classes from cosyvoice2 (used in flow.yaml configs)
10
+ from .cosyvoice2.flow.flow import CausalMaskedDiffWithXvec
11
+ from .cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
12
+ from .cosyvoice2.flow.flow_matching import CausalConditionalCFM
13
+ from .cosyvoice2.flow.decoder_dit import DiT
14
+
15
+ # Export utility classes that might be referenced in configs
16
+ from .cosyvoice2.transformer.attention import RelPositionMultiHeadedAttention
17
+ from .cosyvoice2.transformer.embedding import EspnetRelPositionalEncoding
18
+ from .cosyvoice2.transformer.subsampling import LinearNoSubsampling
19
+ from .cosyvoice2.transformer.encoder_layer import ConformerEncoderLayer
20
+ from .cosyvoice2.transformer.positionwise_feed_forward import PositionwiseFeedForward
21
+
22
+ # Export HiFTGenerator if needed
23
+ from .flashcosyvoice.modules.hifigan import HiFTGenerator
24
+
25
+ __all__ = [
26
+ 'Token2wav',
27
+ 'StepAudio2Base',
28
+ 'StepAudio2',
29
+ 'CausalMaskedDiffWithXvec',
30
+ 'CausalMaskedDiffWithXvecFlash',
31
+ 'UpsampleConformerEncoderV2',
32
+ 'CausalConditionalCFM',
33
+ 'DiT',
34
+ 'RelPositionMultiHeadedAttention',
35
+ 'EspnetRelPositionalEncoding',
36
+ 'LinearNoSubsampling',
37
+ 'ConformerEncoderLayer',
38
+ 'PositionwiseFeedForward',
39
+ 'HiFTGenerator',
40
+ ]
@@ -0,0 +1 @@
1
+ """CosyVoice2 subpackage for StepAudio2."""
File without changes