xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,10 +12,10 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
import threading
|
|
15
15
|
import torch
|
|
16
16
|
import torch.nn.functional as F
|
|
17
17
|
from matcha.models.components.flow_matching import BASECFM
|
|
18
|
+
from cosyvoice.utils.common import set_all_random_seed
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class ConditionalCFM(BASECFM):
|
|
@@ -31,7 +32,6 @@ class ConditionalCFM(BASECFM):
|
|
|
31
32
|
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
32
33
|
# Just change the architecture of the estimator here
|
|
33
34
|
self.estimator = estimator
|
|
34
|
-
self.lock = threading.Lock()
|
|
35
35
|
|
|
36
36
|
@torch.inference_mode()
|
|
37
37
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
@@ -68,7 +68,7 @@ class ConditionalCFM(BASECFM):
|
|
|
68
68
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
69
69
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
|
70
70
|
|
|
71
|
-
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
71
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
|
72
72
|
"""
|
|
73
73
|
Fixed euler solver for ODEs.
|
|
74
74
|
Args:
|
|
@@ -109,7 +109,8 @@ class ConditionalCFM(BASECFM):
|
|
|
109
109
|
x_in, mask_in,
|
|
110
110
|
mu_in, t_in,
|
|
111
111
|
spks_in,
|
|
112
|
-
cond_in
|
|
112
|
+
cond_in,
|
|
113
|
+
streaming
|
|
113
114
|
)
|
|
114
115
|
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
115
116
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
@@ -121,25 +122,33 @@ class ConditionalCFM(BASECFM):
|
|
|
121
122
|
|
|
122
123
|
return sol[-1].float()
|
|
123
124
|
|
|
124
|
-
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
|
125
|
+
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
|
125
126
|
if isinstance(self.estimator, torch.nn.Module):
|
|
126
|
-
return self.estimator(x, mask, mu, t, spks, cond)
|
|
127
|
+
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
|
127
128
|
else:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
129
|
+
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
|
130
|
+
# NOTE need to synchronize when switching stream
|
|
131
|
+
torch.cuda.current_stream().synchronize()
|
|
132
|
+
with stream:
|
|
133
|
+
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
134
|
+
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
135
|
+
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
136
|
+
estimator.set_input_shape('t', (2,))
|
|
137
|
+
estimator.set_input_shape('spks', (2, 80))
|
|
138
|
+
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
139
|
+
data_ptrs = [x.contiguous().data_ptr(),
|
|
140
|
+
mask.contiguous().data_ptr(),
|
|
141
|
+
mu.contiguous().data_ptr(),
|
|
142
|
+
t.contiguous().data_ptr(),
|
|
143
|
+
spks.contiguous().data_ptr(),
|
|
144
|
+
cond.contiguous().data_ptr(),
|
|
145
|
+
x.data_ptr()]
|
|
146
|
+
for i, j in enumerate(data_ptrs):
|
|
147
|
+
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
|
135
148
|
# run trt engine
|
|
136
|
-
assert
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
t.contiguous().data_ptr(),
|
|
140
|
-
spks.contiguous().data_ptr(),
|
|
141
|
-
cond.contiguous().data_ptr(),
|
|
142
|
-
x.data_ptr()]) is True
|
|
149
|
+
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
|
150
|
+
torch.cuda.current_stream().synchronize()
|
|
151
|
+
self.estimator.release_estimator(estimator, stream)
|
|
143
152
|
return x
|
|
144
153
|
|
|
145
154
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
|
@@ -187,10 +196,11 @@ class ConditionalCFM(BASECFM):
|
|
|
187
196
|
class CausalConditionalCFM(ConditionalCFM):
|
|
188
197
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
189
198
|
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
|
199
|
+
set_all_random_seed(0)
|
|
190
200
|
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
|
191
201
|
|
|
192
202
|
@torch.inference_mode()
|
|
193
|
-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None,
|
|
203
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
|
194
204
|
"""Forward diffusion
|
|
195
205
|
|
|
196
206
|
Args:
|
|
@@ -209,131 +219,9 @@ class CausalConditionalCFM(ConditionalCFM):
|
|
|
209
219
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
210
220
|
"""
|
|
211
221
|
|
|
212
|
-
|
|
213
|
-
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
|
|
214
|
-
z = z[:, :, offset:]
|
|
215
|
-
offset += mu.size(2)
|
|
222
|
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
|
216
223
|
# fix prompt and overlap part mu and z
|
|
217
224
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
218
225
|
if self.t_scheduler == 'cosine':
|
|
219
226
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
220
|
-
|
|
221
|
-
cache['offset'] = offset
|
|
222
|
-
return mel, cache
|
|
223
|
-
|
|
224
|
-
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
|
|
225
|
-
"""
|
|
226
|
-
Fixed euler solver for ODEs.
|
|
227
|
-
Args:
|
|
228
|
-
x (torch.Tensor): random noise
|
|
229
|
-
t_span (torch.Tensor): n_timesteps interpolated
|
|
230
|
-
shape: (n_timesteps + 1,)
|
|
231
|
-
mu (torch.Tensor): output of encoder
|
|
232
|
-
shape: (batch_size, n_feats, mel_timesteps)
|
|
233
|
-
mask (torch.Tensor): output_mask
|
|
234
|
-
shape: (batch_size, 1, mel_timesteps)
|
|
235
|
-
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
236
|
-
shape: (batch_size, spk_emb_dim)
|
|
237
|
-
cond: Not used but kept for future purposes
|
|
238
|
-
"""
|
|
239
|
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
240
|
-
t = t.unsqueeze(dim=0)
|
|
241
|
-
|
|
242
|
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
243
|
-
# Or in future might add like a return_all_steps flag
|
|
244
|
-
sol = []
|
|
245
|
-
|
|
246
|
-
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
247
|
-
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
248
|
-
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
|
249
|
-
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
250
|
-
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
|
251
|
-
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
|
252
|
-
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
253
|
-
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
|
|
254
|
-
for step in range(1, len(t_span)):
|
|
255
|
-
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
256
|
-
x_in[:] = x
|
|
257
|
-
mask_in[:] = mask
|
|
258
|
-
mu_in[0] = mu
|
|
259
|
-
t_in[:] = t.unsqueeze(0)
|
|
260
|
-
spks_in[0] = spks
|
|
261
|
-
cond_in[0] = cond
|
|
262
|
-
cache_step = {k: v[step - 1] for k, v in cache.items()}
|
|
263
|
-
dphi_dt, cache_step = self.forward_estimator(
|
|
264
|
-
x_in, mask_in,
|
|
265
|
-
mu_in, t_in,
|
|
266
|
-
spks_in,
|
|
267
|
-
cond_in,
|
|
268
|
-
cache_step
|
|
269
|
-
)
|
|
270
|
-
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
|
|
271
|
-
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
|
|
272
|
-
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
|
273
|
-
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
|
|
274
|
-
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
|
275
|
-
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
|
|
276
|
-
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
|
277
|
-
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
|
|
278
|
-
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
|
279
|
-
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
280
|
-
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
281
|
-
x = x + dt * dphi_dt
|
|
282
|
-
t = t + dt
|
|
283
|
-
sol.append(x)
|
|
284
|
-
if step < len(t_span) - 1:
|
|
285
|
-
dt = t_span[step + 1] - t
|
|
286
|
-
return sol[-1].float(), cache
|
|
287
|
-
|
|
288
|
-
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
|
289
|
-
if isinstance(self.estimator, torch.nn.Module):
|
|
290
|
-
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
|
|
291
|
-
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
|
|
292
|
-
else:
|
|
293
|
-
with self.lock:
|
|
294
|
-
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
295
|
-
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
296
|
-
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
297
|
-
self.estimator.set_input_shape('t', (2,))
|
|
298
|
-
self.estimator.set_input_shape('spks', (2, 80))
|
|
299
|
-
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
300
|
-
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
|
|
301
|
-
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
|
|
302
|
-
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
|
|
303
|
-
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
|
|
304
|
-
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
|
|
305
|
-
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
|
|
306
|
-
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
|
|
307
|
-
# run trt engine
|
|
308
|
-
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
|
309
|
-
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
|
|
310
|
-
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
|
311
|
-
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
|
312
|
-
mask.contiguous().data_ptr(),
|
|
313
|
-
mu.contiguous().data_ptr(),
|
|
314
|
-
t.contiguous().data_ptr(),
|
|
315
|
-
spks.contiguous().data_ptr(),
|
|
316
|
-
cond.contiguous().data_ptr(),
|
|
317
|
-
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
|
|
318
|
-
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
|
|
319
|
-
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
|
|
320
|
-
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
|
|
321
|
-
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
|
|
322
|
-
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
|
|
323
|
-
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
|
|
324
|
-
x.data_ptr(),
|
|
325
|
-
cache['down_blocks_conv_cache'].data_ptr(),
|
|
326
|
-
down_blocks_kv_cache_out.data_ptr(),
|
|
327
|
-
cache['mid_blocks_conv_cache'].data_ptr(),
|
|
328
|
-
mid_blocks_kv_cache_out.data_ptr(),
|
|
329
|
-
cache['up_blocks_conv_cache'].data_ptr(),
|
|
330
|
-
up_blocks_kv_cache_out.data_ptr(),
|
|
331
|
-
cache['final_blocks_conv_cache'].data_ptr()]) is True
|
|
332
|
-
cache = (cache['down_blocks_conv_cache'],
|
|
333
|
-
down_blocks_kv_cache_out,
|
|
334
|
-
cache['mid_blocks_conv_cache'],
|
|
335
|
-
mid_blocks_kv_cache_out,
|
|
336
|
-
cache['up_blocks_conv_cache'],
|
|
337
|
-
up_blocks_kv_cache_out,
|
|
338
|
-
cache['final_blocks_conv_cache'])
|
|
339
|
-
return x, cache
|
|
227
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
|
@@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|
|
223
223
|
return sine_merge, noise, uv
|
|
224
224
|
|
|
225
225
|
|
|
226
|
+
class SineGen2(torch.nn.Module):
|
|
227
|
+
""" Definition of sine generator
|
|
228
|
+
SineGen(samp_rate, harmonic_num = 0,
|
|
229
|
+
sine_amp = 0.1, noise_std = 0.003,
|
|
230
|
+
voiced_threshold = 0,
|
|
231
|
+
flag_for_pulse=False)
|
|
232
|
+
samp_rate: sampling rate in Hz
|
|
233
|
+
harmonic_num: number of harmonic overtones (default 0)
|
|
234
|
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
|
235
|
+
noise_std: std of Gaussian noise (default 0.003)
|
|
236
|
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
|
237
|
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
|
238
|
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
|
239
|
+
segment is always sin(np.pi) or cos(0)
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
|
243
|
+
sine_amp=0.1, noise_std=0.003,
|
|
244
|
+
voiced_threshold=0,
|
|
245
|
+
flag_for_pulse=False):
|
|
246
|
+
super(SineGen2, self).__init__()
|
|
247
|
+
self.sine_amp = sine_amp
|
|
248
|
+
self.noise_std = noise_std
|
|
249
|
+
self.harmonic_num = harmonic_num
|
|
250
|
+
self.dim = self.harmonic_num + 1
|
|
251
|
+
self.sampling_rate = samp_rate
|
|
252
|
+
self.voiced_threshold = voiced_threshold
|
|
253
|
+
self.flag_for_pulse = flag_for_pulse
|
|
254
|
+
self.upsample_scale = upsample_scale
|
|
255
|
+
|
|
256
|
+
def _f02uv(self, f0):
|
|
257
|
+
# generate uv signal
|
|
258
|
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
|
259
|
+
return uv
|
|
260
|
+
|
|
261
|
+
def _f02sine(self, f0_values):
|
|
262
|
+
""" f0_values: (batchsize, length, dim)
|
|
263
|
+
where dim indicates fundamental tone and overtones
|
|
264
|
+
"""
|
|
265
|
+
# convert to F0 in rad. The interger part n can be ignored
|
|
266
|
+
# because 2 * np.pi * n doesn't affect phase
|
|
267
|
+
rad_values = (f0_values / self.sampling_rate) % 1
|
|
268
|
+
|
|
269
|
+
# initial phase noise (no noise for fundamental component)
|
|
270
|
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
|
271
|
+
rand_ini[:, 0] = 0
|
|
272
|
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
|
273
|
+
|
|
274
|
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
|
275
|
+
if not self.flag_for_pulse:
|
|
276
|
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
|
277
|
+
scale_factor=1 / self.upsample_scale,
|
|
278
|
+
mode="linear").transpose(1, 2)
|
|
279
|
+
|
|
280
|
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
|
281
|
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
|
282
|
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
|
283
|
+
sines = torch.sin(phase)
|
|
284
|
+
else:
|
|
285
|
+
# If necessary, make sure that the first time step of every
|
|
286
|
+
# voiced segments is sin(pi) or cos(0)
|
|
287
|
+
# This is used for pulse-train generation
|
|
288
|
+
|
|
289
|
+
# identify the last time step in unvoiced segments
|
|
290
|
+
uv = self._f02uv(f0_values)
|
|
291
|
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
|
292
|
+
uv_1[:, -1, :] = 1
|
|
293
|
+
u_loc = (uv < 1) * (uv_1 > 0)
|
|
294
|
+
|
|
295
|
+
# get the instantanouse phase
|
|
296
|
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
|
297
|
+
# different batch needs to be processed differently
|
|
298
|
+
for idx in range(f0_values.shape[0]):
|
|
299
|
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
|
300
|
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
|
301
|
+
# stores the accumulation of i.phase within
|
|
302
|
+
# each voiced segments
|
|
303
|
+
tmp_cumsum[idx, :, :] = 0
|
|
304
|
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
|
305
|
+
|
|
306
|
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
|
307
|
+
# within the previous voiced segment.
|
|
308
|
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
|
309
|
+
|
|
310
|
+
# get the sines
|
|
311
|
+
sines = torch.cos(i_phase * 2 * np.pi)
|
|
312
|
+
return sines
|
|
313
|
+
|
|
314
|
+
def forward(self, f0):
|
|
315
|
+
""" sine_tensor, uv = forward(f0)
|
|
316
|
+
input F0: tensor(batchsize=1, length, dim=1)
|
|
317
|
+
f0 for unvoiced steps should be 0
|
|
318
|
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
|
319
|
+
output uv: tensor(batchsize=1, length, 1)
|
|
320
|
+
"""
|
|
321
|
+
# fundamental component
|
|
322
|
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
|
323
|
+
|
|
324
|
+
# generate sine waveforms
|
|
325
|
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
|
326
|
+
|
|
327
|
+
# generate uv signal
|
|
328
|
+
uv = self._f02uv(f0)
|
|
329
|
+
|
|
330
|
+
# noise: for unvoiced should be similar to sine_amp
|
|
331
|
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
|
332
|
+
# . for voiced regions is self.noise_std
|
|
333
|
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
|
334
|
+
noise = noise_amp * torch.randn_like(sine_waves)
|
|
335
|
+
|
|
336
|
+
# first: set the unvoiced part to 0 by uv
|
|
337
|
+
# then: additive noise
|
|
338
|
+
sine_waves = sine_waves * uv + noise
|
|
339
|
+
return sine_waves, uv, noise
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class SourceModuleHnNSF2(torch.nn.Module):
|
|
343
|
+
""" SourceModule for hn-nsf
|
|
344
|
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
|
345
|
+
add_noise_std=0.003, voiced_threshod=0)
|
|
346
|
+
sampling_rate: sampling_rate in Hz
|
|
347
|
+
harmonic_num: number of harmonic above F0 (default: 0)
|
|
348
|
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
|
349
|
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
|
350
|
+
note that amplitude of noise in unvoiced is decided
|
|
351
|
+
by sine_amp
|
|
352
|
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
|
353
|
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
|
354
|
+
F0_sampled (batchsize, length, 1)
|
|
355
|
+
Sine_source (batchsize, length, 1)
|
|
356
|
+
noise_source (batchsize, length 1)
|
|
357
|
+
uv (batchsize, length, 1)
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
|
361
|
+
add_noise_std=0.003, voiced_threshod=0):
|
|
362
|
+
super(SourceModuleHnNSF2, self).__init__()
|
|
363
|
+
|
|
364
|
+
self.sine_amp = sine_amp
|
|
365
|
+
self.noise_std = add_noise_std
|
|
366
|
+
|
|
367
|
+
# to produce sine waveforms
|
|
368
|
+
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
|
|
369
|
+
sine_amp, add_noise_std, voiced_threshod)
|
|
370
|
+
|
|
371
|
+
# to merge source harmonics into a single excitation
|
|
372
|
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
|
373
|
+
self.l_tanh = torch.nn.Tanh()
|
|
374
|
+
|
|
375
|
+
def forward(self, x):
|
|
376
|
+
"""
|
|
377
|
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
|
378
|
+
F0_sampled (batchsize, length, 1)
|
|
379
|
+
Sine_source (batchsize, length, 1)
|
|
380
|
+
noise_source (batchsize, length 1)
|
|
381
|
+
"""
|
|
382
|
+
# source for harmonic branch
|
|
383
|
+
with torch.no_grad():
|
|
384
|
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
|
385
|
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
|
386
|
+
|
|
387
|
+
# source for noise branch, in the same shape as uv
|
|
388
|
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
|
389
|
+
return sine_merge, noise, uv
|
|
390
|
+
|
|
391
|
+
|
|
226
392
|
class HiFTGenerator(nn.Module):
|
|
227
393
|
"""
|
|
228
394
|
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
|
@@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module):
|
|
|
259
425
|
|
|
260
426
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
261
427
|
self.num_upsamples = len(upsample_rates)
|
|
262
|
-
|
|
428
|
+
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
|
|
429
|
+
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
|
|
430
|
+
self.m_source = this_SourceModuleHnNSF(
|
|
263
431
|
sampling_rate=sampling_rate,
|
|
264
432
|
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
|
265
433
|
harmonic_num=nb_harmonics,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,7 +12,10 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
15
|
+
import queue
|
|
14
16
|
import random
|
|
17
|
+
import time
|
|
18
|
+
import threading
|
|
15
19
|
from typing import Dict, Optional, Callable, List, Generator
|
|
16
20
|
import torch
|
|
17
21
|
from torch import nn
|
|
@@ -170,6 +174,7 @@ class TransformerLM(torch.nn.Module):
|
|
|
170
174
|
sampling: int = 25,
|
|
171
175
|
max_token_text_ratio: float = 20,
|
|
172
176
|
min_token_text_ratio: float = 2,
|
|
177
|
+
uuid: str = '',
|
|
173
178
|
) -> Generator[torch.Tensor, None, None]:
|
|
174
179
|
device = text.device
|
|
175
180
|
text = torch.concat([prompt_text, text], dim=1)
|
|
@@ -270,7 +275,6 @@ class Qwen2LM(TransformerLM):
|
|
|
270
275
|
self.llm_input_size = llm_input_size
|
|
271
276
|
self.llm_output_size = llm_output_size
|
|
272
277
|
self.speech_token_size = speech_token_size
|
|
273
|
-
|
|
274
278
|
# 2. build speech token language model related modules
|
|
275
279
|
self.sos_eos = 0
|
|
276
280
|
self.task_id = 1
|
|
@@ -293,6 +297,10 @@ class Qwen2LM(TransformerLM):
|
|
|
293
297
|
self.sampling = sampling
|
|
294
298
|
self.mix_ratio = mix_ratio
|
|
295
299
|
|
|
300
|
+
# 5. vllm related
|
|
301
|
+
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
|
302
|
+
self.vllm_output_queue = {}
|
|
303
|
+
|
|
296
304
|
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
|
297
305
|
lm_target, lm_input = [], []
|
|
298
306
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
|
@@ -369,6 +377,53 @@ class Qwen2LM(TransformerLM):
|
|
|
369
377
|
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
|
370
378
|
return {'loss': loss, 'acc': acc}
|
|
371
379
|
|
|
380
|
+
def forward_dpo(
|
|
381
|
+
self,
|
|
382
|
+
batch: dict,
|
|
383
|
+
device: torch.device,
|
|
384
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
385
|
+
text_token = batch['text_token'].to(device)
|
|
386
|
+
text_token_len = batch['text_token_len'].to(device)
|
|
387
|
+
speech_token = batch['speech_token'].to(device)
|
|
388
|
+
speech_token_len = batch['speech_token_len'].to(device)
|
|
389
|
+
reject_speech_token = batch['reject_speech_token'].to(device)
|
|
390
|
+
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
|
|
391
|
+
|
|
392
|
+
# 1. encode text_token
|
|
393
|
+
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
394
|
+
|
|
395
|
+
# 2. encode speech_token
|
|
396
|
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
|
397
|
+
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
|
398
|
+
speech_token_combined = speech_token + reject_speech_token
|
|
399
|
+
speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
|
|
400
|
+
speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
|
|
401
|
+
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
|
402
|
+
|
|
403
|
+
# 3. prepare llm_input/target
|
|
404
|
+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
|
405
|
+
speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
|
406
|
+
lm_target = lm_target.to(device)
|
|
407
|
+
|
|
408
|
+
# 4. run lm forward
|
|
409
|
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
|
410
|
+
logits = self.llm_decoder(lm_output)
|
|
411
|
+
chosen_logits = logits[:text_token.shape[0]]
|
|
412
|
+
rejected_logits = logits[text_token.shape[0]:]
|
|
413
|
+
chosen_lm_target = lm_target[:text_token.shape[0]]
|
|
414
|
+
rejected_lm_target = lm_target[text_token.shape[0]:]
|
|
415
|
+
loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
|
|
416
|
+
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
|
|
417
|
+
|
|
418
|
+
# 5. calculate dpo logits
|
|
419
|
+
chosen_lm_mask = chosen_lm_target == IGNORE_ID
|
|
420
|
+
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
|
421
|
+
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
|
422
|
+
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
|
423
|
+
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
|
|
424
|
+
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
|
|
425
|
+
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
|
426
|
+
|
|
372
427
|
@torch.inference_mode()
|
|
373
428
|
def inference(
|
|
374
429
|
self,
|
|
@@ -382,6 +437,7 @@ class Qwen2LM(TransformerLM):
|
|
|
382
437
|
sampling: int = 25,
|
|
383
438
|
max_token_text_ratio: float = 20,
|
|
384
439
|
min_token_text_ratio: float = 2,
|
|
440
|
+
uuid: str = '',
|
|
385
441
|
) -> Generator[torch.Tensor, None, None]:
|
|
386
442
|
device = text.device
|
|
387
443
|
text = torch.concat([prompt_text, text], dim=1)
|
|
@@ -402,22 +458,57 @@ class Qwen2LM(TransformerLM):
|
|
|
402
458
|
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
|
403
459
|
|
|
404
460
|
# 5. step by step decode
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
out_tokens
|
|
420
|
-
|
|
461
|
+
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
|
462
|
+
yield token
|
|
463
|
+
|
|
464
|
+
@torch.inference_mode()
|
|
465
|
+
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
|
|
466
|
+
if hasattr(self, 'vllm'):
|
|
467
|
+
from vllm import SamplingParams, RequestOutput
|
|
468
|
+
sampling_params = SamplingParams(top_k=sampling,
|
|
469
|
+
stop_token_ids=self.stop_token_ids,
|
|
470
|
+
min_tokens=min_len,
|
|
471
|
+
max_tokens=max_len)
|
|
472
|
+
with self.lock:
|
|
473
|
+
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
|
|
474
|
+
self.vllm_output_queue[uuid] = queue.Queue()
|
|
475
|
+
out_tokens = []
|
|
476
|
+
while True:
|
|
477
|
+
with self.lock:
|
|
478
|
+
if self.vllm_output_queue[uuid].empty() is True:
|
|
479
|
+
request_outputs: List[RequestOutput] = self.vllm.step()
|
|
480
|
+
for request_output in request_outputs:
|
|
481
|
+
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
|
482
|
+
self.vllm_output_queue[request_output.request_id].put(top_ids)
|
|
483
|
+
if self.vllm_output_queue[uuid].empty() is False:
|
|
484
|
+
top_ids = self.vllm_output_queue[uuid].get()
|
|
485
|
+
if top_ids in self.stop_token_ids:
|
|
486
|
+
break
|
|
487
|
+
# in stream mode, yield token one by one
|
|
488
|
+
yield top_ids
|
|
489
|
+
out_tokens.append(top_ids)
|
|
490
|
+
if len(out_tokens) == max_len:
|
|
491
|
+
break
|
|
492
|
+
time.sleep(0.001)
|
|
493
|
+
with self.lock:
|
|
494
|
+
self.vllm_output_queue.pop(uuid)
|
|
495
|
+
else:
|
|
496
|
+
out_tokens = []
|
|
497
|
+
cache = None
|
|
498
|
+
for i in range(max_len):
|
|
499
|
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
500
|
+
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
|
501
|
+
cache=cache)
|
|
502
|
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
503
|
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
|
504
|
+
if top_ids == self.speech_token_size:
|
|
505
|
+
break
|
|
506
|
+
if top_ids > self.speech_token_size:
|
|
507
|
+
continue
|
|
508
|
+
# in stream mode, yield token one by one
|
|
509
|
+
yield top_ids
|
|
510
|
+
out_tokens.append(top_ids)
|
|
511
|
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
421
512
|
|
|
422
513
|
@torch.inference_mode()
|
|
423
514
|
def inference_bistream(
|