xinference 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +97 -8
- xinference/client/restful/restful_client.py +51 -11
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/worker.py +31 -37
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +1 -0
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +20 -3
- xinference/model/audio/model_spec_modelscope.json +18 -1
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +37 -110
- xinference/model/llm/core.py +15 -6
- xinference/model/llm/llama_cpp/core.py +25 -353
- xinference/model/llm/llm_family.json +613 -89
- xinference/model/llm/llm_family.py +9 -1
- xinference/model/llm/llm_family_modelscope.json +540 -90
- xinference/model/llm/mlx/core.py +6 -3
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +16 -3
- xinference/model/llm/transformers/chatglm.py +2 -2
- xinference/model/llm/transformers/cogagent.py +1 -1
- xinference/model/llm/transformers/cogvlm2.py +1 -1
- xinference/model/llm/transformers/core.py +9 -3
- xinference/model/llm/transformers/glm4v.py +1 -1
- xinference/model/llm/transformers/minicpmv26.py +1 -1
- xinference/model/llm/transformers/qwen-omni.py +6 -0
- xinference/model/llm/transformers/qwen_vl.py +1 -1
- xinference/model/llm/utils.py +68 -45
- xinference/model/llm/vllm/core.py +38 -18
- xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +133 -16
- xinference/model/video/model_spec.json +54 -0
- xinference/model/video/model_spec_modelscope.json +56 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +0 -71
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/src/locales/en.json +6 -4
- xinference/web/ui/src/locales/zh.json +6 -4
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/RECORD +87 -87
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
- xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import random
|
|
14
15
|
from typing import Dict, Optional, Callable, List, Generator
|
|
15
16
|
import torch
|
|
16
17
|
from torch import nn
|
|
@@ -20,6 +21,8 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
|
|
20
21
|
from cosyvoice.utils.common import IGNORE_ID
|
|
21
22
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
|
22
23
|
from cosyvoice.utils.common import th_accuracy
|
|
24
|
+
from cosyvoice.utils.file_utils import logging
|
|
25
|
+
from cosyvoice.utils.mask import make_pad_mask
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
class TransformerLM(torch.nn.Module):
|
|
@@ -144,10 +147,14 @@ class TransformerLM(torch.nn.Module):
|
|
|
144
147
|
sampling: int,
|
|
145
148
|
ignore_eos: bool = True,
|
|
146
149
|
):
|
|
150
|
+
num_trials, max_trials = 0, 100
|
|
147
151
|
while True:
|
|
148
152
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
149
153
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
150
154
|
break
|
|
155
|
+
num_trials += 1
|
|
156
|
+
if num_trials > max_trials:
|
|
157
|
+
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
|
151
158
|
return top_ids
|
|
152
159
|
|
|
153
160
|
@torch.inference_mode()
|
|
@@ -178,7 +185,7 @@ class TransformerLM(torch.nn.Module):
|
|
|
178
185
|
embedding = self.spk_embed_affine_layer(embedding)
|
|
179
186
|
embedding = embedding.unsqueeze(dim=1)
|
|
180
187
|
else:
|
|
181
|
-
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
188
|
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
|
182
189
|
|
|
183
190
|
# 3. concat llm_input
|
|
184
191
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
@@ -221,6 +228,17 @@ class Qwen2Encoder(torch.nn.Module):
|
|
|
221
228
|
super().__init__()
|
|
222
229
|
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
|
223
230
|
|
|
231
|
+
def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
|
|
232
|
+
T = xs.size(1)
|
|
233
|
+
masks = ~make_pad_mask(xs_lens, T)
|
|
234
|
+
outs = self.model(
|
|
235
|
+
inputs_embeds=xs,
|
|
236
|
+
attention_mask=masks,
|
|
237
|
+
output_hidden_states=True,
|
|
238
|
+
return_dict=True,
|
|
239
|
+
)
|
|
240
|
+
return outs.hidden_states[-1], masks.unsqueeze(1)
|
|
241
|
+
|
|
224
242
|
def forward_one_step(self, xs, masks, cache=None):
|
|
225
243
|
input_masks = masks[:, -1, :]
|
|
226
244
|
outs = self.model(
|
|
@@ -236,7 +254,7 @@ class Qwen2Encoder(torch.nn.Module):
|
|
|
236
254
|
return xs, new_cache
|
|
237
255
|
|
|
238
256
|
|
|
239
|
-
class Qwen2LM(
|
|
257
|
+
class Qwen2LM(TransformerLM):
|
|
240
258
|
def __init__(
|
|
241
259
|
self,
|
|
242
260
|
llm_input_size: int,
|
|
@@ -246,8 +264,9 @@ class Qwen2LM(torch.nn.Module):
|
|
|
246
264
|
sampling: Callable,
|
|
247
265
|
length_normalized_loss: bool = True,
|
|
248
266
|
lsm_weight: float = 0.0,
|
|
267
|
+
mix_ratio: List[int] = [5, 15],
|
|
249
268
|
):
|
|
250
|
-
|
|
269
|
+
torch.nn.Module.__init__(self)
|
|
251
270
|
self.llm_input_size = llm_input_size
|
|
252
271
|
self.llm_output_size = llm_output_size
|
|
253
272
|
self.speech_token_size = speech_token_size
|
|
@@ -272,19 +291,83 @@ class Qwen2LM(torch.nn.Module):
|
|
|
272
291
|
|
|
273
292
|
# 4. sampling method
|
|
274
293
|
self.sampling = sampling
|
|
294
|
+
self.mix_ratio = mix_ratio
|
|
275
295
|
|
|
276
|
-
def
|
|
296
|
+
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
|
297
|
+
lm_target, lm_input = [], []
|
|
298
|
+
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
|
299
|
+
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
|
300
|
+
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
|
301
|
+
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
|
302
|
+
for i in range(len(text_token)):
|
|
303
|
+
# bistream sequence
|
|
304
|
+
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
|
305
|
+
this_lm_target, this_lm_input = [], []
|
|
306
|
+
this_lm_target.append(IGNORE_ID)
|
|
307
|
+
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
|
|
308
|
+
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
|
309
|
+
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
|
310
|
+
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
|
311
|
+
if len(this_text_token) == self.mix_ratio[0]:
|
|
312
|
+
assert len(this_speech_token) == self.mix_ratio[1]
|
|
313
|
+
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
|
314
|
+
this_lm_target += this_speech_token
|
|
315
|
+
this_lm_target.append(self.speech_token_size + 2)
|
|
316
|
+
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
|
|
317
|
+
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
|
|
318
|
+
else:
|
|
319
|
+
this_lm_target += [-1] * len(this_text_token)
|
|
320
|
+
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
|
321
|
+
this_lm_target.append(self.speech_token_size)
|
|
322
|
+
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
|
323
|
+
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
|
|
324
|
+
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
|
325
|
+
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
|
326
|
+
# unistream sequence
|
|
327
|
+
else:
|
|
328
|
+
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
|
|
329
|
+
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
|
|
330
|
+
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
|
|
331
|
+
lm_target.append(this_lm_target)
|
|
332
|
+
lm_input.append(this_lm_input)
|
|
333
|
+
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
|
334
|
+
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
|
335
|
+
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
|
336
|
+
return lm_target, lm_input, lm_input_len
|
|
337
|
+
|
|
338
|
+
def forward(
|
|
277
339
|
self,
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
340
|
+
batch: dict,
|
|
341
|
+
device: torch.device,
|
|
342
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
343
|
+
"""
|
|
344
|
+
Args:
|
|
345
|
+
text: (B, L, D)
|
|
346
|
+
text_lengths: (B,)
|
|
347
|
+
audio: (B, T, N) or (B, T)
|
|
348
|
+
audio_lengths: (B,)
|
|
349
|
+
"""
|
|
350
|
+
text_token = batch['text_token'].to(device)
|
|
351
|
+
text_token_len = batch['text_token_len'].to(device)
|
|
352
|
+
speech_token = batch['speech_token'].to(device)
|
|
353
|
+
speech_token_len = batch['speech_token_len'].to(device)
|
|
354
|
+
|
|
355
|
+
# 1. encode text_token
|
|
356
|
+
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
357
|
+
|
|
358
|
+
# 2. encode speech_token
|
|
359
|
+
speech_token_emb = self.speech_embedding(speech_token)
|
|
360
|
+
|
|
361
|
+
# 3. prepare llm_input/target
|
|
362
|
+
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
|
|
363
|
+
lm_target = lm_target.to(device)
|
|
364
|
+
|
|
365
|
+
# 4. run lm forward
|
|
366
|
+
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
|
367
|
+
logits = self.llm_decoder(lm_output)
|
|
368
|
+
loss = self.criterion_ce(logits, lm_target.to(device))
|
|
369
|
+
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
|
370
|
+
return {'loss': loss, 'acc': acc}
|
|
288
371
|
|
|
289
372
|
@torch.inference_mode()
|
|
290
373
|
def inference(
|
|
@@ -305,9 +388,6 @@ class Qwen2LM(torch.nn.Module):
|
|
|
305
388
|
text_len += prompt_text_len
|
|
306
389
|
text = self.llm.model.model.embed_tokens(text)
|
|
307
390
|
|
|
308
|
-
# 2. encode embedding
|
|
309
|
-
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
310
|
-
|
|
311
391
|
# 3. concat llm_input
|
|
312
392
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
313
393
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
@@ -315,7 +395,7 @@ class Qwen2LM(torch.nn.Module):
|
|
|
315
395
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
316
396
|
else:
|
|
317
397
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
318
|
-
lm_input = torch.concat([sos_eos_emb,
|
|
398
|
+
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
319
399
|
|
|
320
400
|
# 4. cal min/max_length
|
|
321
401
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
|
@@ -338,3 +418,103 @@ class Qwen2LM(torch.nn.Module):
|
|
|
338
418
|
yield top_ids
|
|
339
419
|
out_tokens.append(top_ids)
|
|
340
420
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
421
|
+
|
|
422
|
+
@torch.inference_mode()
|
|
423
|
+
def inference_bistream(
|
|
424
|
+
self,
|
|
425
|
+
text: Generator,
|
|
426
|
+
prompt_text: torch.Tensor,
|
|
427
|
+
prompt_text_len: torch.Tensor,
|
|
428
|
+
prompt_speech_token: torch.Tensor,
|
|
429
|
+
prompt_speech_token_len: torch.Tensor,
|
|
430
|
+
embedding: torch.Tensor,
|
|
431
|
+
sampling: int = 25,
|
|
432
|
+
max_token_text_ratio: float = 20,
|
|
433
|
+
min_token_text_ratio: float = 2,
|
|
434
|
+
) -> Generator[torch.Tensor, None, None]:
|
|
435
|
+
|
|
436
|
+
device = prompt_text.device
|
|
437
|
+
# 1. prepare input
|
|
438
|
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
439
|
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
440
|
+
if prompt_speech_token_len != 0:
|
|
441
|
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
442
|
+
else:
|
|
443
|
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
|
444
|
+
lm_input = torch.concat([sos_eos_emb], dim=1)
|
|
445
|
+
|
|
446
|
+
# 2. iterate text
|
|
447
|
+
out_tokens = []
|
|
448
|
+
cache = None
|
|
449
|
+
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
|
450
|
+
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
|
451
|
+
next_fill_index = -1
|
|
452
|
+
for this_text in text:
|
|
453
|
+
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
|
454
|
+
# prompt_speech_token_emb not empty, try append to lm_input
|
|
455
|
+
while prompt_speech_token_emb.size(1) != 0:
|
|
456
|
+
if text_cache.size(1) >= self.mix_ratio[0]:
|
|
457
|
+
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
|
458
|
+
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
|
459
|
+
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
|
460
|
+
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
|
461
|
+
else:
|
|
462
|
+
logging.info('not enough text token to decode, wait for more')
|
|
463
|
+
break
|
|
464
|
+
# no prompt_speech_token_emb remain, can decode some speech token
|
|
465
|
+
if prompt_speech_token_emb.size(1) == 0:
|
|
466
|
+
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
|
467
|
+
logging.info('get fill token, need to append more text token')
|
|
468
|
+
if text_cache.size(1) >= self.mix_ratio[0]:
|
|
469
|
+
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
|
470
|
+
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
|
471
|
+
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
|
472
|
+
lm_input = lm_input_text
|
|
473
|
+
else:
|
|
474
|
+
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
|
475
|
+
text_cache = text_cache[:, self.mix_ratio[0]:]
|
|
476
|
+
else:
|
|
477
|
+
logging.info('not enough text token to decode, wait for more')
|
|
478
|
+
continue
|
|
479
|
+
while True:
|
|
480
|
+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
|
481
|
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
482
|
+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
|
483
|
+
cache=cache)
|
|
484
|
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
485
|
+
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
|
486
|
+
top_ids = self.speech_token_size + 2
|
|
487
|
+
next_fill_index += (self.mix_ratio[1] + 1)
|
|
488
|
+
else:
|
|
489
|
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
|
490
|
+
if top_ids == self.speech_token_size + 2:
|
|
491
|
+
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
|
492
|
+
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
|
493
|
+
out_tokens.append(top_ids)
|
|
494
|
+
if top_ids >= self.speech_token_size:
|
|
495
|
+
if top_ids == self.speech_token_size + 2:
|
|
496
|
+
break
|
|
497
|
+
else:
|
|
498
|
+
raise ValueError('should not get token {}'.format(top_ids))
|
|
499
|
+
yield top_ids
|
|
500
|
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
501
|
+
|
|
502
|
+
# 3. final decode
|
|
503
|
+
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
|
504
|
+
logging.info('no more text token, decode until met eos')
|
|
505
|
+
while True:
|
|
506
|
+
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
|
507
|
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
508
|
+
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
|
509
|
+
cache=cache)
|
|
510
|
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
511
|
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
|
512
|
+
out_tokens.append(top_ids)
|
|
513
|
+
if top_ids >= self.speech_token_size:
|
|
514
|
+
if top_ids == self.speech_token_size:
|
|
515
|
+
break
|
|
516
|
+
else:
|
|
517
|
+
raise ValueError('should not get token {}'.format(top_ids))
|
|
518
|
+
# in stream mode, yield token one by one
|
|
519
|
+
yield top_ids
|
|
520
|
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
@@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
|
287
287
|
Returns:
|
|
288
288
|
torch.Tensor: Corresponding encoding
|
|
289
289
|
"""
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
290
|
+
# How to subscript a Union type:
|
|
291
|
+
# https://github.com/pytorch/pytorch/issues/69434
|
|
292
|
+
if isinstance(offset, int):
|
|
293
|
+
pos_emb = self.pe[
|
|
294
|
+
:,
|
|
295
|
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
|
296
|
+
]
|
|
297
|
+
elif isinstance(offset, torch.Tensor):
|
|
298
|
+
pos_emb = self.pe[
|
|
299
|
+
:,
|
|
300
|
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
|
301
|
+
]
|
|
294
302
|
return pos_emb
|
|
@@ -56,11 +56,16 @@ class Upsample1D(nn.Module):
|
|
|
56
56
|
# In this mode, first repeat interpolate, than conv with stride=1
|
|
57
57
|
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
|
58
58
|
|
|
59
|
-
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
|
59
|
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
60
60
|
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
|
61
|
-
|
|
61
|
+
if conv_cache.size(2) == 0:
|
|
62
|
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
|
63
|
+
else:
|
|
64
|
+
assert conv_cache.size(2) == self.stride * 2
|
|
65
|
+
outputs = torch.concat([conv_cache, outputs], dim=2)
|
|
66
|
+
conv_cache_new = outputs[:, :, -self.stride * 2:]
|
|
62
67
|
outputs = self.conv(outputs)
|
|
63
|
-
return outputs, input_lengths * self.stride
|
|
68
|
+
return outputs, input_lengths * self.stride, conv_cache_new
|
|
64
69
|
|
|
65
70
|
|
|
66
71
|
class PreLookaheadLayer(nn.Module):
|
|
@@ -78,22 +83,32 @@ class PreLookaheadLayer(nn.Module):
|
|
|
78
83
|
kernel_size=3, stride=1, padding=0,
|
|
79
84
|
)
|
|
80
85
|
|
|
81
|
-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
82
87
|
"""
|
|
83
88
|
inputs: (batch_size, seq_len, channels)
|
|
84
89
|
"""
|
|
85
90
|
outputs = inputs.transpose(1, 2).contiguous()
|
|
91
|
+
context = context.transpose(1, 2).contiguous()
|
|
86
92
|
# look ahead
|
|
87
|
-
|
|
93
|
+
if context.size(2) == 0:
|
|
94
|
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
|
95
|
+
else:
|
|
96
|
+
assert context.size(2) == self.pre_lookahead_len
|
|
97
|
+
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
|
88
98
|
outputs = F.leaky_relu(self.conv1(outputs))
|
|
89
99
|
# outputs
|
|
90
|
-
|
|
100
|
+
if conv2_cache.size(2) == 0:
|
|
101
|
+
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
|
102
|
+
else:
|
|
103
|
+
assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
|
|
104
|
+
outputs = torch.concat([conv2_cache, outputs], dim=2)
|
|
105
|
+
conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
|
|
91
106
|
outputs = self.conv2(outputs)
|
|
92
107
|
outputs = outputs.transpose(1, 2).contiguous()
|
|
93
108
|
|
|
94
109
|
# residual connection
|
|
95
110
|
outputs = outputs + inputs
|
|
96
|
-
return outputs
|
|
111
|
+
return outputs, conv2_cache_new
|
|
97
112
|
|
|
98
113
|
|
|
99
114
|
class UpsampleConformerEncoder(torch.nn.Module):
|
|
@@ -240,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
240
255
|
xs_lens: torch.Tensor,
|
|
241
256
|
decoding_chunk_size: int = 0,
|
|
242
257
|
num_decoding_left_chunks: int = -1,
|
|
258
|
+
streaming: bool = False,
|
|
243
259
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
244
260
|
"""Embed positions in tensor.
|
|
245
261
|
|
|
@@ -270,30 +286,20 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
270
286
|
xs = self.global_cmvn(xs)
|
|
271
287
|
xs, pos_emb, masks = self.embed(xs, masks)
|
|
272
288
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
273
|
-
chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
274
|
-
self.use_dynamic_chunk,
|
|
275
|
-
self.use_dynamic_left_chunk,
|
|
276
|
-
decoding_chunk_size,
|
|
277
|
-
self.static_chunk_size,
|
|
278
|
-
num_decoding_left_chunks)
|
|
289
|
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
|
279
290
|
# lookahead + conformer encoder
|
|
280
|
-
xs = self.pre_lookahead_layer(xs)
|
|
291
|
+
xs, _ = self.pre_lookahead_layer(xs)
|
|
281
292
|
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
282
293
|
|
|
283
294
|
# upsample + conformer encoder
|
|
284
295
|
xs = xs.transpose(1, 2).contiguous()
|
|
285
|
-
xs, xs_lens = self.up_layer(xs, xs_lens)
|
|
296
|
+
xs, xs_lens, _ = self.up_layer(xs, xs_lens)
|
|
286
297
|
xs = xs.transpose(1, 2).contiguous()
|
|
287
298
|
T = xs.size(1)
|
|
288
299
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
289
300
|
xs, pos_emb, masks = self.up_embed(xs, masks)
|
|
290
301
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
291
|
-
chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
292
|
-
self.use_dynamic_chunk,
|
|
293
|
-
self.use_dynamic_left_chunk,
|
|
294
|
-
decoding_chunk_size,
|
|
295
|
-
self.static_chunk_size * self.up_layer.stride,
|
|
296
|
-
num_decoding_left_chunks)
|
|
302
|
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
|
297
303
|
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
298
304
|
|
|
299
305
|
if self.normalize_before:
|
|
@@ -316,3 +322,100 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
316
322
|
for layer in self.up_encoders:
|
|
317
323
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
318
324
|
return xs
|
|
325
|
+
|
|
326
|
+
@torch.jit.export
|
|
327
|
+
def forward_chunk(
|
|
328
|
+
self,
|
|
329
|
+
xs: torch.Tensor,
|
|
330
|
+
xs_lens: torch.Tensor,
|
|
331
|
+
offset: int = 0,
|
|
332
|
+
context: torch.Tensor = torch.zeros(0, 0, 0),
|
|
333
|
+
pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
|
|
334
|
+
encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
|
|
335
|
+
upsample_offset: int = 0,
|
|
336
|
+
upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
|
|
337
|
+
upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
|
|
338
|
+
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
|
|
339
|
+
"""Embed positions in tensor.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
xs: padded input tensor (B, T, D)
|
|
343
|
+
xs_lens: input length (B)
|
|
344
|
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
345
|
+
0: default for training, use random dynamic chunk.
|
|
346
|
+
<0: for decoding, use full chunk.
|
|
347
|
+
>0: for decoding, use fixed chunk size as set.
|
|
348
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
349
|
+
the chunk size is decoding_chunk_size.
|
|
350
|
+
>=0: use num_decoding_left_chunks
|
|
351
|
+
<0: use all left chunks
|
|
352
|
+
Returns:
|
|
353
|
+
encoder output tensor xs, and subsampled masks
|
|
354
|
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
|
355
|
+
masks: torch.Tensor batch padding mask after subsample
|
|
356
|
+
(B, 1, T' ~= T/subsample_rate)
|
|
357
|
+
NOTE(xcsong):
|
|
358
|
+
We pass the `__call__` method of the modules instead of `forward` to the
|
|
359
|
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
|
360
|
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
361
|
+
"""
|
|
362
|
+
assert xs.size(0) == 1
|
|
363
|
+
# tmp_masks is just for interface compatibility
|
|
364
|
+
tmp_masks = torch.ones(1,
|
|
365
|
+
xs.size(1),
|
|
366
|
+
device=xs.device,
|
|
367
|
+
dtype=torch.bool)
|
|
368
|
+
tmp_masks = tmp_masks.unsqueeze(1)
|
|
369
|
+
if self.global_cmvn is not None:
|
|
370
|
+
xs = self.global_cmvn(xs)
|
|
371
|
+
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
|
372
|
+
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
|
373
|
+
offset += xs.size(1)
|
|
374
|
+
tmp_masks = torch.ones(1,
|
|
375
|
+
context.size(1),
|
|
376
|
+
device=context.device,
|
|
377
|
+
dtype=torch.bool)
|
|
378
|
+
tmp_masks = tmp_masks.unsqueeze(1)
|
|
379
|
+
if context.size(1) != 0:
|
|
380
|
+
context, _, _ = self.embed(context, tmp_masks, offset)
|
|
381
|
+
|
|
382
|
+
# lookahead + conformer encoder
|
|
383
|
+
xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
|
|
384
|
+
# NOTE in cache mode we do not need to call add_optional_chunk_mask
|
|
385
|
+
chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
|
|
386
|
+
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
|
387
|
+
encoders_kv_cache_list = []
|
|
388
|
+
for index, layer in enumerate(self.encoders):
|
|
389
|
+
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
|
|
390
|
+
encoders_kv_cache_list.append(encoders_kv_cache_new)
|
|
391
|
+
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
|
|
392
|
+
|
|
393
|
+
# upsample
|
|
394
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
395
|
+
xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
|
|
396
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
397
|
+
|
|
398
|
+
# tmp_masks is just for interface compatibility
|
|
399
|
+
tmp_masks = torch.ones(1,
|
|
400
|
+
xs.size(1),
|
|
401
|
+
device=xs.device,
|
|
402
|
+
dtype=torch.bool)
|
|
403
|
+
tmp_masks = tmp_masks.unsqueeze(1)
|
|
404
|
+
xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
|
|
405
|
+
upsample_offset += xs.size(1)
|
|
406
|
+
|
|
407
|
+
# conformer encoder
|
|
408
|
+
chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
|
|
409
|
+
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
|
410
|
+
upsample_kv_cache_list = []
|
|
411
|
+
for index, layer in enumerate(self.up_encoders):
|
|
412
|
+
xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
|
|
413
|
+
upsample_kv_cache_list.append(upsample_kv_cache_new)
|
|
414
|
+
upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
|
|
415
|
+
|
|
416
|
+
if self.normalize_before:
|
|
417
|
+
xs = self.after_norm(xs)
|
|
418
|
+
# Here we assume the mask is not changed in encoder layers, so just
|
|
419
|
+
# return the masks before encoder layers, and the masks will be used
|
|
420
|
+
# for cross attention with decoder later
|
|
421
|
+
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)
|
|
@@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
|
|
32
32
|
RelPositionMultiHeadedAttention)
|
|
33
33
|
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
|
34
34
|
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
|
35
|
+
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
|
36
|
+
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
|
|
37
|
+
from cosyvoice.hifigan.generator import HiFTGenerator
|
|
38
|
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
COSYVOICE_ACTIVATION_CLASSES = {
|
|
@@ -68,3 +72,12 @@ COSYVOICE_ATTENTION_CLASSES = {
|
|
|
68
72
|
"selfattn": MultiHeadedAttention,
|
|
69
73
|
"rel_selfattn": RelPositionMultiHeadedAttention,
|
|
70
74
|
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_model_type(configs):
|
|
78
|
+
# NOTE CosyVoice2Model inherits CosyVoiceModel
|
|
79
|
+
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
|
80
|
+
return CosyVoiceModel
|
|
81
|
+
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
|
82
|
+
return CosyVoice2Model
|
|
83
|
+
raise TypeError('No valid model type found!')
|
|
@@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
|
162
162
|
# attention mask bias
|
|
163
163
|
# NOTE(Mddct): torch.finfo jit issues
|
|
164
164
|
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
|
165
|
-
mask = (1.0 - mask) *
|
|
165
|
+
mask = (1.0 - mask) * -1.0e+10
|
|
166
166
|
return mask
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
|
2
|
-
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -39,9 +39,47 @@ def read_json_lists(list_file):
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def load_wav(wav, target_sr):
|
|
42
|
-
speech, sample_rate = torchaudio.load(wav)
|
|
42
|
+
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
|
43
43
|
speech = speech.mean(dim=0, keepdim=True)
|
|
44
44
|
if sample_rate != target_sr:
|
|
45
45
|
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
|
46
46
|
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
|
47
47
|
return speech
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
|
51
|
+
import tensorrt as trt
|
|
52
|
+
logging.info("Converting onnx to trt...")
|
|
53
|
+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
54
|
+
logger = trt.Logger(trt.Logger.INFO)
|
|
55
|
+
builder = trt.Builder(logger)
|
|
56
|
+
network = builder.create_network(network_flags)
|
|
57
|
+
parser = trt.OnnxParser(network, logger)
|
|
58
|
+
config = builder.create_builder_config()
|
|
59
|
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
|
60
|
+
if fp16:
|
|
61
|
+
config.set_flag(trt.BuilderFlag.FP16)
|
|
62
|
+
profile = builder.create_optimization_profile()
|
|
63
|
+
# load onnx model
|
|
64
|
+
with open(onnx_model, "rb") as f:
|
|
65
|
+
if not parser.parse(f.read()):
|
|
66
|
+
for error in range(parser.num_errors):
|
|
67
|
+
print(parser.get_error(error))
|
|
68
|
+
raise ValueError('failed to parse {}'.format(onnx_model))
|
|
69
|
+
# set input shapes
|
|
70
|
+
for i in range(len(trt_kwargs['input_names'])):
|
|
71
|
+
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
|
72
|
+
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
|
73
|
+
# set input and output data type
|
|
74
|
+
for i in range(network.num_inputs):
|
|
75
|
+
input_tensor = network.get_input(i)
|
|
76
|
+
input_tensor.dtype = tensor_dtype
|
|
77
|
+
for i in range(network.num_outputs):
|
|
78
|
+
output_tensor = network.get_output(i)
|
|
79
|
+
output_tensor.dtype = tensor_dtype
|
|
80
|
+
config.add_optimization_profile(profile)
|
|
81
|
+
engine_bytes = builder.build_serialized_network(network, config)
|
|
82
|
+
# save trt engine
|
|
83
|
+
with open(trt_model, "wb") as f:
|
|
84
|
+
f.write(engine_bytes)
|
|
85
|
+
logging.info("Succesfully convert onnx to trt...")
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
|
+
import regex
|
|
16
17
|
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
|
17
18
|
|
|
18
19
|
|
|
@@ -127,3 +128,9 @@ def replace_blank(text: str):
|
|
|
127
128
|
else:
|
|
128
129
|
out_str.append(c)
|
|
129
130
|
return "".join(out_str)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def is_only_punctuation(text):
|
|
134
|
+
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
|
135
|
+
punctuation_pattern = r'^[\p{P}\p{S}]*$'
|
|
136
|
+
return bool(regex.fullmatch(punctuation_pattern, text))
|
|
@@ -195,6 +195,10 @@ def add_optional_chunk_mask(xs: torch.Tensor,
|
|
|
195
195
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
196
196
|
else:
|
|
197
197
|
chunk_masks = masks
|
|
198
|
+
assert chunk_masks.dtype == torch.bool
|
|
199
|
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
|
200
|
+
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
|
201
|
+
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
|
198
202
|
return chunk_masks
|
|
199
203
|
|
|
200
204
|
|
|
@@ -286,11 +286,15 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
|
|
286
286
|
# optimizer.step().
|
|
287
287
|
if torch.isfinite(grad_norm):
|
|
288
288
|
scaler.step(optimizer)
|
|
289
|
+
else:
|
|
290
|
+
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
|
289
291
|
scaler.update()
|
|
290
292
|
else:
|
|
291
293
|
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
|
292
294
|
if torch.isfinite(grad_norm):
|
|
293
295
|
optimizer.step()
|
|
296
|
+
else:
|
|
297
|
+
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
|
294
298
|
optimizer.zero_grad()
|
|
295
299
|
scheduler.step()
|
|
296
300
|
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
|
@@ -336,7 +340,7 @@ def log_per_save(writer, info_dict):
|
|
|
336
340
|
rank = int(os.environ.get('RANK', 0))
|
|
337
341
|
logging.info(
|
|
338
342
|
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
|
|
339
|
-
epoch, step + 1, lr, rank, ' '.join(['{}
|
|
343
|
+
epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
|
|
340
344
|
|
|
341
345
|
if writer is not None:
|
|
342
346
|
for k in ['epoch', 'lr']:
|