minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
s3tokenizer/utils.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
# Copyright (c) 2023 OpenAI. (authors: Whisper Team)
|
|
2
|
+
# 2024 Tsinghua Univ. (authors: Xingchen Song)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
|
|
16
|
+
Add rename_weights() & onnx2torch() & make_non_pad_mask() & mask_to_bias()
|
|
17
|
+
Copy merge_tokenized_segments() from https://github.com/Mddct/s3tokenizer-long/blob/main/example.py
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from functools import lru_cache
|
|
22
|
+
from typing import List, Optional, Union
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import onnx
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
import torchaudio
|
|
29
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _rename_weights(weights_dict: dict):
|
|
33
|
+
"""
|
|
34
|
+
Rename onnx weights to pytorch format.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
weight_dict: dict
|
|
39
|
+
The dict containing weights in onnx format
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
A new weight dict containing the weights in pytorch format.
|
|
44
|
+
"""
|
|
45
|
+
new_weight_dict = {}
|
|
46
|
+
for k in weights_dict.keys():
|
|
47
|
+
if "quantizer" in k: # vq or fsq
|
|
48
|
+
if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1":
|
|
49
|
+
new_weight_dict["quantizer._codebook.embed"] = weights_dict[k]
|
|
50
|
+
elif 'project_down' in k: # v2
|
|
51
|
+
new_weight_dict[k] = weights_dict[k]
|
|
52
|
+
elif "positional_embedding" in k: # positional emb
|
|
53
|
+
new_weight_dict[k] = weights_dict[k]
|
|
54
|
+
elif "conv" in k: # 1/2 or 1/4 subsample
|
|
55
|
+
new_weight_dict[k] = weights_dict[k]
|
|
56
|
+
else: # transformer blocks
|
|
57
|
+
assert "blocks" in k
|
|
58
|
+
new_k = (k[1:].replace('/', '.').replace(
|
|
59
|
+
'MatMul', 'weight').replace('Add_1', 'bias').replace(
|
|
60
|
+
'Mul', 'weight').replace('Add', 'bias').replace(
|
|
61
|
+
'mlp.mlp', 'mlp')).replace('fsmn_block.Conv',
|
|
62
|
+
'fsmn_block.weight')
|
|
63
|
+
|
|
64
|
+
new_weight_dict[f"encoder.{new_k}"] = weights_dict[k]
|
|
65
|
+
return new_weight_dict
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False):
|
|
69
|
+
"""
|
|
70
|
+
Open an onnx file and convert to pytorch format.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
onnx_path: str
|
|
75
|
+
The onnx file to open, typically `speech_tokenizer_v1.onnx`
|
|
76
|
+
|
|
77
|
+
torch_path: str
|
|
78
|
+
The path to save the torch-formated checkpoint.
|
|
79
|
+
|
|
80
|
+
verbose: bool
|
|
81
|
+
Logging info or not.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
A checkpoint dict containing the weights and their names, if torch_path is
|
|
86
|
+
None. Otherwise save checkpoint dict to the desired path.
|
|
87
|
+
"""
|
|
88
|
+
onnx_model = onnx.load(onnx_path)
|
|
89
|
+
weights_dict = {}
|
|
90
|
+
initializer_map = {
|
|
91
|
+
initializer.name: initializer
|
|
92
|
+
for initializer in onnx_model.graph.initializer
|
|
93
|
+
}
|
|
94
|
+
for node in onnx_model.graph.node:
|
|
95
|
+
for input_name in node.input:
|
|
96
|
+
if input_name in initializer_map:
|
|
97
|
+
ln_bias_name, ln_weight_name = None, None # for v2 ln
|
|
98
|
+
initializer = initializer_map[input_name]
|
|
99
|
+
if input_name in [
|
|
100
|
+
"onnx::Conv_1519",
|
|
101
|
+
"encoders.conv1.weight",
|
|
102
|
+
"onnx::Conv_2216",
|
|
103
|
+
]: # v1_50hz, v1_25hz, v2_25hz
|
|
104
|
+
weight_name = "encoder.conv1.weight"
|
|
105
|
+
elif input_name in [
|
|
106
|
+
"onnx::Conv_1520",
|
|
107
|
+
"encoders.conv1.bias",
|
|
108
|
+
"onnx::Conv_2217",
|
|
109
|
+
]: # v1_50hz, v1_25hz, v2_25hz
|
|
110
|
+
weight_name = "encoder.conv1.bias"
|
|
111
|
+
elif input_name in [
|
|
112
|
+
"onnx::Conv_1521",
|
|
113
|
+
"encoders.conv2.weight",
|
|
114
|
+
"onnx::Conv_2218",
|
|
115
|
+
]:
|
|
116
|
+
weight_name = "encoder.conv2.weight"
|
|
117
|
+
elif input_name in [
|
|
118
|
+
"onnx::Conv_1522",
|
|
119
|
+
"encoders.conv2.bias",
|
|
120
|
+
"onnx::Conv_2219",
|
|
121
|
+
]:
|
|
122
|
+
weight_name = "encoder.conv2.bias"
|
|
123
|
+
elif input_name == "encoders.positional_embedding":
|
|
124
|
+
weight_name = "encoder.positional_embedding"
|
|
125
|
+
elif input_name == 'quantizer.project_in.bias':
|
|
126
|
+
weight_name = "quantizer._codebook.project_down.bias"
|
|
127
|
+
elif input_name == 'onnx::MatMul_2536':
|
|
128
|
+
weight_name = "quantizer._codebook.project_down.weight"
|
|
129
|
+
else:
|
|
130
|
+
if node.op_type == 'LayerNormalization': # in input_name:
|
|
131
|
+
ln_name = node.name.replace('/LayerNormalization', '')
|
|
132
|
+
ln_weight_name = ln_name + '.weight'
|
|
133
|
+
ln_bias_name = ln_name + '.bias'
|
|
134
|
+
else:
|
|
135
|
+
weight_name = node.name
|
|
136
|
+
if ln_weight_name is not None and ln_bias_name is not None:
|
|
137
|
+
ln_inputs = node.input
|
|
138
|
+
scale_name = ln_inputs[1]
|
|
139
|
+
bias_name = ln_inputs[2]
|
|
140
|
+
scale = onnx.numpy_helper.to_array(
|
|
141
|
+
initializer_map[scale_name]).copy(
|
|
142
|
+
) if scale_name in initializer_map else None
|
|
143
|
+
bias = onnx.numpy_helper.to_array(
|
|
144
|
+
initializer_map[bias_name]).copy(
|
|
145
|
+
) if bias_name in initializer_map else None
|
|
146
|
+
scale.flags.writeable = True
|
|
147
|
+
bias.flags.writeable = True
|
|
148
|
+
weight_tensor = torch.from_numpy(scale)
|
|
149
|
+
bias_tensor = torch.from_numpy(bias)
|
|
150
|
+
|
|
151
|
+
weights_dict[ln_bias_name] = bias_tensor
|
|
152
|
+
weights_dict[ln_weight_name] = weight_tensor
|
|
153
|
+
else:
|
|
154
|
+
weight_array = onnx.numpy_helper.to_array(
|
|
155
|
+
initializer).copy()
|
|
156
|
+
weight_array.flags.writeable = True
|
|
157
|
+
weight_tensor = torch.from_numpy(weight_array)
|
|
158
|
+
if len(weight_tensor.shape) > 2 or weight_name in [
|
|
159
|
+
"encoder.positional_embedding"
|
|
160
|
+
]:
|
|
161
|
+
weights_dict[weight_name] = weight_tensor
|
|
162
|
+
else:
|
|
163
|
+
weights_dict[weight_name] = weight_tensor.t()
|
|
164
|
+
|
|
165
|
+
new_weights_dict = _rename_weights(weights_dict)
|
|
166
|
+
if verbose:
|
|
167
|
+
for k, v in new_weights_dict.items():
|
|
168
|
+
print(f"{k} : {v.shape} {v.dtype}")
|
|
169
|
+
print(f"PyTorch weights saved to {torch_path}")
|
|
170
|
+
del weights_dict, onnx_model
|
|
171
|
+
if torch_path:
|
|
172
|
+
torch.save(new_weights_dict, torch_path)
|
|
173
|
+
else:
|
|
174
|
+
return new_weights_dict
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def load_audio(file: str, sr: int = 16000):
|
|
178
|
+
"""
|
|
179
|
+
Open an audio file and read as mono waveform, resampling as necessary
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
file: str
|
|
184
|
+
The audio file to open
|
|
185
|
+
|
|
186
|
+
sr: int
|
|
187
|
+
The sample rate to resample the audio if necessary
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
A torch.Tensor containing the audio waveform, in float32 dtype.
|
|
192
|
+
"""
|
|
193
|
+
audio, sample_rate = torchaudio.load(file)
|
|
194
|
+
if sample_rate != sr:
|
|
195
|
+
audio = torchaudio.transforms.Resample(sample_rate, sr)(audio)
|
|
196
|
+
audio = audio[0] # get the first channel
|
|
197
|
+
return audio
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@lru_cache(maxsize=None)
|
|
201
|
+
def _mel_filters(device, n_mels: int) -> torch.Tensor:
|
|
202
|
+
"""
|
|
203
|
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
|
204
|
+
Allows decoupling librosa dependency; saved using:
|
|
205
|
+
|
|
206
|
+
np.savez_compressed(
|
|
207
|
+
"mel_filters.npz",
|
|
208
|
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
|
209
|
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
|
210
|
+
)
|
|
211
|
+
"""
|
|
212
|
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
|
213
|
+
|
|
214
|
+
filters_path = os.path.join(os.path.dirname(__file__), "assets",
|
|
215
|
+
"mel_filters.npz")
|
|
216
|
+
with np.load(filters_path, allow_pickle=False) as f:
|
|
217
|
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def log_mel_spectrogram(
|
|
221
|
+
audio: Union[str, np.ndarray, torch.Tensor],
|
|
222
|
+
n_mels: int = 128,
|
|
223
|
+
padding: int = 0,
|
|
224
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
225
|
+
):
|
|
226
|
+
"""
|
|
227
|
+
Compute the log-Mel spectrogram of
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
|
232
|
+
The path to audio or either a NumPy array or Tensor containing the
|
|
233
|
+
audio waveform in 16 kHz
|
|
234
|
+
|
|
235
|
+
n_mels: int
|
|
236
|
+
The number of Mel-frequency filters, only 80 is supported
|
|
237
|
+
|
|
238
|
+
padding: int
|
|
239
|
+
Number of zero samples to pad to the right
|
|
240
|
+
|
|
241
|
+
device: Optional[Union[str, torch.device]]
|
|
242
|
+
If given, the audio tensor is moved to this device before STFT
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
torch.Tensor, shape = (128, n_frames)
|
|
247
|
+
A Tensor that contains the Mel spectrogram
|
|
248
|
+
"""
|
|
249
|
+
if not torch.is_tensor(audio):
|
|
250
|
+
if isinstance(audio, str):
|
|
251
|
+
audio = load_audio(audio)
|
|
252
|
+
|
|
253
|
+
if device is not None:
|
|
254
|
+
audio = audio.to(device)
|
|
255
|
+
if padding > 0:
|
|
256
|
+
audio = F.pad(audio, (0, padding))
|
|
257
|
+
window = torch.hann_window(400).to(audio.device)
|
|
258
|
+
stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
|
|
259
|
+
magnitudes = stft[..., :-1].abs()**2
|
|
260
|
+
|
|
261
|
+
filters = _mel_filters(audio.device, n_mels)
|
|
262
|
+
mel_spec = filters @ magnitudes
|
|
263
|
+
|
|
264
|
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
265
|
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
266
|
+
log_spec = (log_spec + 4.0) / 4.0
|
|
267
|
+
return log_spec
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
271
|
+
"""Make mask tensor containing indices of non-padded part.
|
|
272
|
+
|
|
273
|
+
The sequences in a batch may have different lengths. To enable
|
|
274
|
+
batch computing, padding is need to make all sequence in same
|
|
275
|
+
size. To avoid the padding part pass value to context dependent
|
|
276
|
+
block such as attention or convolution , this padding part is
|
|
277
|
+
masked.
|
|
278
|
+
|
|
279
|
+
1 for non-padded part and 0 for padded part.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
lengths (torch.Tensor): Batch of lengths (B,).
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
-------
|
|
287
|
+
torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
|
|
288
|
+
|
|
289
|
+
Examples:
|
|
290
|
+
>>> import torch
|
|
291
|
+
>>> import s3tokenizer
|
|
292
|
+
>>> lengths = torch.tensor([5, 3, 2])
|
|
293
|
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
|
294
|
+
masks = [[1, 1, 1, 1, 1],
|
|
295
|
+
[1, 1, 1, 0, 0],
|
|
296
|
+
[1, 1, 0, 0, 0]]
|
|
297
|
+
"""
|
|
298
|
+
batch_size = lengths.size(0)
|
|
299
|
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
300
|
+
seq_range = torch.arange(0,
|
|
301
|
+
max_len,
|
|
302
|
+
dtype=torch.int64,
|
|
303
|
+
device=lengths.device)
|
|
304
|
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
305
|
+
seq_length_expand = lengths.unsqueeze(-1)
|
|
306
|
+
mask = seq_range_expand >= seq_length_expand
|
|
307
|
+
return ~mask
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
311
|
+
"""Convert bool-tensor to float-tensor for flash attention.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
lengths (torch.Tensor): Batch of lengths (B, ?).
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
-------
|
|
319
|
+
torch.Tensor: Mask tensor containing indices of padded part (B, ?).
|
|
320
|
+
|
|
321
|
+
Examples:
|
|
322
|
+
>>> import torch
|
|
323
|
+
>>> import s3tokenizer
|
|
324
|
+
>>> lengths = torch.tensor([5, 3, 2])
|
|
325
|
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
|
326
|
+
masks = [[1, 1, 1, 1, 1],
|
|
327
|
+
[1, 1, 1, 0, 0],
|
|
328
|
+
[1, 1, 0, 0, 0]]
|
|
329
|
+
>>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
|
|
330
|
+
new_masks =
|
|
331
|
+
[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
|
|
332
|
+
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
|
|
333
|
+
[-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
|
|
334
|
+
"""
|
|
335
|
+
assert mask.dtype == torch.bool
|
|
336
|
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
|
337
|
+
mask = mask.to(dtype)
|
|
338
|
+
|
|
339
|
+
# attention mask bias
|
|
340
|
+
# NOTE(Mddct): torch.finfo jit issues
|
|
341
|
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
|
342
|
+
mask = (1.0 - mask) * -1.0e+10
|
|
343
|
+
return mask
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def padding(data: List[torch.Tensor]):
|
|
347
|
+
""" Padding the data into batch data
|
|
348
|
+
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
data: List[Tensor], shape of Tensor (128, T)
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
-------
|
|
355
|
+
feats [B, 128, T_max], feats lengths [B]
|
|
356
|
+
"""
|
|
357
|
+
sample = data
|
|
358
|
+
assert isinstance(sample, list)
|
|
359
|
+
feats_lengths = torch.tensor([s.size(1) for s in sample],
|
|
360
|
+
dtype=torch.int32)
|
|
361
|
+
feats = [s.t() for s in sample]
|
|
362
|
+
padded_feats = pad_sequence(feats, batch_first=True, padding_value=0)
|
|
363
|
+
|
|
364
|
+
return padded_feats.transpose(1, 2), feats_lengths
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def merge_tokenized_segments(tokenized_segments, overlap, token_rate):
|
|
368
|
+
"""
|
|
369
|
+
Merges tokenized outputs by keeping the middle and dropping half of the overlapped tokens.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
- tokenized_segments (List[List[int]]): List of tokenized sequences.
|
|
373
|
+
- overlap (int): Overlapping duration in seconds (default: 4s).
|
|
374
|
+
- token_rate (int): Number of tokens per second.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
- List[int]: A single merged token sequence.
|
|
378
|
+
"""
|
|
379
|
+
merged_tokens = []
|
|
380
|
+
overlap_tokens = (
|
|
381
|
+
overlap //
|
|
382
|
+
2) * token_rate # Tokens corresponding to half of the overlap duration
|
|
383
|
+
|
|
384
|
+
for i, tokens in enumerate(tokenized_segments):
|
|
385
|
+
l = 0 if i == 0 else overlap_tokens
|
|
386
|
+
r = -overlap_tokens if i != len(tokenized_segments) - 1 else len(tokens)
|
|
387
|
+
# Keep only the middle part (drop overlap / 2 from both sides)
|
|
388
|
+
merged_tokens.extend(tokens[l:r])
|
|
389
|
+
|
|
390
|
+
return merged_tokens
|
stepaudio2/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""StepAudio2: Audio tokenizer and TTS model package."""
|
|
2
|
+
|
|
3
|
+
from .token2wav import Token2wav
|
|
4
|
+
from .stepaudio2 import StepAudio2Base, StepAudio2
|
|
5
|
+
|
|
6
|
+
# Export classes from flashcosyvoice for backward compatibility
|
|
7
|
+
from .flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec as CausalMaskedDiffWithXvecFlash
|
|
8
|
+
|
|
9
|
+
# Export classes from cosyvoice2 (used in flow.yaml configs)
|
|
10
|
+
from .cosyvoice2.flow.flow import CausalMaskedDiffWithXvec
|
|
11
|
+
from .cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
|
|
12
|
+
from .cosyvoice2.flow.flow_matching import CausalConditionalCFM
|
|
13
|
+
from .cosyvoice2.flow.decoder_dit import DiT
|
|
14
|
+
|
|
15
|
+
# Export utility classes that might be referenced in configs
|
|
16
|
+
from .cosyvoice2.transformer.attention import RelPositionMultiHeadedAttention
|
|
17
|
+
from .cosyvoice2.transformer.embedding import EspnetRelPositionalEncoding
|
|
18
|
+
from .cosyvoice2.transformer.subsampling import LinearNoSubsampling
|
|
19
|
+
from .cosyvoice2.transformer.encoder_layer import ConformerEncoderLayer
|
|
20
|
+
from .cosyvoice2.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
|
21
|
+
|
|
22
|
+
# Export HiFTGenerator if needed
|
|
23
|
+
from .flashcosyvoice.modules.hifigan import HiFTGenerator
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
'Token2wav',
|
|
27
|
+
'StepAudio2Base',
|
|
28
|
+
'StepAudio2',
|
|
29
|
+
'CausalMaskedDiffWithXvec',
|
|
30
|
+
'CausalMaskedDiffWithXvecFlash',
|
|
31
|
+
'UpsampleConformerEncoderV2',
|
|
32
|
+
'CausalConditionalCFM',
|
|
33
|
+
'DiT',
|
|
34
|
+
'RelPositionMultiHeadedAttention',
|
|
35
|
+
'EspnetRelPositionalEncoding',
|
|
36
|
+
'LinearNoSubsampling',
|
|
37
|
+
'ConformerEncoderLayer',
|
|
38
|
+
'PositionwiseFeedForward',
|
|
39
|
+
'HiFTGenerator',
|
|
40
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""CosyVoice2 subpackage for StepAudio2."""
|
|
File without changes
|