xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +22 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +91 -6
- xinference/client/restful/restful_client.py +39 -0
- xinference/core/model.py +41 -13
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +26 -4
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +71 -35
- xinference/model/audio/model_spec.json +88 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +322 -6
- xinference/model/embedding/model_spec.json +8 -1
- xinference/model/embedding/model_spec_modelscope.json +9 -1
- xinference/model/llm/__init__.py +4 -2
- xinference/model/llm/llm_family.json +479 -53
- xinference/model/llm/llm_family_modelscope.json +423 -17
- xinference/model/llm/mlx/core.py +230 -50
- xinference/model/llm/sglang/core.py +2 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +23 -1
- xinference/model/llm/vllm/core.py +89 -2
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/types.py +2 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ein notation:
|
|
3
|
+
b - batch
|
|
4
|
+
n - sequence
|
|
5
|
+
nt - text sequence
|
|
6
|
+
nw - raw wave length
|
|
7
|
+
d - dimension
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
from x_transformers.x_transformers import RotaryEmbedding
|
|
17
|
+
|
|
18
|
+
from f5_tts.model.modules import (
|
|
19
|
+
TimestepEmbedding,
|
|
20
|
+
ConvNeXtV2Block,
|
|
21
|
+
ConvPositionEmbedding,
|
|
22
|
+
DiTBlock,
|
|
23
|
+
AdaLayerNormZero_Final,
|
|
24
|
+
precompute_freqs_cis,
|
|
25
|
+
get_pos_embed_indices,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Text embedding
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TextEmbedding(nn.Module):
|
|
33
|
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
|
36
|
+
|
|
37
|
+
if conv_layers > 0:
|
|
38
|
+
self.extra_modeling = True
|
|
39
|
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
|
40
|
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
|
41
|
+
self.text_blocks = nn.Sequential(
|
|
42
|
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
self.extra_modeling = False
|
|
46
|
+
|
|
47
|
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
|
48
|
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
|
49
|
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
|
50
|
+
batch, text_len = text.shape[0], text.shape[1]
|
|
51
|
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
|
52
|
+
|
|
53
|
+
if drop_text: # cfg for text
|
|
54
|
+
text = torch.zeros_like(text)
|
|
55
|
+
|
|
56
|
+
text = self.text_embed(text) # b n -> b n d
|
|
57
|
+
|
|
58
|
+
# possible extra modeling
|
|
59
|
+
if self.extra_modeling:
|
|
60
|
+
# sinus pos emb
|
|
61
|
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
|
62
|
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
|
63
|
+
text_pos_embed = self.freqs_cis[pos_idx]
|
|
64
|
+
text = text + text_pos_embed
|
|
65
|
+
|
|
66
|
+
# convnextv2 blocks
|
|
67
|
+
text = self.text_blocks(text)
|
|
68
|
+
|
|
69
|
+
return text
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# noised input audio and context mixing embedding
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class InputEmbedding(nn.Module):
|
|
76
|
+
def __init__(self, mel_dim, text_dim, out_dim):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
|
79
|
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
|
80
|
+
|
|
81
|
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
|
82
|
+
if drop_audio_cond: # cfg for cond audio
|
|
83
|
+
cond = torch.zeros_like(cond)
|
|
84
|
+
|
|
85
|
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
|
86
|
+
x = self.conv_pos_embed(x) + x
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Transformer backbone using DiT blocks
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DiT(nn.Module):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
*,
|
|
97
|
+
dim,
|
|
98
|
+
depth=8,
|
|
99
|
+
heads=8,
|
|
100
|
+
dim_head=64,
|
|
101
|
+
dropout=0.1,
|
|
102
|
+
ff_mult=4,
|
|
103
|
+
mel_dim=100,
|
|
104
|
+
text_num_embeds=256,
|
|
105
|
+
text_dim=None,
|
|
106
|
+
conv_layers=0,
|
|
107
|
+
long_skip_connection=False,
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
110
|
+
|
|
111
|
+
self.time_embed = TimestepEmbedding(dim)
|
|
112
|
+
if text_dim is None:
|
|
113
|
+
text_dim = mel_dim
|
|
114
|
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
|
115
|
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
|
116
|
+
|
|
117
|
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
118
|
+
|
|
119
|
+
self.dim = dim
|
|
120
|
+
self.depth = depth
|
|
121
|
+
|
|
122
|
+
self.transformer_blocks = nn.ModuleList(
|
|
123
|
+
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
|
124
|
+
)
|
|
125
|
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
|
126
|
+
|
|
127
|
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
|
128
|
+
self.proj_out = nn.Linear(dim, mel_dim)
|
|
129
|
+
|
|
130
|
+
def forward(
|
|
131
|
+
self,
|
|
132
|
+
x: float["b n d"], # nosied input audio # noqa: F722
|
|
133
|
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
134
|
+
text: int["b nt"], # text # noqa: F722
|
|
135
|
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
|
136
|
+
drop_audio_cond, # cfg for cond audio
|
|
137
|
+
drop_text, # cfg for text
|
|
138
|
+
mask: bool["b n"] | None = None, # noqa: F722
|
|
139
|
+
):
|
|
140
|
+
batch, seq_len = x.shape[0], x.shape[1]
|
|
141
|
+
if time.ndim == 0:
|
|
142
|
+
time = time.repeat(batch)
|
|
143
|
+
|
|
144
|
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
|
145
|
+
t = self.time_embed(time)
|
|
146
|
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
|
147
|
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
|
148
|
+
|
|
149
|
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
|
150
|
+
|
|
151
|
+
if self.long_skip_connection is not None:
|
|
152
|
+
residual = x
|
|
153
|
+
|
|
154
|
+
for block in self.transformer_blocks:
|
|
155
|
+
x = block(x, t, mask=mask, rope=rope)
|
|
156
|
+
|
|
157
|
+
if self.long_skip_connection is not None:
|
|
158
|
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
|
159
|
+
|
|
160
|
+
x = self.norm_out(x, t)
|
|
161
|
+
output = self.proj_out(x)
|
|
162
|
+
|
|
163
|
+
return output
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ein notation:
|
|
3
|
+
b - batch
|
|
4
|
+
n - sequence
|
|
5
|
+
nt - text sequence
|
|
6
|
+
nw - raw wave length
|
|
7
|
+
d - dimension
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from x_transformers.x_transformers import RotaryEmbedding
|
|
16
|
+
|
|
17
|
+
from f5_tts.model.modules import (
|
|
18
|
+
TimestepEmbedding,
|
|
19
|
+
ConvPositionEmbedding,
|
|
20
|
+
MMDiTBlock,
|
|
21
|
+
AdaLayerNormZero_Final,
|
|
22
|
+
precompute_freqs_cis,
|
|
23
|
+
get_pos_embed_indices,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# text embedding
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TextEmbedding(nn.Module):
|
|
31
|
+
def __init__(self, out_dim, text_num_embeds):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
|
34
|
+
|
|
35
|
+
self.precompute_max_pos = 1024
|
|
36
|
+
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
|
37
|
+
|
|
38
|
+
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
|
39
|
+
text = text + 1
|
|
40
|
+
if drop_text:
|
|
41
|
+
text = torch.zeros_like(text)
|
|
42
|
+
text = self.text_embed(text)
|
|
43
|
+
|
|
44
|
+
# sinus pos emb
|
|
45
|
+
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
|
46
|
+
batch_text_len = text.shape[1]
|
|
47
|
+
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
|
|
48
|
+
text_pos_embed = self.freqs_cis[pos_idx]
|
|
49
|
+
|
|
50
|
+
text = text + text_pos_embed
|
|
51
|
+
|
|
52
|
+
return text
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# noised input & masked cond audio embedding
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AudioEmbedding(nn.Module):
|
|
59
|
+
def __init__(self, in_dim, out_dim):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.linear = nn.Linear(2 * in_dim, out_dim)
|
|
62
|
+
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
|
63
|
+
|
|
64
|
+
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
|
65
|
+
if drop_audio_cond:
|
|
66
|
+
cond = torch.zeros_like(cond)
|
|
67
|
+
x = torch.cat((x, cond), dim=-1)
|
|
68
|
+
x = self.linear(x)
|
|
69
|
+
x = self.conv_pos_embed(x) + x
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Transformer backbone using MM-DiT blocks
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class MMDiT(nn.Module):
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
*,
|
|
80
|
+
dim,
|
|
81
|
+
depth=8,
|
|
82
|
+
heads=8,
|
|
83
|
+
dim_head=64,
|
|
84
|
+
dropout=0.1,
|
|
85
|
+
ff_mult=4,
|
|
86
|
+
text_num_embeds=256,
|
|
87
|
+
mel_dim=100,
|
|
88
|
+
):
|
|
89
|
+
super().__init__()
|
|
90
|
+
|
|
91
|
+
self.time_embed = TimestepEmbedding(dim)
|
|
92
|
+
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
|
93
|
+
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
|
94
|
+
|
|
95
|
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
96
|
+
|
|
97
|
+
self.dim = dim
|
|
98
|
+
self.depth = depth
|
|
99
|
+
|
|
100
|
+
self.transformer_blocks = nn.ModuleList(
|
|
101
|
+
[
|
|
102
|
+
MMDiTBlock(
|
|
103
|
+
dim=dim,
|
|
104
|
+
heads=heads,
|
|
105
|
+
dim_head=dim_head,
|
|
106
|
+
dropout=dropout,
|
|
107
|
+
ff_mult=ff_mult,
|
|
108
|
+
context_pre_only=i == depth - 1,
|
|
109
|
+
)
|
|
110
|
+
for i in range(depth)
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
|
114
|
+
self.proj_out = nn.Linear(dim, mel_dim)
|
|
115
|
+
|
|
116
|
+
def forward(
|
|
117
|
+
self,
|
|
118
|
+
x: float["b n d"], # nosied input audio # noqa: F722
|
|
119
|
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
120
|
+
text: int["b nt"], # text # noqa: F722
|
|
121
|
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
|
122
|
+
drop_audio_cond, # cfg for cond audio
|
|
123
|
+
drop_text, # cfg for text
|
|
124
|
+
mask: bool["b n"] | None = None, # noqa: F722
|
|
125
|
+
):
|
|
126
|
+
batch = x.shape[0]
|
|
127
|
+
if time.ndim == 0:
|
|
128
|
+
time = time.repeat(batch)
|
|
129
|
+
|
|
130
|
+
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
|
131
|
+
t = self.time_embed(time)
|
|
132
|
+
c = self.text_embed(text, drop_text=drop_text)
|
|
133
|
+
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
|
134
|
+
|
|
135
|
+
seq_len = x.shape[1]
|
|
136
|
+
text_len = text.shape[1]
|
|
137
|
+
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
|
138
|
+
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
|
139
|
+
|
|
140
|
+
for block in self.transformer_blocks:
|
|
141
|
+
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
|
142
|
+
|
|
143
|
+
x = self.norm_out(x, t)
|
|
144
|
+
output = self.proj_out(x)
|
|
145
|
+
|
|
146
|
+
return output
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ein notation:
|
|
3
|
+
b - batch
|
|
4
|
+
n - sequence
|
|
5
|
+
nt - text sequence
|
|
6
|
+
nw - raw wave length
|
|
7
|
+
d - dimension
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torch import nn
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
|
|
17
|
+
from x_transformers import RMSNorm
|
|
18
|
+
from x_transformers.x_transformers import RotaryEmbedding
|
|
19
|
+
|
|
20
|
+
from f5_tts.model.modules import (
|
|
21
|
+
TimestepEmbedding,
|
|
22
|
+
ConvNeXtV2Block,
|
|
23
|
+
ConvPositionEmbedding,
|
|
24
|
+
Attention,
|
|
25
|
+
AttnProcessor,
|
|
26
|
+
FeedForward,
|
|
27
|
+
precompute_freqs_cis,
|
|
28
|
+
get_pos_embed_indices,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Text embedding
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TextEmbedding(nn.Module):
|
|
36
|
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
|
39
|
+
|
|
40
|
+
if conv_layers > 0:
|
|
41
|
+
self.extra_modeling = True
|
|
42
|
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
|
43
|
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
|
44
|
+
self.text_blocks = nn.Sequential(
|
|
45
|
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
self.extra_modeling = False
|
|
49
|
+
|
|
50
|
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
|
51
|
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
|
52
|
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
|
53
|
+
batch, text_len = text.shape[0], text.shape[1]
|
|
54
|
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
|
55
|
+
|
|
56
|
+
if drop_text: # cfg for text
|
|
57
|
+
text = torch.zeros_like(text)
|
|
58
|
+
|
|
59
|
+
text = self.text_embed(text) # b n -> b n d
|
|
60
|
+
|
|
61
|
+
# possible extra modeling
|
|
62
|
+
if self.extra_modeling:
|
|
63
|
+
# sinus pos emb
|
|
64
|
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
|
65
|
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
|
66
|
+
text_pos_embed = self.freqs_cis[pos_idx]
|
|
67
|
+
text = text + text_pos_embed
|
|
68
|
+
|
|
69
|
+
# convnextv2 blocks
|
|
70
|
+
text = self.text_blocks(text)
|
|
71
|
+
|
|
72
|
+
return text
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# noised input audio and context mixing embedding
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class InputEmbedding(nn.Module):
|
|
79
|
+
def __init__(self, mel_dim, text_dim, out_dim):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
|
82
|
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
|
83
|
+
|
|
84
|
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
|
85
|
+
if drop_audio_cond: # cfg for cond audio
|
|
86
|
+
cond = torch.zeros_like(cond)
|
|
87
|
+
|
|
88
|
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
|
89
|
+
x = self.conv_pos_embed(x) + x
|
|
90
|
+
return x
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# Flat UNet Transformer backbone
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class UNetT(nn.Module):
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
dim,
|
|
101
|
+
depth=8,
|
|
102
|
+
heads=8,
|
|
103
|
+
dim_head=64,
|
|
104
|
+
dropout=0.1,
|
|
105
|
+
ff_mult=4,
|
|
106
|
+
mel_dim=100,
|
|
107
|
+
text_num_embeds=256,
|
|
108
|
+
text_dim=None,
|
|
109
|
+
conv_layers=0,
|
|
110
|
+
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
|
111
|
+
):
|
|
112
|
+
super().__init__()
|
|
113
|
+
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
|
114
|
+
|
|
115
|
+
self.time_embed = TimestepEmbedding(dim)
|
|
116
|
+
if text_dim is None:
|
|
117
|
+
text_dim = mel_dim
|
|
118
|
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
|
119
|
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
|
120
|
+
|
|
121
|
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
122
|
+
|
|
123
|
+
# transformer layers & skip connections
|
|
124
|
+
|
|
125
|
+
self.dim = dim
|
|
126
|
+
self.skip_connect_type = skip_connect_type
|
|
127
|
+
needs_skip_proj = skip_connect_type == "concat"
|
|
128
|
+
|
|
129
|
+
self.depth = depth
|
|
130
|
+
self.layers = nn.ModuleList([])
|
|
131
|
+
|
|
132
|
+
for idx in range(depth):
|
|
133
|
+
is_later_half = idx >= (depth // 2)
|
|
134
|
+
|
|
135
|
+
attn_norm = RMSNorm(dim)
|
|
136
|
+
attn = Attention(
|
|
137
|
+
processor=AttnProcessor(),
|
|
138
|
+
dim=dim,
|
|
139
|
+
heads=heads,
|
|
140
|
+
dim_head=dim_head,
|
|
141
|
+
dropout=dropout,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
ff_norm = RMSNorm(dim)
|
|
145
|
+
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
|
146
|
+
|
|
147
|
+
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
|
|
148
|
+
|
|
149
|
+
self.layers.append(
|
|
150
|
+
nn.ModuleList(
|
|
151
|
+
[
|
|
152
|
+
skip_proj,
|
|
153
|
+
attn_norm,
|
|
154
|
+
attn,
|
|
155
|
+
ff_norm,
|
|
156
|
+
ff,
|
|
157
|
+
]
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self.norm_out = RMSNorm(dim)
|
|
162
|
+
self.proj_out = nn.Linear(dim, mel_dim)
|
|
163
|
+
|
|
164
|
+
def forward(
|
|
165
|
+
self,
|
|
166
|
+
x: float["b n d"], # nosied input audio # noqa: F722
|
|
167
|
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
168
|
+
text: int["b nt"], # text # noqa: F722
|
|
169
|
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
|
170
|
+
drop_audio_cond, # cfg for cond audio
|
|
171
|
+
drop_text, # cfg for text
|
|
172
|
+
mask: bool["b n"] | None = None, # noqa: F722
|
|
173
|
+
):
|
|
174
|
+
batch, seq_len = x.shape[0], x.shape[1]
|
|
175
|
+
if time.ndim == 0:
|
|
176
|
+
time = time.repeat(batch)
|
|
177
|
+
|
|
178
|
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
|
179
|
+
t = self.time_embed(time)
|
|
180
|
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
|
181
|
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
|
182
|
+
|
|
183
|
+
# postfix time t to input x, [b n d] -> [b n+1 d]
|
|
184
|
+
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
|
185
|
+
if mask is not None:
|
|
186
|
+
mask = F.pad(mask, (1, 0), value=1)
|
|
187
|
+
|
|
188
|
+
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
|
189
|
+
|
|
190
|
+
# flat unet transformer
|
|
191
|
+
skip_connect_type = self.skip_connect_type
|
|
192
|
+
skips = []
|
|
193
|
+
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
|
|
194
|
+
layer = idx + 1
|
|
195
|
+
|
|
196
|
+
# skip connection logic
|
|
197
|
+
is_first_half = layer <= (self.depth // 2)
|
|
198
|
+
is_later_half = not is_first_half
|
|
199
|
+
|
|
200
|
+
if is_first_half:
|
|
201
|
+
skips.append(x)
|
|
202
|
+
|
|
203
|
+
if is_later_half:
|
|
204
|
+
skip = skips.pop()
|
|
205
|
+
if skip_connect_type == "concat":
|
|
206
|
+
x = torch.cat((x, skip), dim=-1)
|
|
207
|
+
x = maybe_skip_proj(x)
|
|
208
|
+
elif skip_connect_type == "add":
|
|
209
|
+
x = x + skip
|
|
210
|
+
|
|
211
|
+
# attention and feedforward blocks
|
|
212
|
+
x = attn(attn_norm(x), rope=rope, mask=mask) + x
|
|
213
|
+
x = ff(ff_norm(x)) + x
|
|
214
|
+
|
|
215
|
+
assert len(skips) == 0
|
|
216
|
+
|
|
217
|
+
x = self.norm_out(x)[:, 1:, :] # unpack t from x
|
|
218
|
+
|
|
219
|
+
return self.proj_out(x)
|