minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,610 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import queue
16
+ import random
17
+ import time
18
+ import threading
19
+ from typing import Dict, Optional, Callable, List, Generator
20
+ import torch
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+ from transformers import Qwen2ForCausalLM
24
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
25
+ from cosyvoice.utils.common import IGNORE_ID
26
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
27
+ from cosyvoice.utils.common import th_accuracy
28
+ from cosyvoice.utils.file_utils import logging
29
+ from cosyvoice.utils.mask import make_pad_mask
30
+
31
+
32
+ class TransformerLM(torch.nn.Module):
33
+ def __init__(
34
+ self,
35
+ text_encoder_input_size: int,
36
+ llm_input_size: int,
37
+ llm_output_size: int,
38
+ text_token_size: int,
39
+ speech_token_size: int,
40
+ text_encoder: torch.nn.Module,
41
+ llm: torch.nn.Module,
42
+ sampling: Callable,
43
+ length_normalized_loss: bool = True,
44
+ lsm_weight: float = 0.0,
45
+ spk_embed_dim: int = 192,
46
+ ):
47
+ super().__init__()
48
+ self.llm_input_size = llm_input_size
49
+ self.speech_token_size = speech_token_size
50
+ # 1. build text token inputs related modules
51
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
52
+ self.text_encoder = text_encoder
53
+ self.text_encoder_affine_layer = nn.Linear(
54
+ self.text_encoder.output_size(),
55
+ llm_input_size
56
+ )
57
+
58
+ # 2. build speech token language model related modules
59
+ self.sos_eos = 0
60
+ self.task_id = 1
61
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
62
+ self.llm = llm
63
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
64
+ self.criterion_ce = LabelSmoothingLoss(
65
+ size=speech_token_size + 1,
66
+ padding_idx=IGNORE_ID,
67
+ smoothing=lsm_weight,
68
+ normalize_length=length_normalized_loss,
69
+ )
70
+
71
+ # 3. [Optional] build speech token related modules
72
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
73
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
74
+
75
+ # 4. sampling method
76
+ self.sampling = sampling
77
+
78
+ def encode(
79
+ self,
80
+ text: torch.Tensor,
81
+ text_lengths: torch.Tensor,
82
+ ):
83
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
84
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
85
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
86
+ return encoder_out, encoder_out_lens
87
+
88
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
89
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
90
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
91
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
92
+ for i in range(len(text_token))]
93
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
94
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
95
+ return lm_input, lm_input_len
96
+
97
+ def forward(
98
+ self,
99
+ batch: dict,
100
+ device: torch.device,
101
+ ) -> Dict[str, Optional[torch.Tensor]]:
102
+ """
103
+ Args:
104
+ text: (B, L, D)
105
+ text_lengths: (B,)
106
+ audio: (B, T, N) or (B, T)
107
+ audio_lengths: (B,)
108
+ """
109
+ text_token = batch['text_token'].to(device)
110
+ text_token_len = batch['text_token_len'].to(device)
111
+ speech_token = batch['speech_token'].to(device)
112
+ speech_token_len = batch['speech_token_len'].to(device)
113
+ embedding = batch['embedding'].to(device)
114
+
115
+ # 1. prepare llm_target
116
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
117
+ [self.speech_token_size]) for i in range(text_token.size(0))]
118
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
119
+
120
+ # 1. encode text_token
121
+ text_token = self.text_embedding(text_token)
122
+ text_token, text_token_len = self.encode(text_token, text_token_len)
123
+
124
+ # 2. embedding projection
125
+ embedding = F.normalize(embedding, dim=1)
126
+ embedding = self.spk_embed_affine_layer(embedding)
127
+ embedding = embedding.unsqueeze(1)
128
+
129
+ # 3. eos and task_id
130
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
131
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
132
+
133
+ # 4. encode speech_token
134
+ speech_token = self.speech_embedding(speech_token)
135
+
136
+ # 5. unpad and pad
137
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
138
+ task_id_emb, speech_token, speech_token_len)
139
+
140
+ # 6. run lm forward
141
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
142
+ logits = self.llm_decoder(lm_output)
143
+ loss = self.criterion_ce(logits, lm_target)
144
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
145
+ return {'loss': loss, 'acc': acc}
146
+
147
+ def sampling_ids(
148
+ self,
149
+ weighted_scores: torch.Tensor,
150
+ decoded_tokens: List,
151
+ sampling: int,
152
+ ignore_eos: bool = True,
153
+ ):
154
+ num_trials, max_trials = 0, 100
155
+ while True:
156
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
157
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
158
+ break
159
+ num_trials += 1
160
+ if num_trials > max_trials:
161
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
162
+ return top_ids
163
+
164
+ @torch.inference_mode()
165
+ def inference(
166
+ self,
167
+ text: torch.Tensor,
168
+ text_len: torch.Tensor,
169
+ prompt_text: torch.Tensor,
170
+ prompt_text_len: torch.Tensor,
171
+ prompt_speech_token: torch.Tensor,
172
+ prompt_speech_token_len: torch.Tensor,
173
+ embedding: torch.Tensor,
174
+ sampling: int = 25,
175
+ max_token_text_ratio: float = 20,
176
+ min_token_text_ratio: float = 2,
177
+ uuid: str = '',
178
+ ) -> Generator[torch.Tensor, None, None]:
179
+ device = text.device
180
+ text = torch.concat([prompt_text, text], dim=1)
181
+ text_len += prompt_text_len
182
+ text = self.text_embedding(text)
183
+
184
+ # 1. encode text
185
+ text, text_len = self.encode(text, text_len)
186
+
187
+ # 2. encode embedding
188
+ if embedding.shape[0] != 0:
189
+ embedding = F.normalize(embedding, dim=1)
190
+ embedding = self.spk_embed_affine_layer(embedding)
191
+ embedding = embedding.unsqueeze(dim=1)
192
+ else:
193
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
194
+
195
+ # 3. concat llm_input
196
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
197
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
198
+ if prompt_speech_token_len != 0:
199
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
200
+ else:
201
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
202
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
203
+
204
+ # 4. cal min/max_length
205
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
206
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
207
+
208
+ # 5. step by step decode
209
+ out_tokens = []
210
+ offset = 0
211
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
212
+ for i in range(max_len):
213
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
214
+ att_cache=att_cache, cnn_cache=cnn_cache,
215
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
216
+ device=lm_input.device)).to(torch.bool))
217
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
218
+ # force continue decode first token
219
+ if i == 0:
220
+ logp[:, self.speech_token_size] = -float('inf')
221
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
222
+ if top_ids == self.speech_token_size:
223
+ break
224
+ # in stream mode, yield token one by one
225
+ yield top_ids
226
+ out_tokens.append(top_ids)
227
+ offset += lm_input.size(1)
228
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
229
+
230
+
231
+ class Qwen2Encoder(torch.nn.Module):
232
+ def __init__(self, pretrain_path):
233
+ super().__init__()
234
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
235
+
236
+ def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
237
+ T = xs.size(1)
238
+ masks = ~make_pad_mask(xs_lens, T)
239
+ outs = self.model(
240
+ inputs_embeds=xs,
241
+ attention_mask=masks,
242
+ output_hidden_states=True,
243
+ return_dict=True,
244
+ )
245
+ return outs.hidden_states[-1], masks.unsqueeze(1)
246
+
247
+ def forward_one_step(self, xs, masks, cache=None):
248
+ input_masks = masks[:, -1, :]
249
+ outs = self.model(
250
+ inputs_embeds=xs,
251
+ attention_mask=input_masks,
252
+ output_hidden_states=True,
253
+ return_dict=True,
254
+ use_cache=True,
255
+ past_key_values=cache,
256
+ )
257
+ xs = outs.hidden_states[-1]
258
+ new_cache = outs.past_key_values
259
+ return xs, new_cache
260
+
261
+
262
+ class Qwen2LM(TransformerLM):
263
+ def __init__(
264
+ self,
265
+ llm_input_size: int,
266
+ llm_output_size: int,
267
+ speech_token_size: int,
268
+ llm: torch.nn.Module,
269
+ sampling: Callable,
270
+ length_normalized_loss: bool = True,
271
+ lsm_weight: float = 0.0,
272
+ mix_ratio: List[int] = [5, 15],
273
+ ):
274
+ torch.nn.Module.__init__(self)
275
+ self.llm_input_size = llm_input_size
276
+ self.llm_output_size = llm_output_size
277
+ self.speech_token_size = speech_token_size
278
+ # 2. build speech token language model related modules
279
+ self.sos_eos = 0
280
+ self.task_id = 1
281
+ self.fill_token = 2
282
+
283
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
284
+ self.llm = llm
285
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
286
+ self.criterion_ce = LabelSmoothingLoss(
287
+ size=speech_token_size + 3,
288
+ padding_idx=IGNORE_ID,
289
+ smoothing=lsm_weight,
290
+ normalize_length=length_normalized_loss,
291
+ )
292
+
293
+ # 3. [Optional] build speech token related modules
294
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
295
+
296
+ # 4. sampling method
297
+ self.sampling = sampling
298
+ self.mix_ratio = mix_ratio
299
+
300
+ # 5. vllm related
301
+ self.stop_token_ids = [speech_token_size + i for i in range(3)]
302
+ self.vllm_output_queue = {}
303
+
304
+ def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
305
+ lm_target, lm_input = [], []
306
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
307
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
308
+ text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
309
+ speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
310
+ for i in range(len(text_token)):
311
+ # bistream sequence
312
+ if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
313
+ this_lm_target, this_lm_input = [], []
314
+ this_lm_target.append(IGNORE_ID)
315
+ this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
316
+ for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
317
+ this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
318
+ this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
319
+ if len(this_text_token) == self.mix_ratio[0]:
320
+ assert len(this_speech_token) == self.mix_ratio[1]
321
+ this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
322
+ this_lm_target += this_speech_token
323
+ this_lm_target.append(self.speech_token_size + 2)
324
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
325
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
326
+ else:
327
+ this_lm_target += [-1] * len(this_text_token)
328
+ this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
329
+ this_lm_target.append(self.speech_token_size)
330
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
331
+ this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
332
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
333
+ this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
334
+ # unistream sequence
335
+ else:
336
+ this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
337
+ this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
338
+ self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
339
+ lm_target.append(this_lm_target)
340
+ lm_input.append(this_lm_input)
341
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
342
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
343
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
344
+ return lm_target, lm_input, lm_input_len
345
+
346
+ def forward(
347
+ self,
348
+ batch: dict,
349
+ device: torch.device,
350
+ ) -> Dict[str, Optional[torch.Tensor]]:
351
+ """
352
+ Args:
353
+ text: (B, L, D)
354
+ text_lengths: (B,)
355
+ audio: (B, T, N) or (B, T)
356
+ audio_lengths: (B,)
357
+ """
358
+ text_token = batch['text_token'].to(device)
359
+ text_token_len = batch['text_token_len'].to(device)
360
+ speech_token = batch['speech_token'].to(device)
361
+ speech_token_len = batch['speech_token_len'].to(device)
362
+
363
+ # 1. encode text_token
364
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
365
+
366
+ # 2. encode speech_token
367
+ speech_token_emb = self.speech_embedding(speech_token)
368
+
369
+ # 3. prepare llm_input/target
370
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
371
+ lm_target = lm_target.to(device)
372
+
373
+ # 4. run lm forward
374
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
375
+ logits = self.llm_decoder(lm_output)
376
+ loss = self.criterion_ce(logits, lm_target.to(device))
377
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
378
+ return {'loss': loss, 'acc': acc}
379
+
380
+ def forward_dpo(
381
+ self,
382
+ batch: dict,
383
+ device: torch.device,
384
+ ) -> Dict[str, Optional[torch.Tensor]]:
385
+ text_token = batch['text_token'].to(device)
386
+ text_token_len = batch['text_token_len'].to(device)
387
+ speech_token = batch['speech_token'].to(device)
388
+ speech_token_len = batch['speech_token_len'].to(device)
389
+ reject_speech_token = batch['reject_speech_token'].to(device)
390
+ reject_speech_token_len = batch['reject_speech_token_len'].to(device)
391
+
392
+ # 1. encode text_token
393
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
394
+
395
+ # 2. encode speech_token
396
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
397
+ reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
398
+ speech_token_combined = speech_token + reject_speech_token
399
+ speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
400
+ speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
401
+ speech_token_combined_emb = self.speech_embedding(speech_token_combined)
402
+
403
+ # 3. prepare llm_input/target
404
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
405
+ lm_target = lm_target.to(device)
406
+
407
+ # 4. run lm forward
408
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
409
+ logits = self.llm_decoder(lm_output)
410
+ chosen_logits = logits[:text_token.shape[0]]
411
+ rejected_logits = logits[text_token.shape[0]:]
412
+ chosen_lm_target = lm_target[:text_token.shape[0]]
413
+ rejected_lm_target = lm_target[text_token.shape[0]:]
414
+ loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
415
+ acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
416
+
417
+ # 5. calculate dpo logits
418
+ chosen_lm_mask = chosen_lm_target == IGNORE_ID
419
+ rejected_lm_mask = rejected_lm_target == IGNORE_ID
420
+ chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
421
+ rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
422
+ chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
423
+ rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
424
+ return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
425
+
426
+ @torch.inference_mode()
427
+ def inference(
428
+ self,
429
+ text: torch.Tensor,
430
+ text_len: torch.Tensor,
431
+ prompt_text: torch.Tensor,
432
+ prompt_text_len: torch.Tensor,
433
+ prompt_speech_token: torch.Tensor,
434
+ prompt_speech_token_len: torch.Tensor,
435
+ embedding: torch.Tensor,
436
+ sampling: int = 25,
437
+ max_token_text_ratio: float = 20,
438
+ min_token_text_ratio: float = 2,
439
+ uuid: str = '',
440
+ ) -> Generator[torch.Tensor, None, None]:
441
+ device = text.device
442
+ text = torch.concat([prompt_text, text], dim=1)
443
+ text_len += prompt_text_len
444
+ text = self.llm.model.model.embed_tokens(text)
445
+
446
+ # 3. concat llm_input
447
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
448
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
449
+ if prompt_speech_token_len != 0:
450
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
451
+ else:
452
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
453
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
454
+
455
+ # 4. cal min/max_length
456
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
457
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
458
+
459
+ # 5. step by step decode
460
+ for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
461
+ yield token
462
+
463
+ @torch.inference_mode()
464
+ def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
465
+ if hasattr(self, 'vllm'):
466
+ from vllm import SamplingParams, RequestOutput
467
+ sampling_params = SamplingParams(top_k=sampling,
468
+ stop_token_ids=self.stop_token_ids,
469
+ min_tokens=min_len,
470
+ max_tokens=max_len)
471
+ with self.lock:
472
+ self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
473
+ self.vllm_output_queue[uuid] = queue.Queue()
474
+ out_tokens = []
475
+ while True:
476
+ with self.lock:
477
+ if self.vllm_output_queue[uuid].empty() is True:
478
+ request_outputs: List[RequestOutput] = self.vllm.step()
479
+ for request_output in request_outputs:
480
+ top_ids = list(request_output.outputs[0].token_ids)[-1]
481
+ self.vllm_output_queue[request_output.request_id].put(top_ids)
482
+ if self.vllm_output_queue[uuid].empty() is False:
483
+ top_ids = self.vllm_output_queue[uuid].get()
484
+ if top_ids in self.stop_token_ids:
485
+ break
486
+ # in stream mode, yield token one by one
487
+ yield top_ids
488
+ out_tokens.append(top_ids)
489
+ if len(out_tokens) == max_len:
490
+ break
491
+ time.sleep(0.001)
492
+ with self.lock:
493
+ self.vllm_output_queue.pop(uuid)
494
+ else:
495
+ out_tokens = []
496
+ cache = None
497
+ for i in range(max_len):
498
+ y_pred, cache = self.llm.forward_one_step(lm_input,
499
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
500
+ cache=cache)
501
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
502
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
503
+ if top_ids == self.speech_token_size:
504
+ break
505
+ if top_ids > self.speech_token_size:
506
+ continue
507
+ # in stream mode, yield token one by one
508
+ yield top_ids
509
+ out_tokens.append(top_ids)
510
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
511
+
512
+ @torch.inference_mode()
513
+ def inference_bistream(
514
+ self,
515
+ text: Generator,
516
+ prompt_text: torch.Tensor,
517
+ prompt_text_len: torch.Tensor,
518
+ prompt_speech_token: torch.Tensor,
519
+ prompt_speech_token_len: torch.Tensor,
520
+ embedding: torch.Tensor,
521
+ sampling: int = 25,
522
+ max_token_text_ratio: float = 20,
523
+ min_token_text_ratio: float = 2,
524
+ ) -> Generator[torch.Tensor, None, None]:
525
+
526
+ device = prompt_text.device
527
+ # 1. prepare input
528
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
529
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
530
+ if prompt_speech_token_len != 0:
531
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
532
+ else:
533
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
534
+ lm_input = torch.concat([sos_eos_emb], dim=1)
535
+
536
+ # 2. iterate text
537
+ out_tokens = []
538
+ cache = None
539
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
540
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
541
+ next_fill_index = -1
542
+ for this_text in text:
543
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
544
+ # prompt_speech_token_emb not empty, try append to lm_input
545
+ while prompt_speech_token_emb.size(1) != 0:
546
+ if text_cache.size(1) >= self.mix_ratio[0]:
547
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
548
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
549
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
550
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
551
+ else:
552
+ logging.info('not enough text token to decode, wait for more')
553
+ break
554
+ # no prompt_speech_token_emb remain, can decode some speech token
555
+ if prompt_speech_token_emb.size(1) == 0:
556
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
557
+ logging.info('get fill token, need to append more text token')
558
+ if text_cache.size(1) >= self.mix_ratio[0]:
559
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
560
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
561
+ if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
562
+ lm_input = lm_input_text
563
+ else:
564
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
565
+ text_cache = text_cache[:, self.mix_ratio[0]:]
566
+ else:
567
+ logging.info('not enough text token to decode, wait for more')
568
+ continue
569
+ while True:
570
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
571
+ y_pred, cache = self.llm.forward_one_step(lm_input,
572
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
573
+ cache=cache)
574
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
575
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
576
+ top_ids = self.speech_token_size + 2
577
+ next_fill_index += (self.mix_ratio[1] + 1)
578
+ else:
579
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
580
+ if top_ids == self.speech_token_size + 2:
581
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
582
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
583
+ out_tokens.append(top_ids)
584
+ if top_ids >= self.speech_token_size:
585
+ if top_ids == self.speech_token_size + 2:
586
+ break
587
+ else:
588
+ raise ValueError('should not get token {}'.format(top_ids))
589
+ yield top_ids
590
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
591
+
592
+ # 3. final decode
593
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
594
+ logging.info('no more text token, decode until met eos')
595
+ while True:
596
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
597
+ y_pred, cache = self.llm.forward_one_step(lm_input,
598
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
599
+ cache=cache)
600
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
601
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
602
+ out_tokens.append(top_ids)
603
+ if top_ids >= self.speech_token_size:
604
+ if top_ids == self.speech_token_size:
605
+ break
606
+ else:
607
+ raise ValueError('should not get token {}'.format(top_ids))
608
+ # in stream mode, yield token one by one
609
+ yield top_ids
610
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)