xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module):
|
|
|
49
49
|
olens = ylens
|
|
50
50
|
return out * mask, olens
|
|
51
51
|
|
|
52
|
-
def inference(self, x1, x2, mel_len1, mel_len2):
|
|
52
|
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
|
53
53
|
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
|
54
54
|
# x in (B, T, D)
|
|
55
55
|
if x2.shape[1] > 40:
|
|
56
|
-
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=
|
|
57
|
-
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 -
|
|
58
|
-
|
|
56
|
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
|
57
|
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
|
58
|
+
mode='linear')
|
|
59
|
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
|
59
60
|
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
|
60
61
|
else:
|
|
61
62
|
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.nn.utils import weight_norm
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from torchaudio.transforms import Spectrogram
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MultipleDiscriminator(nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self, mpd: nn.Module, mrd: nn.Module
|
|
12
|
+
):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.mpd = mpd
|
|
15
|
+
self.mrd = mrd
|
|
16
|
+
|
|
17
|
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
|
18
|
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
|
19
|
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
|
20
|
+
y_d_rs += this_y_d_rs
|
|
21
|
+
y_d_gs += this_y_d_gs
|
|
22
|
+
fmap_rs += this_fmap_rs
|
|
23
|
+
fmap_gs += this_fmap_gs
|
|
24
|
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
|
25
|
+
y_d_rs += this_y_d_rs
|
|
26
|
+
y_d_gs += this_y_d_gs
|
|
27
|
+
fmap_rs += this_fmap_rs
|
|
28
|
+
fmap_gs += this_fmap_gs
|
|
29
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MultiResolutionDiscriminator(nn.Module):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
|
36
|
+
num_embeddings: Optional[int] = None,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
|
40
|
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
|
44
|
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
|
45
|
+
Defaults to None.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.discriminators = nn.ModuleList(
|
|
50
|
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def forward(
|
|
54
|
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
|
55
|
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
|
56
|
+
y_d_rs = []
|
|
57
|
+
y_d_gs = []
|
|
58
|
+
fmap_rs = []
|
|
59
|
+
fmap_gs = []
|
|
60
|
+
|
|
61
|
+
for d in self.discriminators:
|
|
62
|
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
|
63
|
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
|
64
|
+
y_d_rs.append(y_d_r)
|
|
65
|
+
fmap_rs.append(fmap_r)
|
|
66
|
+
y_d_gs.append(y_d_g)
|
|
67
|
+
fmap_gs.append(fmap_g)
|
|
68
|
+
|
|
69
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DiscriminatorR(nn.Module):
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
window_length: int,
|
|
76
|
+
num_embeddings: Optional[int] = None,
|
|
77
|
+
channels: int = 32,
|
|
78
|
+
hop_factor: float = 0.25,
|
|
79
|
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
|
80
|
+
):
|
|
81
|
+
super().__init__()
|
|
82
|
+
self.window_length = window_length
|
|
83
|
+
self.hop_factor = hop_factor
|
|
84
|
+
self.spec_fn = Spectrogram(
|
|
85
|
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
|
86
|
+
)
|
|
87
|
+
n_fft = window_length // 2 + 1
|
|
88
|
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
|
89
|
+
self.bands = bands
|
|
90
|
+
convs = lambda: nn.ModuleList(
|
|
91
|
+
[
|
|
92
|
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
|
93
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
|
94
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
|
95
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
|
96
|
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
|
97
|
+
]
|
|
98
|
+
)
|
|
99
|
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
|
100
|
+
|
|
101
|
+
if num_embeddings is not None:
|
|
102
|
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
|
103
|
+
torch.nn.init.zeros_(self.emb.weight)
|
|
104
|
+
|
|
105
|
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
|
106
|
+
|
|
107
|
+
def spectrogram(self, x):
|
|
108
|
+
# Remove DC offset
|
|
109
|
+
x = x - x.mean(dim=-1, keepdims=True)
|
|
110
|
+
# Peak normalize the volume of input audio
|
|
111
|
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
|
112
|
+
x = self.spec_fn(x)
|
|
113
|
+
x = torch.view_as_real(x)
|
|
114
|
+
x = rearrange(x, "b f t c -> b c t f")
|
|
115
|
+
# Split into bands
|
|
116
|
+
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
|
117
|
+
return x_bands
|
|
118
|
+
|
|
119
|
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
|
120
|
+
x_bands = self.spectrogram(x)
|
|
121
|
+
fmap = []
|
|
122
|
+
x = []
|
|
123
|
+
for band, stack in zip(x_bands, self.band_convs):
|
|
124
|
+
for i, layer in enumerate(stack):
|
|
125
|
+
band = layer(band)
|
|
126
|
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
|
127
|
+
if i > 0:
|
|
128
|
+
fmap.append(band)
|
|
129
|
+
x.append(band)
|
|
130
|
+
x = torch.cat(x, dim=-1)
|
|
131
|
+
if cond_embedding_id is not None:
|
|
132
|
+
emb = self.emb(cond_embedding_id)
|
|
133
|
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
|
134
|
+
else:
|
|
135
|
+
h = 0
|
|
136
|
+
x = self.conv_post(x)
|
|
137
|
+
fmap.append(x)
|
|
138
|
+
x += h
|
|
139
|
+
|
|
140
|
+
return x, fmap
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
"""HIFI-GAN"""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
from typing import Dict, Optional, List
|
|
18
18
|
import numpy as np
|
|
19
19
|
from scipy.signal import get_window
|
|
20
20
|
import torch
|
|
@@ -38,13 +38,15 @@ This code is modified from https://github.com/jik876/hifi-gan
|
|
|
38
38
|
https://github.com/NVIDIA/BigVGAN
|
|
39
39
|
|
|
40
40
|
"""
|
|
41
|
+
|
|
42
|
+
|
|
41
43
|
class ResBlock(torch.nn.Module):
|
|
42
44
|
"""Residual block module in HiFiGAN/BigVGAN."""
|
|
43
45
|
def __init__(
|
|
44
46
|
self,
|
|
45
47
|
channels: int = 512,
|
|
46
48
|
kernel_size: int = 3,
|
|
47
|
-
dilations:
|
|
49
|
+
dilations: List[int] = [1, 3, 5],
|
|
48
50
|
):
|
|
49
51
|
super(ResBlock, self).__init__()
|
|
50
52
|
self.convs1 = nn.ModuleList()
|
|
@@ -100,6 +102,7 @@ class ResBlock(torch.nn.Module):
|
|
|
100
102
|
remove_weight_norm(self.convs1[idx])
|
|
101
103
|
remove_weight_norm(self.convs2[idx])
|
|
102
104
|
|
|
105
|
+
|
|
103
106
|
class SineGen(torch.nn.Module):
|
|
104
107
|
""" Definition of sine generator
|
|
105
108
|
SineGen(samp_rate, harmonic_num = 0,
|
|
@@ -231,13 +234,13 @@ class HiFTGenerator(nn.Module):
|
|
|
231
234
|
nsf_alpha: float = 0.1,
|
|
232
235
|
nsf_sigma: float = 0.003,
|
|
233
236
|
nsf_voiced_threshold: float = 10,
|
|
234
|
-
upsample_rates:
|
|
235
|
-
upsample_kernel_sizes:
|
|
236
|
-
istft_params:
|
|
237
|
-
resblock_kernel_sizes:
|
|
238
|
-
resblock_dilation_sizes:
|
|
239
|
-
source_resblock_kernel_sizes:
|
|
240
|
-
source_resblock_dilation_sizes:
|
|
237
|
+
upsample_rates: List[int] = [8, 8],
|
|
238
|
+
upsample_kernel_sizes: List[int] = [16, 16],
|
|
239
|
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
|
240
|
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
|
241
|
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
242
|
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
|
243
|
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
|
241
244
|
lrelu_slope: float = 0.1,
|
|
242
245
|
audio_limit: float = 0.99,
|
|
243
246
|
f0_predictor: torch.nn.Module = None,
|
|
@@ -286,8 +289,7 @@ class HiFTGenerator(nn.Module):
|
|
|
286
289
|
self.source_resblocks = nn.ModuleList()
|
|
287
290
|
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
|
288
291
|
downsample_cum_rates = np.cumprod(downsample_rates)
|
|
289
|
-
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
|
290
|
-
source_resblock_dilation_sizes)):
|
|
292
|
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
|
291
293
|
if u == 1:
|
|
292
294
|
self.source_downs.append(
|
|
293
295
|
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
|
@@ -304,7 +306,7 @@ class HiFTGenerator(nn.Module):
|
|
|
304
306
|
self.resblocks = nn.ModuleList()
|
|
305
307
|
for i in range(len(self.ups)):
|
|
306
308
|
ch = base_channels // (2**(i + 1))
|
|
307
|
-
for
|
|
309
|
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
|
308
310
|
self.resblocks.append(ResBlock(ch, k, d))
|
|
309
311
|
|
|
310
312
|
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
|
@@ -314,11 +316,19 @@ class HiFTGenerator(nn.Module):
|
|
|
314
316
|
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
|
315
317
|
self.f0_predictor = f0_predictor
|
|
316
318
|
|
|
317
|
-
def
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
319
|
+
def remove_weight_norm(self):
|
|
320
|
+
print('Removing weight norm...')
|
|
321
|
+
for l in self.ups:
|
|
322
|
+
remove_weight_norm(l)
|
|
323
|
+
for l in self.resblocks:
|
|
324
|
+
l.remove_weight_norm()
|
|
325
|
+
remove_weight_norm(self.conv_pre)
|
|
326
|
+
remove_weight_norm(self.conv_post)
|
|
327
|
+
self.m_source.remove_weight_norm()
|
|
328
|
+
for l in self.source_downs:
|
|
329
|
+
remove_weight_norm(l)
|
|
330
|
+
for l in self.source_resblocks:
|
|
331
|
+
l.remove_weight_norm()
|
|
322
332
|
|
|
323
333
|
def _stft(self, x):
|
|
324
334
|
spec = torch.stft(
|
|
@@ -332,17 +342,11 @@ class HiFTGenerator(nn.Module):
|
|
|
332
342
|
magnitude = torch.clip(magnitude, max=1e2)
|
|
333
343
|
real = magnitude * torch.cos(phase)
|
|
334
344
|
img = magnitude * torch.sin(phase)
|
|
335
|
-
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
|
345
|
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
|
346
|
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
|
336
347
|
return inverse_transform
|
|
337
348
|
|
|
338
|
-
def
|
|
339
|
-
f0 = self.f0_predictor(x)
|
|
340
|
-
s = self._f02source(f0)
|
|
341
|
-
|
|
342
|
-
# use cache_source to avoid glitch
|
|
343
|
-
if cache_source.shape[2] == 0:
|
|
344
|
-
s[:, :, :cache_source.shape[2]] = cache_source
|
|
345
|
-
|
|
349
|
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
346
350
|
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
|
347
351
|
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
|
348
352
|
|
|
@@ -374,22 +378,34 @@ class HiFTGenerator(nn.Module):
|
|
|
374
378
|
|
|
375
379
|
x = self._istft(magnitude, phase)
|
|
376
380
|
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
|
377
|
-
return x
|
|
381
|
+
return x
|
|
378
382
|
|
|
379
|
-
def
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
383
|
+
def forward(
|
|
384
|
+
self,
|
|
385
|
+
batch: dict,
|
|
386
|
+
device: torch.device,
|
|
387
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
388
|
+
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
|
389
|
+
# mel->f0
|
|
390
|
+
f0 = self.f0_predictor(speech_feat)
|
|
391
|
+
# f0->source
|
|
392
|
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
|
393
|
+
s, _, _ = self.m_source(s)
|
|
394
|
+
s = s.transpose(1, 2)
|
|
395
|
+
# mel+source->speech
|
|
396
|
+
generated_speech = self.decode(x=speech_feat, s=s)
|
|
397
|
+
return generated_speech, f0
|
|
392
398
|
|
|
393
399
|
@torch.inference_mode()
|
|
394
|
-
def inference(self,
|
|
395
|
-
|
|
400
|
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
401
|
+
# mel->f0
|
|
402
|
+
f0 = self.f0_predictor(speech_feat)
|
|
403
|
+
# f0->source
|
|
404
|
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
|
405
|
+
s, _, _ = self.m_source(s)
|
|
406
|
+
s = s.transpose(1, 2)
|
|
407
|
+
# use cache_source to avoid glitch
|
|
408
|
+
if cache_source.shape[2] != 0:
|
|
409
|
+
s[:, :, :cache_source.shape[2]] = cache_source
|
|
410
|
+
generated_speech = self.decode(x=speech_feat, s=s)
|
|
411
|
+
return generated_speech, s
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
|
6
|
+
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HiFiGan(nn.Module):
|
|
10
|
+
def __init__(self, generator, discriminator, mel_spec_transform,
|
|
11
|
+
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
|
12
|
+
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
|
13
|
+
super(HiFiGan, self).__init__()
|
|
14
|
+
self.generator = generator
|
|
15
|
+
self.discriminator = discriminator
|
|
16
|
+
self.mel_spec_transform = mel_spec_transform
|
|
17
|
+
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
|
18
|
+
self.feat_match_loss_weight = feat_match_loss_weight
|
|
19
|
+
self.tpr_loss_weight = tpr_loss_weight
|
|
20
|
+
self.tpr_loss_tau = tpr_loss_tau
|
|
21
|
+
|
|
22
|
+
def forward(
|
|
23
|
+
self,
|
|
24
|
+
batch: dict,
|
|
25
|
+
device: torch.device,
|
|
26
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
27
|
+
if batch['turn'] == 'generator':
|
|
28
|
+
return self.forward_generator(batch, device)
|
|
29
|
+
else:
|
|
30
|
+
return self.forward_discriminator(batch, device)
|
|
31
|
+
|
|
32
|
+
def forward_generator(self, batch, device):
|
|
33
|
+
real_speech = batch['speech'].to(device)
|
|
34
|
+
pitch_feat = batch['pitch_feat'].to(device)
|
|
35
|
+
# 1. calculate generator outputs
|
|
36
|
+
generated_speech, generated_f0 = self.generator(batch, device)
|
|
37
|
+
# 2. calculate discriminator outputs
|
|
38
|
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
|
39
|
+
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
|
40
|
+
loss_gen, _ = generator_loss(y_d_gs)
|
|
41
|
+
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
|
42
|
+
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
|
43
|
+
if self.tpr_loss_weight != 0:
|
|
44
|
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
|
45
|
+
else:
|
|
46
|
+
loss_tpr = torch.zeros(1).to(device)
|
|
47
|
+
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
|
48
|
+
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
|
49
|
+
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
|
50
|
+
self.tpr_loss_weight * loss_tpr + loss_f0
|
|
51
|
+
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
|
52
|
+
|
|
53
|
+
def forward_discriminator(self, batch, device):
|
|
54
|
+
real_speech = batch['speech'].to(device)
|
|
55
|
+
# 1. calculate generator outputs
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
generated_speech, generated_f0 = self.generator(batch, device)
|
|
58
|
+
# 2. calculate discriminator outputs
|
|
59
|
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
|
60
|
+
# 3. calculate discriminator losses, tpr losses [Optional]
|
|
61
|
+
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
|
62
|
+
if self.tpr_loss_weight != 0:
|
|
63
|
+
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
|
64
|
+
else:
|
|
65
|
+
loss_tpr = torch.zeros(1).to(device)
|
|
66
|
+
loss = loss_disc + self.tpr_loss_weight * loss_tpr
|
|
67
|
+
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|
|
@@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator
|
|
|
15
15
|
import torch
|
|
16
16
|
from torch import nn
|
|
17
17
|
import torch.nn.functional as F
|
|
18
|
+
from transformers import Qwen2ForCausalLM
|
|
18
19
|
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
|
19
20
|
from cosyvoice.utils.common import IGNORE_ID
|
|
20
21
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
|
@@ -80,7 +81,8 @@ class TransformerLM(torch.nn.Module):
|
|
|
80
81
|
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
|
81
82
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
|
82
83
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
|
83
|
-
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
|
84
|
+
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
|
85
|
+
for i in range(len(text_token))]
|
|
84
86
|
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
|
85
87
|
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
|
86
88
|
return lm_input, lm_input_len
|
|
@@ -104,7 +106,8 @@ class TransformerLM(torch.nn.Module):
|
|
|
104
106
|
embedding = batch['embedding'].to(device)
|
|
105
107
|
|
|
106
108
|
# 1. prepare llm_target
|
|
107
|
-
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
|
109
|
+
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
|
110
|
+
[self.speech_token_size]) for i in range(text_token.size(0))]
|
|
108
111
|
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
|
109
112
|
|
|
110
113
|
# 1. encode text_token
|
|
@@ -124,7 +127,8 @@ class TransformerLM(torch.nn.Module):
|
|
|
124
127
|
speech_token = self.speech_embedding(speech_token)
|
|
125
128
|
|
|
126
129
|
# 5. unpad and pad
|
|
127
|
-
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
|
130
|
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
|
131
|
+
task_id_emb, speech_token, speech_token_len)
|
|
128
132
|
|
|
129
133
|
# 6. run lm forward
|
|
130
134
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
|
@@ -194,14 +198,143 @@ class TransformerLM(torch.nn.Module):
|
|
|
194
198
|
offset = 0
|
|
195
199
|
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
|
196
200
|
for i in range(max_len):
|
|
197
|
-
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=
|
|
198
|
-
|
|
201
|
+
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
|
202
|
+
att_cache=att_cache, cnn_cache=cnn_cache,
|
|
203
|
+
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
|
204
|
+
device=lm_input.device)).to(torch.bool))
|
|
199
205
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
206
|
+
# force continue decode first token
|
|
207
|
+
if i == 0:
|
|
208
|
+
logp[:, self.speech_token_size] = -float('inf')
|
|
200
209
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
|
201
210
|
if top_ids == self.speech_token_size:
|
|
202
211
|
break
|
|
203
212
|
# in stream mode, yield token one by one
|
|
204
|
-
yield
|
|
213
|
+
yield top_ids
|
|
205
214
|
out_tokens.append(top_ids)
|
|
206
215
|
offset += lm_input.size(1)
|
|
207
216
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class Qwen2Encoder(torch.nn.Module):
|
|
220
|
+
def __init__(self, pretrain_path):
|
|
221
|
+
super().__init__()
|
|
222
|
+
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
|
223
|
+
|
|
224
|
+
def forward_one_step(self, xs, masks, cache=None):
|
|
225
|
+
input_masks = masks[:, -1, :]
|
|
226
|
+
outs = self.model(
|
|
227
|
+
inputs_embeds=xs,
|
|
228
|
+
attention_mask=input_masks,
|
|
229
|
+
output_hidden_states=True,
|
|
230
|
+
return_dict=True,
|
|
231
|
+
use_cache=True,
|
|
232
|
+
past_key_values=cache,
|
|
233
|
+
)
|
|
234
|
+
xs = outs.hidden_states[-1]
|
|
235
|
+
new_cache = outs.past_key_values
|
|
236
|
+
return xs, new_cache
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class Qwen2LM(torch.nn.Module):
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
llm_input_size: int,
|
|
243
|
+
llm_output_size: int,
|
|
244
|
+
speech_token_size: int,
|
|
245
|
+
llm: torch.nn.Module,
|
|
246
|
+
sampling: Callable,
|
|
247
|
+
length_normalized_loss: bool = True,
|
|
248
|
+
lsm_weight: float = 0.0,
|
|
249
|
+
):
|
|
250
|
+
super().__init__()
|
|
251
|
+
self.llm_input_size = llm_input_size
|
|
252
|
+
self.llm_output_size = llm_output_size
|
|
253
|
+
self.speech_token_size = speech_token_size
|
|
254
|
+
|
|
255
|
+
# 2. build speech token language model related modules
|
|
256
|
+
self.sos_eos = 0
|
|
257
|
+
self.task_id = 1
|
|
258
|
+
self.fill_token = 2
|
|
259
|
+
|
|
260
|
+
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
|
261
|
+
self.llm = llm
|
|
262
|
+
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
|
263
|
+
self.criterion_ce = LabelSmoothingLoss(
|
|
264
|
+
size=speech_token_size + 3,
|
|
265
|
+
padding_idx=IGNORE_ID,
|
|
266
|
+
smoothing=lsm_weight,
|
|
267
|
+
normalize_length=length_normalized_loss,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# 3. [Optional] build speech token related modules
|
|
271
|
+
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
|
272
|
+
|
|
273
|
+
# 4. sampling method
|
|
274
|
+
self.sampling = sampling
|
|
275
|
+
|
|
276
|
+
def sampling_ids(
|
|
277
|
+
self,
|
|
278
|
+
weighted_scores: torch.Tensor,
|
|
279
|
+
decoded_tokens: List,
|
|
280
|
+
sampling: int,
|
|
281
|
+
ignore_eos: bool = True,
|
|
282
|
+
):
|
|
283
|
+
while True:
|
|
284
|
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
285
|
+
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
286
|
+
break
|
|
287
|
+
return top_ids
|
|
288
|
+
|
|
289
|
+
@torch.inference_mode()
|
|
290
|
+
def inference(
|
|
291
|
+
self,
|
|
292
|
+
text: torch.Tensor,
|
|
293
|
+
text_len: torch.Tensor,
|
|
294
|
+
prompt_text: torch.Tensor,
|
|
295
|
+
prompt_text_len: torch.Tensor,
|
|
296
|
+
prompt_speech_token: torch.Tensor,
|
|
297
|
+
prompt_speech_token_len: torch.Tensor,
|
|
298
|
+
embedding: torch.Tensor,
|
|
299
|
+
sampling: int = 25,
|
|
300
|
+
max_token_text_ratio: float = 20,
|
|
301
|
+
min_token_text_ratio: float = 2,
|
|
302
|
+
) -> Generator[torch.Tensor, None, None]:
|
|
303
|
+
device = text.device
|
|
304
|
+
text = torch.concat([prompt_text, text], dim=1)
|
|
305
|
+
text_len += prompt_text_len
|
|
306
|
+
text = self.llm.model.model.embed_tokens(text)
|
|
307
|
+
|
|
308
|
+
# 2. encode embedding
|
|
309
|
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
310
|
+
|
|
311
|
+
# 3. concat llm_input
|
|
312
|
+
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
313
|
+
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
314
|
+
if prompt_speech_token_len != 0:
|
|
315
|
+
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
316
|
+
else:
|
|
317
|
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
318
|
+
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
319
|
+
|
|
320
|
+
# 4. cal min/max_length
|
|
321
|
+
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
|
322
|
+
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
|
323
|
+
|
|
324
|
+
# 5. step by step decode
|
|
325
|
+
out_tokens = []
|
|
326
|
+
cache = None
|
|
327
|
+
for i in range(max_len):
|
|
328
|
+
y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
329
|
+
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
|
330
|
+
cache=cache)
|
|
331
|
+
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
332
|
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
|
333
|
+
if top_ids == self.speech_token_size:
|
|
334
|
+
break
|
|
335
|
+
if top_ids > self.speech_token_size:
|
|
336
|
+
continue
|
|
337
|
+
# in stream mode, yield token one by one
|
|
338
|
+
yield top_ids
|
|
339
|
+
out_tokens.append(top_ids)
|
|
340
|
+
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|