xinference 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -5
- xinference/core/model.py +6 -1
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +2 -1
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +172 -53
- xinference/model/llm/llm_family_modelscope.json +118 -20
- xinference/model/llm/mlx/core.py +230 -49
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +4 -1
- xinference/model/llm/vllm/core.py +5 -0
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/METADATA +33 -14
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/RECORD +85 -34
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,658 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ein notation:
|
|
3
|
+
b - batch
|
|
4
|
+
n - sequence
|
|
5
|
+
nt - text sequence
|
|
6
|
+
nw - raw wave length
|
|
7
|
+
d - dimension
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import math
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
import torchaudio
|
|
18
|
+
from librosa.filters import mel as librosa_mel_fn
|
|
19
|
+
from torch import nn
|
|
20
|
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# raw wav to mel spec
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
mel_basis_cache = {}
|
|
27
|
+
hann_window_cache = {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_bigvgan_mel_spectrogram(
|
|
31
|
+
waveform,
|
|
32
|
+
n_fft=1024,
|
|
33
|
+
n_mel_channels=100,
|
|
34
|
+
target_sample_rate=24000,
|
|
35
|
+
hop_length=256,
|
|
36
|
+
win_length=1024,
|
|
37
|
+
fmin=0,
|
|
38
|
+
fmax=None,
|
|
39
|
+
center=False,
|
|
40
|
+
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
|
41
|
+
device = waveform.device
|
|
42
|
+
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
|
43
|
+
|
|
44
|
+
if key not in mel_basis_cache:
|
|
45
|
+
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
|
|
46
|
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
|
47
|
+
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
|
48
|
+
|
|
49
|
+
mel_basis = mel_basis_cache[key]
|
|
50
|
+
hann_window = hann_window_cache[key]
|
|
51
|
+
|
|
52
|
+
padding = (n_fft - hop_length) // 2
|
|
53
|
+
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
|
54
|
+
|
|
55
|
+
spec = torch.stft(
|
|
56
|
+
waveform,
|
|
57
|
+
n_fft,
|
|
58
|
+
hop_length=hop_length,
|
|
59
|
+
win_length=win_length,
|
|
60
|
+
window=hann_window,
|
|
61
|
+
center=center,
|
|
62
|
+
pad_mode="reflect",
|
|
63
|
+
normalized=False,
|
|
64
|
+
onesided=True,
|
|
65
|
+
return_complex=True,
|
|
66
|
+
)
|
|
67
|
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
|
68
|
+
|
|
69
|
+
mel_spec = torch.matmul(mel_basis, spec)
|
|
70
|
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
|
71
|
+
|
|
72
|
+
return mel_spec
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_vocos_mel_spectrogram(
|
|
76
|
+
waveform,
|
|
77
|
+
n_fft=1024,
|
|
78
|
+
n_mel_channels=100,
|
|
79
|
+
target_sample_rate=24000,
|
|
80
|
+
hop_length=256,
|
|
81
|
+
win_length=1024,
|
|
82
|
+
):
|
|
83
|
+
mel_stft = torchaudio.transforms.MelSpectrogram(
|
|
84
|
+
sample_rate=target_sample_rate,
|
|
85
|
+
n_fft=n_fft,
|
|
86
|
+
win_length=win_length,
|
|
87
|
+
hop_length=hop_length,
|
|
88
|
+
n_mels=n_mel_channels,
|
|
89
|
+
power=1,
|
|
90
|
+
center=True,
|
|
91
|
+
normalized=False,
|
|
92
|
+
norm=None,
|
|
93
|
+
).to(waveform.device)
|
|
94
|
+
if len(waveform.shape) == 3:
|
|
95
|
+
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
|
96
|
+
|
|
97
|
+
assert len(waveform.shape) == 2
|
|
98
|
+
|
|
99
|
+
mel = mel_stft(waveform)
|
|
100
|
+
mel = mel.clamp(min=1e-5).log()
|
|
101
|
+
return mel
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class MelSpec(nn.Module):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
n_fft=1024,
|
|
108
|
+
hop_length=256,
|
|
109
|
+
win_length=1024,
|
|
110
|
+
n_mel_channels=100,
|
|
111
|
+
target_sample_rate=24_000,
|
|
112
|
+
mel_spec_type="vocos",
|
|
113
|
+
):
|
|
114
|
+
super().__init__()
|
|
115
|
+
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
|
|
116
|
+
|
|
117
|
+
self.n_fft = n_fft
|
|
118
|
+
self.hop_length = hop_length
|
|
119
|
+
self.win_length = win_length
|
|
120
|
+
self.n_mel_channels = n_mel_channels
|
|
121
|
+
self.target_sample_rate = target_sample_rate
|
|
122
|
+
|
|
123
|
+
if mel_spec_type == "vocos":
|
|
124
|
+
self.extractor = get_vocos_mel_spectrogram
|
|
125
|
+
elif mel_spec_type == "bigvgan":
|
|
126
|
+
self.extractor = get_bigvgan_mel_spectrogram
|
|
127
|
+
|
|
128
|
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
|
129
|
+
|
|
130
|
+
def forward(self, wav):
|
|
131
|
+
if self.dummy.device != wav.device:
|
|
132
|
+
self.to(wav.device)
|
|
133
|
+
|
|
134
|
+
mel = self.extractor(
|
|
135
|
+
waveform=wav,
|
|
136
|
+
n_fft=self.n_fft,
|
|
137
|
+
n_mel_channels=self.n_mel_channels,
|
|
138
|
+
target_sample_rate=self.target_sample_rate,
|
|
139
|
+
hop_length=self.hop_length,
|
|
140
|
+
win_length=self.win_length,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return mel
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# sinusoidal position embedding
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SinusPositionEmbedding(nn.Module):
|
|
150
|
+
def __init__(self, dim):
|
|
151
|
+
super().__init__()
|
|
152
|
+
self.dim = dim
|
|
153
|
+
|
|
154
|
+
def forward(self, x, scale=1000):
|
|
155
|
+
device = x.device
|
|
156
|
+
half_dim = self.dim // 2
|
|
157
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
158
|
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
|
159
|
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
|
160
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
161
|
+
return emb
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# convolutional position embedding
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class ConvPositionEmbedding(nn.Module):
|
|
168
|
+
def __init__(self, dim, kernel_size=31, groups=16):
|
|
169
|
+
super().__init__()
|
|
170
|
+
assert kernel_size % 2 != 0
|
|
171
|
+
self.conv1d = nn.Sequential(
|
|
172
|
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
|
173
|
+
nn.Mish(),
|
|
174
|
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
|
175
|
+
nn.Mish(),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
|
179
|
+
if mask is not None:
|
|
180
|
+
mask = mask[..., None]
|
|
181
|
+
x = x.masked_fill(~mask, 0.0)
|
|
182
|
+
|
|
183
|
+
x = x.permute(0, 2, 1)
|
|
184
|
+
x = self.conv1d(x)
|
|
185
|
+
out = x.permute(0, 2, 1)
|
|
186
|
+
|
|
187
|
+
if mask is not None:
|
|
188
|
+
out = out.masked_fill(~mask, 0.0)
|
|
189
|
+
|
|
190
|
+
return out
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# rotary positional embedding related
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
|
197
|
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
198
|
+
# has some connection to NTK literature
|
|
199
|
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
200
|
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
|
201
|
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
202
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
203
|
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
|
204
|
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
205
|
+
freqs_cos = torch.cos(freqs) # real part
|
|
206
|
+
freqs_sin = torch.sin(freqs) # imaginary part
|
|
207
|
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
|
211
|
+
# length = length if isinstance(length, int) else length.max()
|
|
212
|
+
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
|
213
|
+
pos = (
|
|
214
|
+
start.unsqueeze(1)
|
|
215
|
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
|
216
|
+
)
|
|
217
|
+
# avoid extra long error.
|
|
218
|
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
|
219
|
+
return pos
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# Global Response Normalization layer (Instance Normalization ?)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class GRN(nn.Module):
|
|
226
|
+
def __init__(self, dim):
|
|
227
|
+
super().__init__()
|
|
228
|
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
|
229
|
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
|
230
|
+
|
|
231
|
+
def forward(self, x):
|
|
232
|
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
|
233
|
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
|
234
|
+
return self.gamma * (x * Nx) + self.beta + x
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
|
238
|
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class ConvNeXtV2Block(nn.Module):
|
|
242
|
+
def __init__(
|
|
243
|
+
self,
|
|
244
|
+
dim: int,
|
|
245
|
+
intermediate_dim: int,
|
|
246
|
+
dilation: int = 1,
|
|
247
|
+
):
|
|
248
|
+
super().__init__()
|
|
249
|
+
padding = (dilation * (7 - 1)) // 2
|
|
250
|
+
self.dwconv = nn.Conv1d(
|
|
251
|
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
|
252
|
+
) # depthwise conv
|
|
253
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
254
|
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
|
255
|
+
self.act = nn.GELU()
|
|
256
|
+
self.grn = GRN(intermediate_dim)
|
|
257
|
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
258
|
+
|
|
259
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
260
|
+
residual = x
|
|
261
|
+
x = x.transpose(1, 2) # b n d -> b d n
|
|
262
|
+
x = self.dwconv(x)
|
|
263
|
+
x = x.transpose(1, 2) # b d n -> b n d
|
|
264
|
+
x = self.norm(x)
|
|
265
|
+
x = self.pwconv1(x)
|
|
266
|
+
x = self.act(x)
|
|
267
|
+
x = self.grn(x)
|
|
268
|
+
x = self.pwconv2(x)
|
|
269
|
+
return residual + x
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# AdaLayerNormZero
|
|
273
|
+
# return with modulated x for attn input, and params for later mlp modulation
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class AdaLayerNormZero(nn.Module):
|
|
277
|
+
def __init__(self, dim):
|
|
278
|
+
super().__init__()
|
|
279
|
+
|
|
280
|
+
self.silu = nn.SiLU()
|
|
281
|
+
self.linear = nn.Linear(dim, dim * 6)
|
|
282
|
+
|
|
283
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
284
|
+
|
|
285
|
+
def forward(self, x, emb=None):
|
|
286
|
+
emb = self.linear(self.silu(emb))
|
|
287
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
|
288
|
+
|
|
289
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
290
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# AdaLayerNormZero for final layer
|
|
294
|
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class AdaLayerNormZero_Final(nn.Module):
|
|
298
|
+
def __init__(self, dim):
|
|
299
|
+
super().__init__()
|
|
300
|
+
|
|
301
|
+
self.silu = nn.SiLU()
|
|
302
|
+
self.linear = nn.Linear(dim, dim * 2)
|
|
303
|
+
|
|
304
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
305
|
+
|
|
306
|
+
def forward(self, x, emb):
|
|
307
|
+
emb = self.linear(self.silu(emb))
|
|
308
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
309
|
+
|
|
310
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
311
|
+
return x
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# FeedForward
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class FeedForward(nn.Module):
|
|
318
|
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
|
319
|
+
super().__init__()
|
|
320
|
+
inner_dim = int(dim * mult)
|
|
321
|
+
dim_out = dim_out if dim_out is not None else dim
|
|
322
|
+
|
|
323
|
+
activation = nn.GELU(approximate=approximate)
|
|
324
|
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
|
325
|
+
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
|
326
|
+
|
|
327
|
+
def forward(self, x):
|
|
328
|
+
return self.ff(x)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
# Attention with possible joint part
|
|
332
|
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class Attention(nn.Module):
|
|
336
|
+
def __init__(
|
|
337
|
+
self,
|
|
338
|
+
processor: JointAttnProcessor | AttnProcessor,
|
|
339
|
+
dim: int,
|
|
340
|
+
heads: int = 8,
|
|
341
|
+
dim_head: int = 64,
|
|
342
|
+
dropout: float = 0.0,
|
|
343
|
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
|
344
|
+
context_pre_only=None,
|
|
345
|
+
):
|
|
346
|
+
super().__init__()
|
|
347
|
+
|
|
348
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
|
349
|
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
350
|
+
|
|
351
|
+
self.processor = processor
|
|
352
|
+
|
|
353
|
+
self.dim = dim
|
|
354
|
+
self.heads = heads
|
|
355
|
+
self.inner_dim = dim_head * heads
|
|
356
|
+
self.dropout = dropout
|
|
357
|
+
|
|
358
|
+
self.context_dim = context_dim
|
|
359
|
+
self.context_pre_only = context_pre_only
|
|
360
|
+
|
|
361
|
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
|
362
|
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
|
363
|
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
|
364
|
+
|
|
365
|
+
if self.context_dim is not None:
|
|
366
|
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
|
367
|
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
|
368
|
+
if self.context_pre_only is not None:
|
|
369
|
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
|
370
|
+
|
|
371
|
+
self.to_out = nn.ModuleList([])
|
|
372
|
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
|
373
|
+
self.to_out.append(nn.Dropout(dropout))
|
|
374
|
+
|
|
375
|
+
if self.context_pre_only is not None and not self.context_pre_only:
|
|
376
|
+
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
|
377
|
+
|
|
378
|
+
def forward(
|
|
379
|
+
self,
|
|
380
|
+
x: float["b n d"], # noised input x # noqa: F722
|
|
381
|
+
c: float["b n d"] = None, # context c # noqa: F722
|
|
382
|
+
mask: bool["b n"] | None = None, # noqa: F722
|
|
383
|
+
rope=None, # rotary position embedding for x
|
|
384
|
+
c_rope=None, # rotary position embedding for c
|
|
385
|
+
) -> torch.Tensor:
|
|
386
|
+
if c is not None:
|
|
387
|
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
|
388
|
+
else:
|
|
389
|
+
return self.processor(self, x, mask=mask, rope=rope)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
# Attention processor
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class AttnProcessor:
|
|
396
|
+
def __init__(self):
|
|
397
|
+
pass
|
|
398
|
+
|
|
399
|
+
def __call__(
|
|
400
|
+
self,
|
|
401
|
+
attn: Attention,
|
|
402
|
+
x: float["b n d"], # noised input x # noqa: F722
|
|
403
|
+
mask: bool["b n"] | None = None, # noqa: F722
|
|
404
|
+
rope=None, # rotary position embedding
|
|
405
|
+
) -> torch.FloatTensor:
|
|
406
|
+
batch_size = x.shape[0]
|
|
407
|
+
|
|
408
|
+
# `sample` projections.
|
|
409
|
+
query = attn.to_q(x)
|
|
410
|
+
key = attn.to_k(x)
|
|
411
|
+
value = attn.to_v(x)
|
|
412
|
+
|
|
413
|
+
# apply rotary position embedding
|
|
414
|
+
if rope is not None:
|
|
415
|
+
freqs, xpos_scale = rope
|
|
416
|
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
|
417
|
+
|
|
418
|
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
|
419
|
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
|
420
|
+
|
|
421
|
+
# attention
|
|
422
|
+
inner_dim = key.shape[-1]
|
|
423
|
+
head_dim = inner_dim // attn.heads
|
|
424
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
425
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
426
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
427
|
+
|
|
428
|
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
|
429
|
+
if mask is not None:
|
|
430
|
+
attn_mask = mask
|
|
431
|
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
|
432
|
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
|
433
|
+
else:
|
|
434
|
+
attn_mask = None
|
|
435
|
+
|
|
436
|
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
|
437
|
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
438
|
+
x = x.to(query.dtype)
|
|
439
|
+
|
|
440
|
+
# linear proj
|
|
441
|
+
x = attn.to_out[0](x)
|
|
442
|
+
# dropout
|
|
443
|
+
x = attn.to_out[1](x)
|
|
444
|
+
|
|
445
|
+
if mask is not None:
|
|
446
|
+
mask = mask.unsqueeze(-1)
|
|
447
|
+
x = x.masked_fill(~mask, 0.0)
|
|
448
|
+
|
|
449
|
+
return x
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# Joint Attention processor for MM-DiT
|
|
453
|
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
class JointAttnProcessor:
|
|
457
|
+
def __init__(self):
|
|
458
|
+
pass
|
|
459
|
+
|
|
460
|
+
def __call__(
|
|
461
|
+
self,
|
|
462
|
+
attn: Attention,
|
|
463
|
+
x: float["b n d"], # noised input x # noqa: F722
|
|
464
|
+
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
|
465
|
+
mask: bool["b n"] | None = None, # noqa: F722
|
|
466
|
+
rope=None, # rotary position embedding for x
|
|
467
|
+
c_rope=None, # rotary position embedding for c
|
|
468
|
+
) -> torch.FloatTensor:
|
|
469
|
+
residual = x
|
|
470
|
+
|
|
471
|
+
batch_size = c.shape[0]
|
|
472
|
+
|
|
473
|
+
# `sample` projections.
|
|
474
|
+
query = attn.to_q(x)
|
|
475
|
+
key = attn.to_k(x)
|
|
476
|
+
value = attn.to_v(x)
|
|
477
|
+
|
|
478
|
+
# `context` projections.
|
|
479
|
+
c_query = attn.to_q_c(c)
|
|
480
|
+
c_key = attn.to_k_c(c)
|
|
481
|
+
c_value = attn.to_v_c(c)
|
|
482
|
+
|
|
483
|
+
# apply rope for context and noised input independently
|
|
484
|
+
if rope is not None:
|
|
485
|
+
freqs, xpos_scale = rope
|
|
486
|
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
|
487
|
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
|
488
|
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
|
489
|
+
if c_rope is not None:
|
|
490
|
+
freqs, xpos_scale = c_rope
|
|
491
|
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
|
492
|
+
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
|
493
|
+
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
|
494
|
+
|
|
495
|
+
# attention
|
|
496
|
+
query = torch.cat([query, c_query], dim=1)
|
|
497
|
+
key = torch.cat([key, c_key], dim=1)
|
|
498
|
+
value = torch.cat([value, c_value], dim=1)
|
|
499
|
+
|
|
500
|
+
inner_dim = key.shape[-1]
|
|
501
|
+
head_dim = inner_dim // attn.heads
|
|
502
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
503
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
504
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
505
|
+
|
|
506
|
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
|
507
|
+
if mask is not None:
|
|
508
|
+
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
|
509
|
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
|
510
|
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
|
511
|
+
else:
|
|
512
|
+
attn_mask = None
|
|
513
|
+
|
|
514
|
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
|
515
|
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
516
|
+
x = x.to(query.dtype)
|
|
517
|
+
|
|
518
|
+
# Split the attention outputs.
|
|
519
|
+
x, c = (
|
|
520
|
+
x[:, : residual.shape[1]],
|
|
521
|
+
x[:, residual.shape[1] :],
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# linear proj
|
|
525
|
+
x = attn.to_out[0](x)
|
|
526
|
+
# dropout
|
|
527
|
+
x = attn.to_out[1](x)
|
|
528
|
+
if not attn.context_pre_only:
|
|
529
|
+
c = attn.to_out_c(c)
|
|
530
|
+
|
|
531
|
+
if mask is not None:
|
|
532
|
+
mask = mask.unsqueeze(-1)
|
|
533
|
+
x = x.masked_fill(~mask, 0.0)
|
|
534
|
+
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
|
535
|
+
|
|
536
|
+
return x, c
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
# DiT Block
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class DiTBlock(nn.Module):
|
|
543
|
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
|
544
|
+
super().__init__()
|
|
545
|
+
|
|
546
|
+
self.attn_norm = AdaLayerNormZero(dim)
|
|
547
|
+
self.attn = Attention(
|
|
548
|
+
processor=AttnProcessor(),
|
|
549
|
+
dim=dim,
|
|
550
|
+
heads=heads,
|
|
551
|
+
dim_head=dim_head,
|
|
552
|
+
dropout=dropout,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
556
|
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
|
557
|
+
|
|
558
|
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
|
559
|
+
# pre-norm & modulation for attention input
|
|
560
|
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
|
561
|
+
|
|
562
|
+
# attention
|
|
563
|
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
|
564
|
+
|
|
565
|
+
# process attention output for input x
|
|
566
|
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
|
567
|
+
|
|
568
|
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
569
|
+
ff_output = self.ff(norm)
|
|
570
|
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
|
571
|
+
|
|
572
|
+
return x
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
# MMDiT Block https://arxiv.org/abs/2403.03206
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class MMDiTBlock(nn.Module):
|
|
579
|
+
r"""
|
|
580
|
+
modified from diffusers/src/diffusers/models/attention.py
|
|
581
|
+
|
|
582
|
+
notes.
|
|
583
|
+
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
|
584
|
+
_x: noised input related. (right part)
|
|
585
|
+
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
|
586
|
+
"""
|
|
587
|
+
|
|
588
|
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
|
589
|
+
super().__init__()
|
|
590
|
+
|
|
591
|
+
self.context_pre_only = context_pre_only
|
|
592
|
+
|
|
593
|
+
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
|
594
|
+
self.attn_norm_x = AdaLayerNormZero(dim)
|
|
595
|
+
self.attn = Attention(
|
|
596
|
+
processor=JointAttnProcessor(),
|
|
597
|
+
dim=dim,
|
|
598
|
+
heads=heads,
|
|
599
|
+
dim_head=dim_head,
|
|
600
|
+
dropout=dropout,
|
|
601
|
+
context_dim=dim,
|
|
602
|
+
context_pre_only=context_pre_only,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
if not context_pre_only:
|
|
606
|
+
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
607
|
+
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
|
608
|
+
else:
|
|
609
|
+
self.ff_norm_c = None
|
|
610
|
+
self.ff_c = None
|
|
611
|
+
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
612
|
+
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
|
613
|
+
|
|
614
|
+
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
|
615
|
+
# pre-norm & modulation for attention input
|
|
616
|
+
if self.context_pre_only:
|
|
617
|
+
norm_c = self.attn_norm_c(c, t)
|
|
618
|
+
else:
|
|
619
|
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
|
620
|
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
|
621
|
+
|
|
622
|
+
# attention
|
|
623
|
+
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
|
624
|
+
|
|
625
|
+
# process attention output for context c
|
|
626
|
+
if self.context_pre_only:
|
|
627
|
+
c = None
|
|
628
|
+
else: # if not last layer
|
|
629
|
+
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
|
630
|
+
|
|
631
|
+
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
|
632
|
+
c_ff_output = self.ff_c(norm_c)
|
|
633
|
+
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
|
634
|
+
|
|
635
|
+
# process attention output for input x
|
|
636
|
+
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
|
637
|
+
|
|
638
|
+
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
|
639
|
+
x_ff_output = self.ff_x(norm_x)
|
|
640
|
+
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
|
641
|
+
|
|
642
|
+
return c, x
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
# time step conditioning embedding
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
class TimestepEmbedding(nn.Module):
|
|
649
|
+
def __init__(self, dim, freq_embed_dim=256):
|
|
650
|
+
super().__init__()
|
|
651
|
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
|
652
|
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
|
653
|
+
|
|
654
|
+
def forward(self, timestep: float["b"]): # noqa: F821
|
|
655
|
+
time_hidden = self.time_embed(timestep)
|
|
656
|
+
time_hidden = time_hidden.to(timestep.dtype)
|
|
657
|
+
time = self.time_mlp(time_hidden) # b d
|
|
658
|
+
return time
|