minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,101 @@
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Unility functions for Transformer."""
17
+
18
+ import random
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ IGNORE_ID = -1
25
+
26
+
27
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
28
+ """Perform padding for the list of tensors.
29
+
30
+ Args:
31
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
32
+ pad_value (float): Value for padding.
33
+
34
+ Returns:
35
+ Tensor: Padded tensor (B, Tmax, `*`).
36
+
37
+ Examples:
38
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
39
+ >>> x
40
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
41
+ >>> pad_list(x, 0)
42
+ tensor([[1., 1., 1., 1.],
43
+ [1., 1., 0., 0.],
44
+ [1., 0., 0., 0.]])
45
+
46
+ """
47
+ max_len = max([len(item) for item in xs])
48
+ batchs = len(xs)
49
+ ndim = xs[0].ndim
50
+ if ndim == 1:
51
+ pad_res = torch.zeros(batchs,
52
+ max_len,
53
+ dtype=xs[0].dtype,
54
+ device=xs[0].device)
55
+ elif ndim == 2:
56
+ pad_res = torch.zeros(batchs,
57
+ max_len,
58
+ xs[0].shape[1],
59
+ dtype=xs[0].dtype,
60
+ device=xs[0].device)
61
+ elif ndim == 3:
62
+ pad_res = torch.zeros(batchs,
63
+ max_len,
64
+ xs[0].shape[1],
65
+ xs[0].shape[2],
66
+ dtype=xs[0].dtype,
67
+ device=xs[0].device)
68
+ else:
69
+ raise ValueError(f"Unsupported ndim: {ndim}")
70
+ pad_res.fill_(pad_value)
71
+ for i in range(batchs):
72
+ pad_res[i, :len(xs[i])] = xs[i]
73
+ return pad_res
74
+
75
+
76
+ def get_padding(kernel_size, dilation=1):
77
+ return int((kernel_size * dilation - dilation) / 2)
78
+
79
+
80
+ def init_weights(m, mean=0.0, std=0.01):
81
+ classname = m.__class__.__name__
82
+ if classname.find("Conv") != -1:
83
+ m.weight.data.normal_(mean, std)
84
+
85
+
86
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
87
+ device = fade_in_mel.device
88
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
89
+ mel_overlap_len = int(window.shape[0] / 2)
90
+ if fade_in_mel.device == torch.device('cpu'):
91
+ fade_in_mel = fade_in_mel.clone()
92
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
93
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
94
+ return fade_in_mel.to(device)
95
+
96
+
97
+ def set_all_random_seed(seed):
98
+ random.seed(seed)
99
+ np.random.seed(seed)
100
+ torch.manual_seed(seed)
101
+ torch.cuda.manual_seed_all(seed)
@@ -0,0 +1,49 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import math
18
+ import torch
19
+ from typing import List
20
+
21
+
22
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
23
+ """Make mask tensor containing indices of padded part.
24
+
25
+ See description of make_non_pad_mask.
26
+
27
+ Args:
28
+ lengths (torch.Tensor): Batch of lengths (B,).
29
+ Returns:
30
+ torch.Tensor: Mask tensor containing indices of padded part.
31
+
32
+ Examples:
33
+ >>> lengths = [5, 3, 2]
34
+ >>> make_pad_mask(lengths)
35
+ masks = [[0, 0, 0, 0 ,0],
36
+ [0, 0, 0, 1, 1],
37
+ [0, 0, 1, 1, 1]]
38
+ """
39
+ batch_size = lengths.size(0)
40
+ max_len = max_len if max_len > 0 else lengths.max().item()
41
+ seq_range = torch.arange(0,
42
+ max_len,
43
+ dtype=torch.int64,
44
+ device=lengths.device)
45
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
46
+ seq_length_expand = lengths.unsqueeze(-1)
47
+ mask = seq_range_expand >= seq_length_expand
48
+ return mask
49
+
File without changes
@@ -0,0 +1,424 @@
1
+ # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Example Usage: see README.md
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import random
21
+ import sys
22
+ import time
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from datetime import datetime
25
+
26
+ import numpy as np
27
+ import onnxruntime
28
+ import s3tokenizer
29
+ import torch
30
+ import torch.distributed as dist
31
+ import torchaudio
32
+ import torchaudio.compliance.kaldi as kaldi
33
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
34
+ from tqdm import tqdm
35
+
36
+ from stepaudio2.flashcosyvoice.config import Config, CosyVoice2LLMConfig, SamplingParams
37
+ from stepaudio2.flashcosyvoice.cosyvoice2 import CosyVoice2
38
+ from stepaudio2.flashcosyvoice.utils.audio import mel_spectrogram
39
+
40
+
41
+ def set_all_random_seed(seed):
42
+ random.seed(seed)
43
+ np.random.seed(seed)
44
+ torch.manual_seed(seed)
45
+ torch.cuda.manual_seed_all(seed)
46
+
47
+
48
+ def save_file_async(
49
+ wav, prompt_speech_tokens, generated_speech_tokens,
50
+ info, timing_stats
51
+ ):
52
+ """Save audio asynchronously."""
53
+ try:
54
+ os.makedirs(os.path.dirname(info['wav']), exist_ok=True)
55
+ if wav is not None:
56
+ wav = wav.cpu()
57
+ torchaudio.save(info['wav'], wav, 24000)
58
+ duration = wav.shape[-1] / 24000.0
59
+ rtf = ((timing_stats['dataloader_time'] + timing_stats['model_inference_time']) / timing_stats['batch_size']) / duration
60
+ timing_stats['rtf'] = rtf
61
+ else:
62
+ duration = 0.0
63
+ info['timing_stats'] = timing_stats
64
+ info['prompt_speech_tokens'] = prompt_speech_tokens
65
+ info['generated_speech_tokens'] = generated_speech_tokens
66
+ with open(f"{info['wav'].replace('.wav', '.json')}", "w") as f:
67
+ json.dump(info, f, ensure_ascii=False, indent=4)
68
+ return duration
69
+ except Exception as e:
70
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
71
+ tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}")
72
+ return 0.0
73
+
74
+
75
+ class AudioDataset(Dataset):
76
+
77
+ def __init__(self, text_norm, text_tokenizer, data_list, model_config: Config):
78
+ self.datas = []
79
+ self.text_norm = text_norm
80
+ self.model_config = model_config
81
+
82
+ """Example data_list:
83
+ ```
84
+ {"key": "uttid_1", "prompt_text": "你好,我是小明。", "text": "你好,我是小红。", "prompt_wav": "/mnt/data/audio/00000000.wav", "wav": "/mnt/data/audio_synthetic/uttid_1.wav"}
85
+ {"key": "uttid_2", "prompt_text": "你好,我是小红。", "text": "你好,我是小明。", "prompt_wav": "/mnt/data/audio/00000001.wav", "wav": "/mnt/data/audio_synthetic/uttid_2.wav"}
86
+ ```
87
+ Note:
88
+ - `key` is the key of this sample.
89
+ - `prompt_text` is the text used for prompt.
90
+ - `text` is the text used for generating real audio.
91
+ - `prompt_wav` is the audio used for prompt.
92
+ - `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script).
93
+ """
94
+ missing = 0
95
+ with open(data_list, 'r', encoding='utf-8') as f:
96
+ lines = f.readlines()
97
+ total_lines = len(lines)
98
+ if torch.distributed.get_node_local_rank() == 0:
99
+ iterator = tqdm(lines, desc='Loading data')
100
+ else:
101
+ iterator = lines
102
+ for line in iterator:
103
+ data = json.loads(line.strip())
104
+ valid = True
105
+ for k in ['key', 'prompt_text', 'text', 'prompt_wav']:
106
+ if k not in data:
107
+ valid = False
108
+ break
109
+ if data[k] is None:
110
+ valid = False
111
+ break
112
+ if not os.path.exists(data['prompt_wav']):
113
+ valid = False
114
+ if valid:
115
+ self.datas.append(data)
116
+ else:
117
+ missing += 1
118
+ if torch.distributed.get_node_local_rank() == 0:
119
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
120
+ tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.')
121
+
122
+ self.text_tokenizer = text_tokenizer
123
+
124
+ option = onnxruntime.SessionOptions()
125
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
126
+ option.intra_op_num_threads = 1
127
+ self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option,
128
+ providers=["CPUExecutionProvider"])
129
+
130
+ def __len__(self):
131
+ return len(self.datas)
132
+
133
+ def __getitem__(self, idx):
134
+ data = self.datas[idx]
135
+
136
+ try:
137
+ # 1. feature for s3tokenizer
138
+ audio = s3tokenizer.load_audio(data['prompt_wav'], sr=16000) # [T]
139
+ log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
140
+
141
+ # 2. feature for speaker embedding
142
+ spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
143
+ spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
144
+ spk_emb = self.spk_model.run(
145
+ None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
146
+ )[0].flatten().tolist()
147
+
148
+ # 3. feature for flow
149
+ audio, sample_rate = torchaudio.load(data['prompt_wav'], backend='soundfile')
150
+ audio = audio.mean(dim=0, keepdim=True) # [1, T]
151
+ if sample_rate != 24000:
152
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
153
+ mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
154
+ mel_len = mel.shape[0]
155
+
156
+ # 4. feature for llm
157
+ if self.text_norm is not None:
158
+ prompt_texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['prompt_text'].strip()))["sentences"]]
159
+ prompt_text = ''.join(prompt_texts)
160
+ texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['text'].strip()))["sentences"]]
161
+ text = ''.join(texts)
162
+ else:
163
+ prompt_text = data['prompt_text']
164
+ text = data['text']
165
+ prompt_text_ids = self.text_tokenizer.encode(prompt_text)
166
+ prompt_text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in prompt_text_ids]
167
+ text_ids = self.text_tokenizer.encode(text)
168
+ text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in text_ids]
169
+ item = {
170
+ "prompt_text_tokens": prompt_text_ids, "text_tokens": text_ids,
171
+ "spk_emb": spk_emb, "mel": mel, "mel_len": mel_len, "log_mel": log_mel, "info": data,
172
+ "min_tokens": len(text_ids) * self.model_config.min_token_text_ratio,
173
+ "max_tokens": len(text_ids) * self.model_config.max_token_text_ratio,
174
+ }
175
+ except Exception as e:
176
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
177
+ tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}")
178
+ return None
179
+ return item
180
+
181
+
182
+ def collate_fn(batch):
183
+ prompt_mels_for_llm = [item["log_mel"] for item in batch if item is not None]
184
+ prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_mels_for_llm) # [B, num_mels=128, T]
185
+ prompt_text_tokens_for_llm = [item["prompt_text_tokens"] for item in batch if item is not None]
186
+ text_tokens_for_llm = [item["text_tokens"] for item in batch if item is not None]
187
+ prompt_mels_for_flow = [item["mel"] for item in batch if item is not None]
188
+ prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
189
+ prompt_mels_lens_for_flow = [item["mel_len"] for item in batch if item is not None]
190
+ prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
191
+ spk_emb_for_flow = [item["spk_emb"] for item in batch if item is not None]
192
+ spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
193
+ sampling_params = [SamplingParams(min_tokens=item["min_tokens"], max_tokens=item["max_tokens"], use_ras=True) for item in batch if item is not None]
194
+ infos = [item["info"] for item in batch if item is not None]
195
+ return {
196
+ "prompt_mels_for_llm": prompt_mels_for_llm,
197
+ "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
198
+ "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
199
+ "text_tokens_for_llm": text_tokens_for_llm,
200
+ "prompt_mels_for_flow": prompt_mels_for_flow,
201
+ "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow,
202
+ "spk_emb_for_flow": spk_emb_for_flow,
203
+ "sampling_params": sampling_params,
204
+ "infos": infos,
205
+ }
206
+
207
+
208
+ def init_distributed():
209
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
210
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
211
+ rank = int(os.environ.get('RANK', 0))
212
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
213
+ tqdm.write(f'[{timestamp}] - [INFO] - Inference on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}')
214
+ torch.cuda.set_device(local_rank)
215
+ dist.init_process_group("nccl")
216
+ return world_size, local_rank, rank
217
+
218
+
219
+ def get_args():
220
+ parser = argparse.ArgumentParser(description='FlashCosyVoice')
221
+ parser.add_argument('--model_path',
222
+ required=True,
223
+ type=str,
224
+ help='model path')
225
+ parser.add_argument('--data_list',
226
+ required=True,
227
+ type=str,
228
+ help='data list')
229
+ parser.add_argument('--batch_size_dataloader',
230
+ required=True,
231
+ type=int,
232
+ help='batch size (per-device) for dataloading')
233
+ parser.add_argument('--batch_size_flow',
234
+ required=True,
235
+ type=int,
236
+ help='batch size (per-device) for flow-matching')
237
+ parser.add_argument('--num_workers',
238
+ type=int,
239
+ default=4,
240
+ help='workers for dataloader')
241
+ parser.add_argument('--prefetch',
242
+ type=int,
243
+ default=5,
244
+ help='prefetch for dataloader')
245
+ parser.add_argument('--enable_tn',
246
+ action='store_true',
247
+ help='enable text normalization')
248
+ parser.add_argument('--only_llm',
249
+ action='store_true',
250
+ help='only generate speech tokens from llm')
251
+ parser.add_argument('--fp16_flow',
252
+ action='store_true',
253
+ help='enable fp16 flow')
254
+ parser.add_argument('--seed',
255
+ type=int,
256
+ default=1986,
257
+ help='random seed for generation')
258
+ args = parser.parse_args()
259
+ return args
260
+
261
+
262
+ def main():
263
+ args = get_args()
264
+
265
+ if args.enable_tn:
266
+ # Check python version, if == 3.10, use ttsfrd
267
+ if sys.version_info.major == 3 and sys.version_info.minor == 10:
268
+ # Check if ttsfrd is installed
269
+ try:
270
+ import ttsfrd
271
+ from cosyvoice_ttsfrd import get_resource_path
272
+ except ImportError as e:
273
+ raise ImportError("ttsfrd is not installed, please install it first, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for installation guide.") from e
274
+ text_norm = ttsfrd.TtsFrontendEngine()
275
+ text_norm.initialize(get_resource_path())
276
+ text_norm.set_lang_type('pinyinvg')
277
+ else:
278
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
279
+ tqdm.write(f"[{timestamp}] - [WARNING] - Only python 3.10 is supported for ttsfrd, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for more info. Setting enable_tn to False...")
280
+ # TODO: maybe we should use wetext if python version is not 3.10?
281
+ args.enable_tn = False
282
+ text_norm = None
283
+ else:
284
+ text_norm = None
285
+
286
+ assert (torch.cuda.is_available())
287
+ world_size, local_rank, rank = init_distributed()
288
+ config = Config(model=args.model_path, enforce_eager=True, tensor_parallel_size=1,
289
+ max_num_seqs=args.batch_size_dataloader,
290
+ hf_config=CosyVoice2LLMConfig(fp16_flow=args.fp16_flow), rank=local_rank)
291
+ model = CosyVoice2(config)
292
+
293
+ set_all_random_seed(args.seed)
294
+
295
+ dataset = AudioDataset(text_norm, model.llm.tokenizer, args.data_list, config)
296
+ sampler = DistributedSampler(dataset,
297
+ num_replicas=world_size,
298
+ rank=rank)
299
+ dataloader = DataLoader(dataset, batch_size=args.batch_size_dataloader, num_workers=args.num_workers, pin_memory=True,
300
+ sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn)
301
+ total_steps = len(dataset)
302
+
303
+ if local_rank == 0:
304
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
305
+ tqdm.write(f"[{timestamp}] - [INFO] - {args}")
306
+ progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav",
307
+ position=0, leave=True, dynamic_ncols=True)
308
+
309
+ cpu_counts = os.cpu_count()
310
+ executor = ThreadPoolExecutor(max_workers=min(args.batch_size_dataloader, cpu_counts // 8))
311
+ pending_futures = []
312
+ dataloader_iter = iter(dataloader)
313
+ succeed_duration = 0.01 # avoid division by zero
314
+ start_time = time.time()
315
+ estimated_total_wavs = 0
316
+ succeed_wavs = 0
317
+ failed_wavs = 0
318
+ last_print_time = start_time
319
+
320
+ while True:
321
+ try:
322
+ dataloader_start = time.time()
323
+ batch = next(dataloader_iter)
324
+ dataloader_time = time.time() - dataloader_start
325
+
326
+ if len(batch['infos']) == 0:
327
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
328
+ tqdm.write(f"[{timestamp}] - [WARNING] - rank {rank} of {world_size}: No valid batch found, skipping this batch...")
329
+ continue
330
+
331
+ model_start = time.time()
332
+ results_dict, timing_stats = model(**batch, batch_size_flow=args.batch_size_flow,
333
+ only_llm=args.only_llm)
334
+ model_time = time.time() - model_start
335
+
336
+ estimated_total_wavs += len(results_dict['generated_wavs'])
337
+
338
+ timing_stats['dataloader_time'] = dataloader_time
339
+ timing_stats['model_inference_time'] = model_time
340
+
341
+ if args.only_llm:
342
+ results_dict['generated_wavs'] = [None] * len(results_dict['prompt_speech_tokens'])
343
+
344
+ for i in range(len(results_dict['generated_wavs'])):
345
+ future = executor.submit(
346
+ save_file_async, results_dict['generated_wavs'][i],
347
+ results_dict['prompt_speech_tokens'][i],
348
+ results_dict['generated_speech_tokens'][i],
349
+ batch['infos'][i].copy(), timing_stats.copy()
350
+ )
351
+ pending_futures.append(future)
352
+
353
+ completed_futures = []
354
+ for future in pending_futures:
355
+ if future.done():
356
+ try:
357
+ duration = future.result()
358
+ succeed_duration += duration
359
+ succeed_wavs += 1
360
+ except Exception as e:
361
+ failed_wavs += 1
362
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
363
+ tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in async save task: {e}")
364
+ completed_futures.append(future)
365
+
366
+ for future in completed_futures:
367
+ pending_futures.remove(future)
368
+
369
+ if local_rank == 0:
370
+ update_n = world_size * len(batch["prompt_text_tokens_for_llm"])
371
+ if progress_bar.n + update_n > progress_bar.total:
372
+ progress_bar.update(progress_bar.total - progress_bar.n)
373
+ else:
374
+ progress_bar.update(update_n)
375
+
376
+ current_time = time.time()
377
+ if current_time - last_print_time >= 120 and not args.only_llm:
378
+ elapsed_time = current_time - start_time
379
+ avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0
380
+ estimated_total_duration = avg_duration * estimated_total_wavs
381
+ current_rtf = elapsed_time / estimated_total_duration if estimated_total_duration > 0.01 else 0
382
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
383
+ tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Estimated RTF: {current_rtf:.5f}, Elapsed time: {elapsed_time:.2f}s") # noqa
384
+ last_print_time = current_time
385
+ except StopIteration:
386
+ break
387
+ except Exception as e:
388
+ failed_wavs += 1
389
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
390
+ tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in main loop: {e}")
391
+ continue
392
+
393
+ total_time = time.time() - start_time
394
+
395
+ if local_rank == 0:
396
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
397
+ tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...")
398
+
399
+ for future in pending_futures:
400
+ try:
401
+ duration = future.result(timeout=60)
402
+ succeed_duration += duration
403
+ succeed_wavs += 1
404
+ except Exception as e:
405
+ failed_wavs += 1
406
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
407
+ tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in final async save task: {e}")
408
+ executor.shutdown(wait=True)
409
+
410
+ if local_rank == 0:
411
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
412
+ tqdm.write(f"[{timestamp}] - [INFO] - All async save tasks completed.")
413
+ progress_bar.close()
414
+
415
+ if not args.only_llm:
416
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
417
+ tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h), RTF: {total_time / succeed_duration:.5f}") # noqa
418
+
419
+ dist.barrier()
420
+ dist.destroy_process_group()
421
+
422
+
423
+ if __name__ == "__main__":
424
+ main()
@@ -0,0 +1,80 @@
1
+ import os
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from transformers import AutoConfig
6
+
7
+
8
+ @dataclass
9
+ class CosyVoice2LLMConfig:
10
+ architectures: list[str] = field(default_factory=lambda: ["Qwen2ForCausalLM"])
11
+ attention_dropout: float = 0.0
12
+ bos_token_id: int = 151643
13
+ eos_token_id: int = 6561 # speech eos
14
+ hidden_act: str = "silu"
15
+ hidden_size: int = 896
16
+ initializer_range: float = 0.02
17
+ intermediate_size: int = 4864
18
+ max_position_embeddings: int = 32768
19
+ max_window_layers: int = 24
20
+ model_type: str = "qwen2"
21
+ num_attention_heads: int = 14
22
+ num_hidden_layers: int = 24
23
+ num_key_value_heads: int = 2
24
+ head_dim: int = 64
25
+ rms_norm_eps: float = 1e-06
26
+ rope_scaling: dict | None = None
27
+ rope_theta: float = 1000000.0
28
+ sliding_window: int = 32768
29
+ tie_word_embeddings: bool = False
30
+ torch_dtype: torch.dtype = torch.bfloat16
31
+ transformers_version: str = "4.52.0.dev0"
32
+ use_cache: bool = True
33
+ use_sliding_window: bool = False
34
+ vocab_size: int = 158500 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
35
+ text_vocab_size: int = 151936
36
+ speech_vocab_size: int = 6562 # actually 6564, we only care about non-streaming inference, so cut off tokens (6562, 6563) that are only used for streaming TTS
37
+ lm_head_bias: bool = True
38
+ qkv_bias: bool = True
39
+ fp16_flow: bool = True
40
+
41
+
42
+ @dataclass
43
+ class SamplingParams:
44
+ temperature: float = 1.0
45
+ min_tokens: int = 2
46
+ max_tokens: int = 64
47
+ ignore_eos: bool = False
48
+ top_k: int = 25
49
+ # RasSampler parameters
50
+ use_ras: bool = False
51
+ win_size: int = 10
52
+ tau_r: float = 0.1
53
+ top_p: float = 0.8
54
+
55
+
56
+ @dataclass
57
+ class Config:
58
+ model: str
59
+ max_num_batched_tokens: int = 1572864
60
+ max_num_seqs: int = 1024
61
+ max_model_len: int = 1536 # 15s prompt + 30s generated audio for 25hz audio tokenizer
62
+ gpu_memory_utilization: float = 0.9
63
+ tensor_parallel_size: int = 1
64
+ enforce_eager: bool = False
65
+ hf_config: CosyVoice2LLMConfig | AutoConfig = field(default_factory=CosyVoice2LLMConfig)
66
+ eos: int = -1
67
+ kvcache_block_size: int = 256
68
+ num_kvcache_blocks: int = -1
69
+ min_token_text_ratio: int = 2
70
+ max_token_text_ratio: int = 20
71
+ rank: int = 0
72
+
73
+ def __post_init__(self):
74
+ assert os.path.isdir(self.model)
75
+ assert self.kvcache_block_size % 256 == 0
76
+ assert 1 <= self.tensor_parallel_size <= 8
77
+
78
+ max_pos = getattr(self.hf_config, "max_position_embeddings", 4096)
79
+ self.max_model_len = min(self.max_model_len, max_pos)
80
+ assert self.max_num_batched_tokens >= self.max_model_len