minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,83 @@
1
+ # Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (PositionalEncoding,
27
+ RelPositionalEncoding,
28
+ WhisperPositionalEncoding,
29
+ LearnablePositionalEncoding,
30
+ NoPositionalEncoding)
31
+ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
+ RelPositionMultiHeadedAttention)
33
+ from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
+ from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+ from cosyvoice.llm.llm import TransformerLM, Qwen2LM
36
+ from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
37
+ from cosyvoice.hifigan.generator import HiFTGenerator
38
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
39
+
40
+
41
+ COSYVOICE_ACTIVATION_CLASSES = {
42
+ "hardtanh": torch.nn.Hardtanh,
43
+ "tanh": torch.nn.Tanh,
44
+ "relu": torch.nn.ReLU,
45
+ "selu": torch.nn.SELU,
46
+ "swish": getattr(torch.nn, "SiLU", Swish),
47
+ "gelu": torch.nn.GELU,
48
+ }
49
+
50
+ COSYVOICE_SUBSAMPLE_CLASSES = {
51
+ "linear": LinearNoSubsampling,
52
+ "linear_legacy": LegacyLinearNoSubsampling,
53
+ "embed": EmbedinigNoSubsampling,
54
+ "conv1d2": Conv1dSubsampling2,
55
+ "conv2d": Conv2dSubsampling4,
56
+ "conv2d6": Conv2dSubsampling6,
57
+ "conv2d8": Conv2dSubsampling8,
58
+ 'paraformer_dummy': torch.nn.Identity
59
+ }
60
+
61
+ COSYVOICE_EMB_CLASSES = {
62
+ "embed": PositionalEncoding,
63
+ "abs_pos": PositionalEncoding,
64
+ "rel_pos": RelPositionalEncoding,
65
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
66
+ "no_pos": NoPositionalEncoding,
67
+ "abs_pos_whisper": WhisperPositionalEncoding,
68
+ "embed_learnable_pe": LearnablePositionalEncoding,
69
+ }
70
+
71
+ COSYVOICE_ATTENTION_CLASSES = {
72
+ "selfattn": MultiHeadedAttention,
73
+ "rel_selfattn": RelPositionMultiHeadedAttention,
74
+ }
75
+
76
+
77
+ def get_model_type(configs):
78
+ # NOTE CosyVoice2Model inherits CosyVoiceModel
79
+ if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
80
+ return CosyVoiceModel
81
+ if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
82
+ return CosyVoice2Model
83
+ raise TypeError('No valid model type found!')
@@ -0,0 +1,186 @@
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Unility functions for Transformer."""
18
+
19
+ import queue
20
+ import random
21
+ from typing import List
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ IGNORE_ID = -1
27
+
28
+
29
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
30
+ """Perform padding for the list of tensors.
31
+
32
+ Args:
33
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
34
+ pad_value (float): Value for padding.
35
+
36
+ Returns:
37
+ Tensor: Padded tensor (B, Tmax, `*`).
38
+
39
+ Examples:
40
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
41
+ >>> x
42
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
43
+ >>> pad_list(x, 0)
44
+ tensor([[1., 1., 1., 1.],
45
+ [1., 1., 0., 0.],
46
+ [1., 0., 0., 0.]])
47
+
48
+ """
49
+ max_len = max([len(item) for item in xs])
50
+ batchs = len(xs)
51
+ ndim = xs[0].ndim
52
+ if ndim == 1:
53
+ pad_res = torch.zeros(batchs,
54
+ max_len,
55
+ dtype=xs[0].dtype,
56
+ device=xs[0].device)
57
+ elif ndim == 2:
58
+ pad_res = torch.zeros(batchs,
59
+ max_len,
60
+ xs[0].shape[1],
61
+ dtype=xs[0].dtype,
62
+ device=xs[0].device)
63
+ elif ndim == 3:
64
+ pad_res = torch.zeros(batchs,
65
+ max_len,
66
+ xs[0].shape[1],
67
+ xs[0].shape[2],
68
+ dtype=xs[0].dtype,
69
+ device=xs[0].device)
70
+ else:
71
+ raise ValueError(f"Unsupported ndim: {ndim}")
72
+ pad_res.fill_(pad_value)
73
+ for i in range(batchs):
74
+ pad_res[i, :len(xs[i])] = xs[i]
75
+ return pad_res
76
+
77
+
78
+ def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
79
+ ignore_label: int) -> torch.Tensor:
80
+ """Calculate accuracy.
81
+
82
+ Args:
83
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
84
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
85
+ ignore_label (int): Ignore label id.
86
+
87
+ Returns:
88
+ torch.Tensor: Accuracy value (0.0 - 1.0).
89
+
90
+ """
91
+ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
92
+ pad_outputs.size(1)).argmax(2)
93
+ mask = pad_targets != ignore_label
94
+ numerator = torch.sum(
95
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
96
+ denominator = torch.sum(mask)
97
+ return (numerator / denominator).detach()
98
+
99
+
100
+ def get_padding(kernel_size, dilation=1):
101
+ return int((kernel_size * dilation - dilation) / 2)
102
+
103
+
104
+ def init_weights(m, mean=0.0, std=0.01):
105
+ classname = m.__class__.__name__
106
+ if classname.find("Conv") != -1:
107
+ m.weight.data.normal_(mean, std)
108
+
109
+
110
+ # Repetition Aware Sampling in VALL-E 2
111
+ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
112
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
113
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
114
+ if rep_num >= win_size * tau_r:
115
+ top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
116
+ return top_ids
117
+
118
+
119
+ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
120
+ prob, indices = [], []
121
+ cum_prob = 0.0
122
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
123
+ for i in range(len(sorted_idx)):
124
+ # sampling both top-p and numbers.
125
+ if cum_prob < top_p and len(prob) < top_k:
126
+ cum_prob += sorted_value[i]
127
+ prob.append(sorted_value[i])
128
+ indices.append(sorted_idx[i])
129
+ else:
130
+ break
131
+ prob = torch.tensor(prob).to(weighted_scores)
132
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
133
+ top_ids = indices[prob.multinomial(1, replacement=True)]
134
+ return top_ids
135
+
136
+
137
+ def random_sampling(weighted_scores, decoded_tokens, sampling):
138
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
139
+ return top_ids
140
+
141
+
142
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
143
+ device = fade_in_mel.device
144
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
145
+ mel_overlap_len = int(window.shape[0] / 2)
146
+ if fade_in_mel.device == torch.device('cpu'):
147
+ fade_in_mel = fade_in_mel.clone()
148
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
149
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
150
+ return fade_in_mel.to(device)
151
+
152
+
153
+ def set_all_random_seed(seed):
154
+ random.seed(seed)
155
+ np.random.seed(seed)
156
+ torch.manual_seed(seed)
157
+ torch.cuda.manual_seed_all(seed)
158
+
159
+
160
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
161
+ assert mask.dtype == torch.bool
162
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
163
+ mask = mask.to(dtype)
164
+ # attention mask bias
165
+ # NOTE(Mddct): torch.finfo jit issues
166
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
167
+ mask = (1.0 - mask) * -1.0e+10
168
+ return mask
169
+
170
+
171
+ class TrtContextWrapper:
172
+ def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
173
+ self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
174
+ self.trt_engine = trt_engine
175
+ for _ in range(trt_concurrent):
176
+ trt_context = trt_engine.create_execution_context()
177
+ trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
178
+ assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
179
+ self.trt_context_pool.put([trt_context, trt_stream])
180
+ assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
181
+
182
+ def acquire_estimator(self):
183
+ return self.trt_context_pool.get(), self.trt_engine
184
+
185
+ def release_estimator(self, context, stream):
186
+ self.trt_context_pool.put([context, stream])
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from contextlib import nullcontext
18
+ import os
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24
+
25
+
26
+ class Executor:
27
+
28
+ def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
29
+ self.gan = gan
30
+ self.ref_model = ref_model
31
+ self.dpo_loss = dpo_loss
32
+ self.step = 0
33
+ self.epoch = 0
34
+ self.rank = int(os.environ.get('RANK', 0))
35
+ self.device = torch.device('cuda:{}'.format(self.rank))
36
+
37
+ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
38
+ ''' Train one epoch
39
+ '''
40
+
41
+ lr = optimizer.param_groups[0]['lr']
42
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
43
+ logging.info('using accumulate grad, new batch size is {} times'
44
+ ' larger than before'.format(info_dict['accum_grad']))
45
+ # A context manager to be used in conjunction with an instance of
46
+ # torch.nn.parallel.DistributedDataParallel to be able to train
47
+ # with uneven inputs across participating processes.
48
+ model.train()
49
+ if self.ref_model is not None:
50
+ self.ref_model.eval()
51
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
52
+ with model_context():
53
+ for batch_idx, batch_dict in enumerate(train_data_loader):
54
+ info_dict["tag"] = "TRAIN"
55
+ info_dict["step"] = self.step
56
+ info_dict["epoch"] = self.epoch
57
+ info_dict["batch_idx"] = batch_idx
58
+ if cosyvoice_join(group_join, info_dict):
59
+ break
60
+
61
+ # Disable gradient synchronizations across DDP processes.
62
+ # Within this context, gradients will be accumulated on module
63
+ # variables, which will later be synchronized.
64
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
65
+ context = model.no_sync
66
+ # Used for single gpu training and DDP gradient synchronization
67
+ # processes.
68
+ else:
69
+ context = nullcontext
70
+
71
+ with context():
72
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
73
+ info_dict = batch_backward(model, scaler, info_dict)
74
+
75
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
76
+ log_per_step(writer, info_dict)
77
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
78
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
79
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
80
+ dist.barrier()
81
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
82
+ model.train()
83
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
84
+ self.step += 1
85
+ dist.barrier()
86
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
87
+
88
+ def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
89
+ writer, info_dict, scaler, group_join):
90
+ ''' Train one epoch
91
+ '''
92
+
93
+ lr = optimizer.param_groups[0]['lr']
94
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
95
+ logging.info('using accumulate grad, new batch size is {} times'
96
+ ' larger than before'.format(info_dict['accum_grad']))
97
+ # A context manager to be used in conjunction with an instance of
98
+ # torch.nn.parallel.DistributedDataParallel to be able to train
99
+ # with uneven inputs across participating processes.
100
+ model.train()
101
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
102
+ with model_context():
103
+ for batch_idx, batch_dict in enumerate(train_data_loader):
104
+ info_dict["tag"] = "TRAIN"
105
+ info_dict["step"] = self.step
106
+ info_dict["epoch"] = self.epoch
107
+ info_dict["batch_idx"] = batch_idx
108
+ if cosyvoice_join(group_join, info_dict):
109
+ break
110
+
111
+ # Disable gradient synchronizations across DDP processes.
112
+ # Within this context, gradients will be accumulated on module
113
+ # variables, which will later be synchronized.
114
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
115
+ context = model.no_sync
116
+ # Used for single gpu training and DDP gradient synchronization
117
+ # processes.
118
+ else:
119
+ context = nullcontext
120
+
121
+ with context():
122
+ batch_dict['turn'] = 'discriminator'
123
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
124
+ info_dict = batch_backward(model, scaler, info_dict)
125
+ info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
126
+ optimizer.zero_grad()
127
+ log_per_step(writer, info_dict)
128
+ with context():
129
+ batch_dict['turn'] = 'generator'
130
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
131
+ info_dict = batch_backward(model, scaler, info_dict)
132
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
133
+ optimizer_d.zero_grad()
134
+ log_per_step(writer, info_dict)
135
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
136
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
137
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
138
+ dist.barrier()
139
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
140
+ model.train()
141
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
142
+ self.step += 1
143
+ dist.barrier()
144
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
145
+
146
+ @torch.inference_mode()
147
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
148
+ ''' Cross validation on
149
+ '''
150
+ logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
151
+ model.eval()
152
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
153
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
154
+ info_dict["tag"] = "CV"
155
+ info_dict["step"] = self.step
156
+ info_dict["epoch"] = self.epoch
157
+ info_dict["batch_idx"] = batch_idx
158
+
159
+ num_utts = len(batch_dict["utts"])
160
+ total_num_utts += num_utts
161
+
162
+ if self.gan is True:
163
+ batch_dict['turn'] = 'generator'
164
+ info_dict = batch_forward(model, batch_dict, None, info_dict)
165
+
166
+ for k, v in info_dict['loss_dict'].items():
167
+ if k not in total_loss_dict:
168
+ total_loss_dict[k] = []
169
+ total_loss_dict[k].append(v.item() * num_utts)
170
+ log_per_step(None, info_dict)
171
+ for k, v in total_loss_dict.items():
172
+ total_loss_dict[k] = sum(v) / total_num_utts
173
+ info_dict['loss_dict'] = total_loss_dict
174
+ log_per_save(writer, info_dict)
175
+ model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
176
+ save_model(model, model_name, info_dict)
@@ -0,0 +1,129 @@
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
3
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import json
19
+ import torch
20
+ import torchaudio
21
+ import logging
22
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
23
+ logging.basicConfig(level=logging.DEBUG,
24
+ format='%(asctime)s %(levelname)s %(message)s')
25
+
26
+
27
+ def read_lists(list_file):
28
+ lists = []
29
+ with open(list_file, 'r', encoding='utf8') as fin:
30
+ for line in fin:
31
+ lists.append(line.strip())
32
+ return lists
33
+
34
+
35
+ def read_json_lists(list_file):
36
+ lists = read_lists(list_file)
37
+ results = {}
38
+ for fn in lists:
39
+ with open(fn, 'r', encoding='utf8') as fin:
40
+ results.update(json.load(fin))
41
+ return results
42
+
43
+
44
+ def load_wav(wav, target_sr):
45
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
46
+ speech = speech.mean(dim=0, keepdim=True)
47
+ if sample_rate != target_sr:
48
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
49
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
50
+ return speech
51
+
52
+
53
+ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
54
+ import tensorrt as trt
55
+ logging.info("Converting onnx to trt...")
56
+ network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
57
+ logger = trt.Logger(trt.Logger.INFO)
58
+ builder = trt.Builder(logger)
59
+ network = builder.create_network(network_flags)
60
+ parser = trt.OnnxParser(network, logger)
61
+ config = builder.create_builder_config()
62
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
63
+ if fp16:
64
+ config.set_flag(trt.BuilderFlag.FP16)
65
+ profile = builder.create_optimization_profile()
66
+ # load onnx model
67
+ with open(onnx_model, "rb") as f:
68
+ if not parser.parse(f.read()):
69
+ for error in range(parser.num_errors):
70
+ print(parser.get_error(error))
71
+ raise ValueError('failed to parse {}'.format(onnx_model))
72
+ # set input shapes
73
+ for i in range(len(trt_kwargs['input_names'])):
74
+ profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
75
+ tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
76
+ # set input and output data type
77
+ for i in range(network.num_inputs):
78
+ input_tensor = network.get_input(i)
79
+ input_tensor.dtype = tensor_dtype
80
+ for i in range(network.num_outputs):
81
+ output_tensor = network.get_output(i)
82
+ output_tensor.dtype = tensor_dtype
83
+ config.add_optimization_profile(profile)
84
+ engine_bytes = builder.build_serialized_network(network, config)
85
+ # save trt engine
86
+ with open(trt_model, "wb") as f:
87
+ f.write(engine_bytes)
88
+ logging.info("Succesfully convert onnx to trt...")
89
+
90
+
91
+ def export_cosyvoice2_vllm(model, model_path, device):
92
+ if os.path.exists(model_path):
93
+ return
94
+ pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
95
+ vocab_size = model.speech_embedding.num_embeddings
96
+ feature_size = model.speech_embedding.embedding_dim
97
+ pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
98
+
99
+ dtype = torch.bfloat16
100
+ # lm_head
101
+ new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
102
+ with torch.no_grad():
103
+ new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
104
+ new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
105
+ new_lm_head.weight[vocab_size:] = 0
106
+ new_lm_head.bias[vocab_size:] = 0
107
+ model.llm.model.lm_head = new_lm_head
108
+ new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
109
+ # embed_tokens
110
+ embed_tokens = model.llm.model.model.embed_tokens
111
+ with torch.no_grad():
112
+ new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
113
+ new_codec_embed.weight[vocab_size:] = 0
114
+ model.llm.model.set_input_embeddings(new_codec_embed)
115
+ model.llm.model.to(device)
116
+ model.llm.model.to(dtype)
117
+ tmp_vocab_size = model.llm.model.config.vocab_size
118
+ tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
119
+ del model.llm.model.generation_config.eos_token_id
120
+ del model.llm.model.config.bos_token_id
121
+ del model.llm.model.config.eos_token_id
122
+ model.llm.model.config.vocab_size = pad_vocab_size
123
+ model.llm.model.config.tie_word_embeddings = False
124
+ model.llm.model.config.use_bias = True
125
+ model.llm.model.save_pretrained(model_path)
126
+ os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
127
+ model.llm.model.config.vocab_size = tmp_vocab_size
128
+ model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
129
+ model.llm.model.set_input_embeddings(embed_tokens)