minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,386 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from typing import Generator
17
+ import torch
18
+ import numpy as np
19
+ import threading
20
+ import time
21
+ from torch.nn import functional as F
22
+ from contextlib import nullcontext
23
+ import uuid
24
+ from cosyvoice.utils.common import fade_in_out
25
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
26
+ from cosyvoice.utils.common import TrtContextWrapper
27
+
28
+
29
+ class CosyVoiceModel:
30
+
31
+ def __init__(self,
32
+ llm: torch.nn.Module,
33
+ flow: torch.nn.Module,
34
+ hift: torch.nn.Module,
35
+ fp16: bool = False):
36
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+ self.llm = llm
38
+ self.flow = flow
39
+ self.hift = hift
40
+ self.fp16 = fp16
41
+ if self.fp16 is True:
42
+ self.llm.half()
43
+ self.flow.half()
44
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
45
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
46
+ self.token_overlap_len = 20
47
+ # mel fade in out
48
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
49
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
50
+ # hift cache
51
+ self.mel_cache_len = 20
52
+ self.source_cache_len = int(self.mel_cache_len * 256)
53
+ # speech fade in out
54
+ self.speech_window = np.hamming(2 * self.source_cache_len)
55
+ # rtf and decoding related
56
+ self.stream_scale_factor = 1
57
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
58
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
59
+ self.lock = threading.Lock()
60
+ # dict used to store session related variable
61
+ self.tts_speech_token_dict = {}
62
+ self.llm_end_dict = {}
63
+ self.mel_overlap_dict = {}
64
+ self.flow_cache_dict = {}
65
+ self.hift_cache_dict = {}
66
+
67
+ def load(self, llm_model, flow_model, hift_model):
68
+ # self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
69
+ # self.llm.to(self.device).eval()
70
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
71
+ self.flow.to(self.device).eval()
72
+ # in case hift_model is a hifigan model
73
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
74
+ self.hift.load_state_dict(hift_state_dict, strict=True)
75
+ self.hift.to(self.device).eval()
76
+
77
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
78
+ # llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
79
+ # self.llm.text_encoder = llm_text_encoder
80
+ # llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
81
+ # self.llm.llm = llm_llm
82
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
83
+ self.flow.encoder = flow_encoder
84
+
85
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
86
+ assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
87
+ if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
88
+ convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
89
+ del self.flow.decoder.estimator
90
+ import tensorrt as trt
91
+ with open(flow_decoder_estimator_model, 'rb') as f:
92
+ estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
93
+ assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
94
+ self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
95
+
96
+ def get_trt_kwargs(self):
97
+ min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
98
+ opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
99
+ max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
100
+ input_names = ["x", "mask", "mu", "cond"]
101
+ return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
102
+
103
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
104
+ with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
105
+ if isinstance(text, Generator):
106
+ assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
107
+ for i in self.llm.inference_bistream(text=text,
108
+ prompt_text=prompt_text.to(self.device),
109
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
110
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
111
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
112
+ embedding=llm_embedding.to(self.device)):
113
+ self.tts_speech_token_dict[uuid].append(i)
114
+ else:
115
+ for i in self.llm.inference(text=text.to(self.device),
116
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
117
+ prompt_text=prompt_text.to(self.device),
118
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
119
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
120
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
121
+ embedding=llm_embedding.to(self.device),
122
+ uuid=uuid):
123
+ self.tts_speech_token_dict[uuid].append(i)
124
+ self.llm_end_dict[uuid] = True
125
+
126
+ def vc_job(self, source_speech_token, uuid):
127
+ self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
128
+ self.llm_end_dict[uuid] = True
129
+
130
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
131
+ with torch.cuda.amp.autocast(self.fp16):
132
+ tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
133
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
134
+ prompt_token=prompt_token.to(self.device),
135
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
136
+ prompt_feat=prompt_feat.to(self.device),
137
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
138
+ embedding=embedding.to(self.device),
139
+ flow_cache=self.flow_cache_dict[uuid])
140
+
141
+ # mel overlap fade in out
142
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
143
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
144
+ # append hift cache
145
+ if self.hift_cache_dict[uuid] is not None:
146
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
147
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
148
+ else:
149
+ hift_cache_source = torch.zeros(1, 1, 0)
150
+ # keep overlap mel and hift cache
151
+ if finalize is False:
152
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
153
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
154
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
155
+ if self.hift_cache_dict[uuid] is not None:
156
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
157
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
158
+ 'source': tts_source[:, :, -self.source_cache_len:],
159
+ 'speech': tts_speech[:, -self.source_cache_len:]}
160
+ tts_speech = tts_speech[:, :-self.source_cache_len]
161
+ else:
162
+ if speed != 1.0:
163
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
164
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
165
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
166
+ if self.hift_cache_dict[uuid] is not None:
167
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
168
+ return tts_speech
169
+
170
+ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
171
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
172
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
173
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
174
+ prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
175
+ # this_uuid is used to track variables related to this inference thread
176
+ this_uuid = str(uuid.uuid1())
177
+ with self.lock:
178
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
179
+ self.hift_cache_dict[this_uuid] = None
180
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
181
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
182
+ if source_speech_token.shape[1] == 0:
183
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
184
+ else:
185
+ p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
186
+ p.start()
187
+ if stream is True:
188
+ token_hop_len = self.token_min_hop_len
189
+ while True:
190
+ time.sleep(0.1)
191
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
192
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
193
+ .unsqueeze(dim=0)
194
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
195
+ prompt_token=flow_prompt_speech_token,
196
+ prompt_feat=prompt_speech_feat,
197
+ embedding=flow_embedding,
198
+ uuid=this_uuid,
199
+ finalize=False)
200
+ yield {'tts_speech': this_tts_speech.cpu()}
201
+ with self.lock:
202
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
203
+ # increase token_hop_len for better speech quality
204
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
205
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
206
+ break
207
+ p.join()
208
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
209
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
210
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
211
+ prompt_token=flow_prompt_speech_token,
212
+ prompt_feat=prompt_speech_feat,
213
+ embedding=flow_embedding,
214
+ uuid=this_uuid,
215
+ finalize=True)
216
+ yield {'tts_speech': this_tts_speech.cpu()}
217
+ else:
218
+ # deal with all tokens
219
+ p.join()
220
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
221
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
222
+ prompt_token=flow_prompt_speech_token,
223
+ prompt_feat=prompt_speech_feat,
224
+ embedding=flow_embedding,
225
+ uuid=this_uuid,
226
+ finalize=True,
227
+ speed=speed)
228
+ yield {'tts_speech': this_tts_speech.cpu()}
229
+ with self.lock:
230
+ self.tts_speech_token_dict.pop(this_uuid)
231
+ self.llm_end_dict.pop(this_uuid)
232
+ self.mel_overlap_dict.pop(this_uuid)
233
+ self.hift_cache_dict.pop(this_uuid)
234
+ self.flow_cache_dict.pop(this_uuid)
235
+ if torch.cuda.is_available():
236
+ torch.cuda.empty_cache()
237
+ torch.cuda.current_stream().synchronize()
238
+
239
+
240
+ class CosyVoice2Model(CosyVoiceModel):
241
+
242
+ def __init__(self,
243
+ llm: torch.nn.Module,
244
+ flow: torch.nn.Module,
245
+ hift: torch.nn.Module,
246
+ fp16: bool = False):
247
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
248
+ # self.llm = llm
249
+ self.flow = flow
250
+ self.hift = hift
251
+ self.fp16 = fp16
252
+ if self.fp16 is True:
253
+ # self.llm.half()
254
+ self.flow.half()
255
+ # NOTE must matching training static_chunk_size
256
+ self.token_hop_len = 25
257
+ # hift cache
258
+ self.mel_cache_len = 8
259
+ self.source_cache_len = int(self.mel_cache_len * 480)
260
+ # speech fade in out
261
+ self.speech_window = np.hamming(2 * self.source_cache_len)
262
+ # rtf and decoding related
263
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
264
+ self.lock = threading.Lock()
265
+ # dict used to store session related variable
266
+ self.tts_speech_token_dict = {}
267
+ self.llm_end_dict = {}
268
+ self.hift_cache_dict = {}
269
+
270
+ def load_jit(self, flow_encoder_model):
271
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
272
+ self.flow.encoder = flow_encoder
273
+
274
+ def load_vllm(self, model_dir):
275
+ export_cosyvoice2_vllm(self.llm, model_dir, self.device)
276
+ from vllm import EngineArgs, LLMEngine
277
+ engine_args = EngineArgs(model=model_dir,
278
+ skip_tokenizer_init=True,
279
+ enable_prompt_embeds=True,
280
+ gpu_memory_utilization=0.2)
281
+ self.llm.vllm = LLMEngine.from_engine_args(engine_args)
282
+ self.llm.lock = threading.Lock()
283
+ del self.llm.llm.model.model.layers
284
+
285
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
286
+ with torch.cuda.amp.autocast(self.fp16):
287
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
288
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
289
+ prompt_token=prompt_token.to(self.device),
290
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
291
+ prompt_feat=prompt_feat.to(self.device),
292
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
293
+ embedding=embedding.to(self.device),
294
+ streaming=stream,
295
+ finalize=finalize)
296
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
297
+ # append hift cache
298
+ if self.hift_cache_dict[uuid] is not None:
299
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
300
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
301
+ else:
302
+ hift_cache_source = torch.zeros(1, 1, 0)
303
+ # keep overlap mel and hift cache
304
+ if finalize is False:
305
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
306
+ if self.hift_cache_dict[uuid] is not None:
307
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
308
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
309
+ 'source': tts_source[:, :, -self.source_cache_len:],
310
+ 'speech': tts_speech[:, -self.source_cache_len:]}
311
+ tts_speech = tts_speech[:, :-self.source_cache_len]
312
+ else:
313
+ if speed != 1.0:
314
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
315
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
316
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
317
+ if self.hift_cache_dict[uuid] is not None:
318
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
319
+ return tts_speech
320
+
321
+ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
322
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
323
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
324
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
325
+ prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
326
+ # this_uuid is used to track variables related to this inference thread
327
+ this_uuid = str(uuid.uuid1())
328
+ with self.lock:
329
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
330
+ self.hift_cache_dict[this_uuid] = None
331
+ if source_speech_token.shape[1] == 0:
332
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
333
+ else:
334
+ p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
335
+ p.start()
336
+ if stream is True:
337
+ token_offset = 0
338
+ prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
339
+ while True:
340
+ time.sleep(0.1)
341
+ this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
342
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
343
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
344
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
345
+ prompt_token=flow_prompt_speech_token,
346
+ prompt_feat=prompt_speech_feat,
347
+ embedding=flow_embedding,
348
+ token_offset=token_offset,
349
+ uuid=this_uuid,
350
+ stream=stream,
351
+ finalize=False)
352
+ token_offset += this_token_hop_len
353
+ yield {'tts_speech': this_tts_speech.cpu()}
354
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
355
+ break
356
+ p.join()
357
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
358
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
359
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
360
+ prompt_token=flow_prompt_speech_token,
361
+ prompt_feat=prompt_speech_feat,
362
+ embedding=flow_embedding,
363
+ token_offset=token_offset,
364
+ uuid=this_uuid,
365
+ finalize=True)
366
+ yield {'tts_speech': this_tts_speech.cpu()}
367
+ else:
368
+ # deal with all tokens
369
+ p.join()
370
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
371
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
372
+ prompt_token=flow_prompt_speech_token,
373
+ prompt_feat=prompt_speech_feat,
374
+ embedding=flow_embedding,
375
+ token_offset=0,
376
+ uuid=this_uuid,
377
+ finalize=True,
378
+ speed=speed)
379
+ yield {'tts_speech': this_tts_speech.cpu()}
380
+ with self.lock:
381
+ self.tts_speech_token_dict.pop(this_uuid)
382
+ self.llm_end_dict.pop(this_uuid)
383
+ self.hift_cache_dict.pop(this_uuid)
384
+ if torch.cuda.is_available():
385
+ torch.cuda.empty_cache()
386
+ torch.cuda.current_stream().synchronize()
File without changes
@@ -0,0 +1,151 @@
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import math
18
+ from functools import partial
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch.utils.data import IterableDataset
23
+ from cosyvoice.utils.file_utils import read_lists
24
+
25
+
26
+ class Processor(IterableDataset):
27
+
28
+ def __init__(self, source, f, *args, **kw):
29
+ assert callable(f)
30
+ self.source = source
31
+ self.f = f
32
+ self.args = args
33
+ self.kw = kw
34
+
35
+ def set_epoch(self, epoch):
36
+ self.source.set_epoch(epoch)
37
+
38
+ def __iter__(self):
39
+ """ Return an iterator over the source dataset processed by the
40
+ given processor.
41
+ """
42
+ assert self.source is not None
43
+ assert callable(self.f)
44
+ return self.f(iter(self.source), *self.args, **self.kw)
45
+
46
+ def apply(self, f):
47
+ assert callable(f)
48
+ return Processor(self, f, *self.args, **self.kw)
49
+
50
+
51
+ class DistributedSampler:
52
+
53
+ def __init__(self, shuffle=True, partition=True):
54
+ self.epoch = -1
55
+ self.update()
56
+ self.shuffle = shuffle
57
+ self.partition = partition
58
+
59
+ def update(self):
60
+ assert dist.is_available()
61
+ if dist.is_initialized():
62
+ self.rank = dist.get_rank()
63
+ self.world_size = dist.get_world_size()
64
+ else:
65
+ self.rank = 0
66
+ self.world_size = 1
67
+ worker_info = torch.utils.data.get_worker_info()
68
+ if worker_info is None:
69
+ self.worker_id = 0
70
+ self.num_workers = 1
71
+ else:
72
+ self.worker_id = worker_info.id
73
+ self.num_workers = worker_info.num_workers
74
+ return dict(rank=self.rank,
75
+ world_size=self.world_size,
76
+ worker_id=self.worker_id,
77
+ num_workers=self.num_workers)
78
+
79
+ def set_epoch(self, epoch):
80
+ self.epoch = epoch
81
+
82
+ def sample(self, data):
83
+ """ Sample data according to rank/world_size/num_workers
84
+
85
+ Args:
86
+ data(List): input data list
87
+
88
+ Returns:
89
+ List: data list after sample
90
+ """
91
+ data = list(range(len(data)))
92
+ # force datalist even
93
+ if self.partition:
94
+ if self.shuffle:
95
+ random.Random(self.epoch).shuffle(data)
96
+ if len(data) < self.world_size:
97
+ data = data * math.ceil(self.world_size / len(data))
98
+ data = data[:self.world_size]
99
+ data = data[self.rank::self.world_size]
100
+ if len(data) < self.num_workers:
101
+ data = data * math.ceil(self.num_workers / len(data))
102
+ data = data[:self.num_workers]
103
+ data = data[self.worker_id::self.num_workers]
104
+ return data
105
+
106
+
107
+ class DataList(IterableDataset):
108
+
109
+ def __init__(self, lists, shuffle=True, partition=True):
110
+ self.lists = lists
111
+ self.sampler = DistributedSampler(shuffle, partition)
112
+
113
+ def set_epoch(self, epoch):
114
+ self.sampler.set_epoch(epoch)
115
+
116
+ def __iter__(self):
117
+ sampler_info = self.sampler.update()
118
+ indexes = self.sampler.sample(self.lists)
119
+ for index in indexes:
120
+ data = dict(src=self.lists[index])
121
+ data.update(sampler_info)
122
+ yield data
123
+
124
+
125
+ def Dataset(data_list_file,
126
+ data_pipeline,
127
+ mode='train',
128
+ gan=False,
129
+ dpo=False,
130
+ shuffle=True,
131
+ partition=True):
132
+ """ Construct dataset from arguments
133
+
134
+ We have two shuffle stage in the Dataset. The first is global
135
+ shuffle at shards tar/raw file level. The second is global shuffle
136
+ at training samples level.
137
+
138
+ Args:
139
+ data_type(str): raw/shard
140
+ tokenizer (BaseTokenizer): tokenizer to tokenize
141
+ partition(bool): whether to do data partition in terms of rank
142
+ """
143
+ lists = read_lists(data_list_file)
144
+ dataset = DataList(lists,
145
+ shuffle=shuffle,
146
+ partition=partition)
147
+ # map partial arg to padding func
148
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
149
+ for func in data_pipeline:
150
+ dataset = Processor(dataset, func, mode=mode)
151
+ return dataset