xinference 0.14.3__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/worker.py +18 -9
- xinference/model/audio/chattts.py +4 -3
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/embedding/core.py +2 -0
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/stable_diffusion/core.py +21 -6
- xinference/model/llm/llm_family.py +5 -6
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/utils.py +3 -0
- xinference/model/llm/vllm/core.py +0 -33
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/tools/api.py +1 -1
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/METADATA +20 -12
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/RECORD +70 -28
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from conformer import ConformerBlock
|
|
8
|
+
from diffusers.models.activations import get_activation
|
|
9
|
+
from einops import pack, rearrange, repeat
|
|
10
|
+
|
|
11
|
+
from matcha.models.components.transformer import BasicTransformerBlock
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SinusoidalPosEmb(torch.nn.Module):
|
|
15
|
+
def __init__(self, dim):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.dim = dim
|
|
18
|
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
|
19
|
+
|
|
20
|
+
def forward(self, x, scale=1000):
|
|
21
|
+
if x.ndim < 1:
|
|
22
|
+
x = x.unsqueeze(0)
|
|
23
|
+
device = x.device
|
|
24
|
+
half_dim = self.dim // 2
|
|
25
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
26
|
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
|
27
|
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
|
28
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
29
|
+
return emb
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Block1D(torch.nn.Module):
|
|
33
|
+
def __init__(self, dim, dim_out, groups=8):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.block = torch.nn.Sequential(
|
|
36
|
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
|
37
|
+
torch.nn.GroupNorm(groups, dim_out),
|
|
38
|
+
nn.Mish(),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def forward(self, x, mask):
|
|
42
|
+
output = self.block(x * mask)
|
|
43
|
+
return output * mask
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ResnetBlock1D(torch.nn.Module):
|
|
47
|
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
|
50
|
+
|
|
51
|
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
|
52
|
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
|
53
|
+
|
|
54
|
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
|
55
|
+
|
|
56
|
+
def forward(self, x, mask, time_emb):
|
|
57
|
+
h = self.block1(x, mask)
|
|
58
|
+
h += self.mlp(time_emb).unsqueeze(-1)
|
|
59
|
+
h = self.block2(h, mask)
|
|
60
|
+
output = h + self.res_conv(x * mask)
|
|
61
|
+
return output
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Downsample1D(nn.Module):
|
|
65
|
+
def __init__(self, dim):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
|
68
|
+
|
|
69
|
+
def forward(self, x):
|
|
70
|
+
return self.conv(x)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TimestepEmbedding(nn.Module):
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
in_channels: int,
|
|
77
|
+
time_embed_dim: int,
|
|
78
|
+
act_fn: str = "silu",
|
|
79
|
+
out_dim: int = None,
|
|
80
|
+
post_act_fn: Optional[str] = None,
|
|
81
|
+
cond_proj_dim=None,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
|
|
85
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
|
86
|
+
|
|
87
|
+
if cond_proj_dim is not None:
|
|
88
|
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
|
89
|
+
else:
|
|
90
|
+
self.cond_proj = None
|
|
91
|
+
|
|
92
|
+
self.act = get_activation(act_fn)
|
|
93
|
+
|
|
94
|
+
if out_dim is not None:
|
|
95
|
+
time_embed_dim_out = out_dim
|
|
96
|
+
else:
|
|
97
|
+
time_embed_dim_out = time_embed_dim
|
|
98
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
|
99
|
+
|
|
100
|
+
if post_act_fn is None:
|
|
101
|
+
self.post_act = None
|
|
102
|
+
else:
|
|
103
|
+
self.post_act = get_activation(post_act_fn)
|
|
104
|
+
|
|
105
|
+
def forward(self, sample, condition=None):
|
|
106
|
+
if condition is not None:
|
|
107
|
+
sample = sample + self.cond_proj(condition)
|
|
108
|
+
sample = self.linear_1(sample)
|
|
109
|
+
|
|
110
|
+
if self.act is not None:
|
|
111
|
+
sample = self.act(sample)
|
|
112
|
+
|
|
113
|
+
sample = self.linear_2(sample)
|
|
114
|
+
|
|
115
|
+
if self.post_act is not None:
|
|
116
|
+
sample = self.post_act(sample)
|
|
117
|
+
return sample
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class Upsample1D(nn.Module):
|
|
121
|
+
"""A 1D upsampling layer with an optional convolution.
|
|
122
|
+
|
|
123
|
+
Parameters:
|
|
124
|
+
channels (`int`):
|
|
125
|
+
number of channels in the inputs and outputs.
|
|
126
|
+
use_conv (`bool`, default `False`):
|
|
127
|
+
option to use a convolution.
|
|
128
|
+
use_conv_transpose (`bool`, default `False`):
|
|
129
|
+
option to use a convolution transpose.
|
|
130
|
+
out_channels (`int`, optional):
|
|
131
|
+
number of output channels. Defaults to `channels`.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.channels = channels
|
|
137
|
+
self.out_channels = out_channels or channels
|
|
138
|
+
self.use_conv = use_conv
|
|
139
|
+
self.use_conv_transpose = use_conv_transpose
|
|
140
|
+
self.name = name
|
|
141
|
+
|
|
142
|
+
self.conv = None
|
|
143
|
+
if use_conv_transpose:
|
|
144
|
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
|
145
|
+
elif use_conv:
|
|
146
|
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
|
147
|
+
|
|
148
|
+
def forward(self, inputs):
|
|
149
|
+
assert inputs.shape[1] == self.channels
|
|
150
|
+
if self.use_conv_transpose:
|
|
151
|
+
return self.conv(inputs)
|
|
152
|
+
|
|
153
|
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
|
154
|
+
|
|
155
|
+
if self.use_conv:
|
|
156
|
+
outputs = self.conv(outputs)
|
|
157
|
+
|
|
158
|
+
return outputs
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class ConformerWrapper(ConformerBlock):
|
|
162
|
+
def __init__( # pylint: disable=useless-super-delegation
|
|
163
|
+
self,
|
|
164
|
+
*,
|
|
165
|
+
dim,
|
|
166
|
+
dim_head=64,
|
|
167
|
+
heads=8,
|
|
168
|
+
ff_mult=4,
|
|
169
|
+
conv_expansion_factor=2,
|
|
170
|
+
conv_kernel_size=31,
|
|
171
|
+
attn_dropout=0,
|
|
172
|
+
ff_dropout=0,
|
|
173
|
+
conv_dropout=0,
|
|
174
|
+
conv_causal=False,
|
|
175
|
+
):
|
|
176
|
+
super().__init__(
|
|
177
|
+
dim=dim,
|
|
178
|
+
dim_head=dim_head,
|
|
179
|
+
heads=heads,
|
|
180
|
+
ff_mult=ff_mult,
|
|
181
|
+
conv_expansion_factor=conv_expansion_factor,
|
|
182
|
+
conv_kernel_size=conv_kernel_size,
|
|
183
|
+
attn_dropout=attn_dropout,
|
|
184
|
+
ff_dropout=ff_dropout,
|
|
185
|
+
conv_dropout=conv_dropout,
|
|
186
|
+
conv_causal=conv_causal,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def forward(
|
|
190
|
+
self,
|
|
191
|
+
hidden_states,
|
|
192
|
+
attention_mask,
|
|
193
|
+
encoder_hidden_states=None,
|
|
194
|
+
encoder_attention_mask=None,
|
|
195
|
+
timestep=None,
|
|
196
|
+
):
|
|
197
|
+
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class Decoder(nn.Module):
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
in_channels,
|
|
204
|
+
out_channels,
|
|
205
|
+
channels=(256, 256),
|
|
206
|
+
dropout=0.05,
|
|
207
|
+
attention_head_dim=64,
|
|
208
|
+
n_blocks=1,
|
|
209
|
+
num_mid_blocks=2,
|
|
210
|
+
num_heads=4,
|
|
211
|
+
act_fn="snake",
|
|
212
|
+
down_block_type="transformer",
|
|
213
|
+
mid_block_type="transformer",
|
|
214
|
+
up_block_type="transformer",
|
|
215
|
+
):
|
|
216
|
+
super().__init__()
|
|
217
|
+
channels = tuple(channels)
|
|
218
|
+
self.in_channels = in_channels
|
|
219
|
+
self.out_channels = out_channels
|
|
220
|
+
|
|
221
|
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
222
|
+
time_embed_dim = channels[0] * 4
|
|
223
|
+
self.time_mlp = TimestepEmbedding(
|
|
224
|
+
in_channels=in_channels,
|
|
225
|
+
time_embed_dim=time_embed_dim,
|
|
226
|
+
act_fn="silu",
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self.down_blocks = nn.ModuleList([])
|
|
230
|
+
self.mid_blocks = nn.ModuleList([])
|
|
231
|
+
self.up_blocks = nn.ModuleList([])
|
|
232
|
+
|
|
233
|
+
output_channel = in_channels
|
|
234
|
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
|
235
|
+
input_channel = output_channel
|
|
236
|
+
output_channel = channels[i]
|
|
237
|
+
is_last = i == len(channels) - 1
|
|
238
|
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
239
|
+
transformer_blocks = nn.ModuleList(
|
|
240
|
+
[
|
|
241
|
+
self.get_block(
|
|
242
|
+
down_block_type,
|
|
243
|
+
output_channel,
|
|
244
|
+
attention_head_dim,
|
|
245
|
+
num_heads,
|
|
246
|
+
dropout,
|
|
247
|
+
act_fn,
|
|
248
|
+
)
|
|
249
|
+
for _ in range(n_blocks)
|
|
250
|
+
]
|
|
251
|
+
)
|
|
252
|
+
downsample = (
|
|
253
|
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
257
|
+
|
|
258
|
+
for i in range(num_mid_blocks):
|
|
259
|
+
input_channel = channels[-1]
|
|
260
|
+
out_channels = channels[-1]
|
|
261
|
+
|
|
262
|
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
263
|
+
|
|
264
|
+
transformer_blocks = nn.ModuleList(
|
|
265
|
+
[
|
|
266
|
+
self.get_block(
|
|
267
|
+
mid_block_type,
|
|
268
|
+
output_channel,
|
|
269
|
+
attention_head_dim,
|
|
270
|
+
num_heads,
|
|
271
|
+
dropout,
|
|
272
|
+
act_fn,
|
|
273
|
+
)
|
|
274
|
+
for _ in range(n_blocks)
|
|
275
|
+
]
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
|
279
|
+
|
|
280
|
+
channels = channels[::-1] + (channels[0],)
|
|
281
|
+
for i in range(len(channels) - 1):
|
|
282
|
+
input_channel = channels[i]
|
|
283
|
+
output_channel = channels[i + 1]
|
|
284
|
+
is_last = i == len(channels) - 2
|
|
285
|
+
|
|
286
|
+
resnet = ResnetBlock1D(
|
|
287
|
+
dim=2 * input_channel,
|
|
288
|
+
dim_out=output_channel,
|
|
289
|
+
time_emb_dim=time_embed_dim,
|
|
290
|
+
)
|
|
291
|
+
transformer_blocks = nn.ModuleList(
|
|
292
|
+
[
|
|
293
|
+
self.get_block(
|
|
294
|
+
up_block_type,
|
|
295
|
+
output_channel,
|
|
296
|
+
attention_head_dim,
|
|
297
|
+
num_heads,
|
|
298
|
+
dropout,
|
|
299
|
+
act_fn,
|
|
300
|
+
)
|
|
301
|
+
for _ in range(n_blocks)
|
|
302
|
+
]
|
|
303
|
+
)
|
|
304
|
+
upsample = (
|
|
305
|
+
Upsample1D(output_channel, use_conv_transpose=True)
|
|
306
|
+
if not is_last
|
|
307
|
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
311
|
+
|
|
312
|
+
self.final_block = Block1D(channels[-1], channels[-1])
|
|
313
|
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
314
|
+
|
|
315
|
+
self.initialize_weights()
|
|
316
|
+
# nn.init.normal_(self.final_proj.weight)
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
|
320
|
+
if block_type == "conformer":
|
|
321
|
+
block = ConformerWrapper(
|
|
322
|
+
dim=dim,
|
|
323
|
+
dim_head=attention_head_dim,
|
|
324
|
+
heads=num_heads,
|
|
325
|
+
ff_mult=1,
|
|
326
|
+
conv_expansion_factor=2,
|
|
327
|
+
ff_dropout=dropout,
|
|
328
|
+
attn_dropout=dropout,
|
|
329
|
+
conv_dropout=dropout,
|
|
330
|
+
conv_kernel_size=31,
|
|
331
|
+
)
|
|
332
|
+
elif block_type == "transformer":
|
|
333
|
+
block = BasicTransformerBlock(
|
|
334
|
+
dim=dim,
|
|
335
|
+
num_attention_heads=num_heads,
|
|
336
|
+
attention_head_dim=attention_head_dim,
|
|
337
|
+
dropout=dropout,
|
|
338
|
+
activation_fn=act_fn,
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError(f"Unknown block type {block_type}")
|
|
342
|
+
|
|
343
|
+
return block
|
|
344
|
+
|
|
345
|
+
def initialize_weights(self):
|
|
346
|
+
for m in self.modules():
|
|
347
|
+
if isinstance(m, nn.Conv1d):
|
|
348
|
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
349
|
+
|
|
350
|
+
if m.bias is not None:
|
|
351
|
+
nn.init.constant_(m.bias, 0)
|
|
352
|
+
|
|
353
|
+
elif isinstance(m, nn.GroupNorm):
|
|
354
|
+
nn.init.constant_(m.weight, 1)
|
|
355
|
+
nn.init.constant_(m.bias, 0)
|
|
356
|
+
|
|
357
|
+
elif isinstance(m, nn.Linear):
|
|
358
|
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
359
|
+
|
|
360
|
+
if m.bias is not None:
|
|
361
|
+
nn.init.constant_(m.bias, 0)
|
|
362
|
+
|
|
363
|
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
|
364
|
+
"""Forward pass of the UNet1DConditional model.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
368
|
+
mask (_type_): shape (batch_size, 1, time)
|
|
369
|
+
t (_type_): shape (batch_size)
|
|
370
|
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
371
|
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
372
|
+
|
|
373
|
+
Raises:
|
|
374
|
+
ValueError: _description_
|
|
375
|
+
ValueError: _description_
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
_type_: _description_
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
t = self.time_embeddings(t)
|
|
382
|
+
t = self.time_mlp(t)
|
|
383
|
+
|
|
384
|
+
x = pack([x, mu], "b * t")[0]
|
|
385
|
+
|
|
386
|
+
if spks is not None:
|
|
387
|
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
388
|
+
x = pack([x, spks], "b * t")[0]
|
|
389
|
+
|
|
390
|
+
hiddens = []
|
|
391
|
+
masks = [mask]
|
|
392
|
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
|
393
|
+
mask_down = masks[-1]
|
|
394
|
+
x = resnet(x, mask_down, t)
|
|
395
|
+
x = rearrange(x, "b c t -> b t c")
|
|
396
|
+
mask_down = rearrange(mask_down, "b 1 t -> b t")
|
|
397
|
+
for transformer_block in transformer_blocks:
|
|
398
|
+
x = transformer_block(
|
|
399
|
+
hidden_states=x,
|
|
400
|
+
attention_mask=mask_down,
|
|
401
|
+
timestep=t,
|
|
402
|
+
)
|
|
403
|
+
x = rearrange(x, "b t c -> b c t")
|
|
404
|
+
mask_down = rearrange(mask_down, "b t -> b 1 t")
|
|
405
|
+
hiddens.append(x) # Save hidden states for skip connections
|
|
406
|
+
x = downsample(x * mask_down)
|
|
407
|
+
masks.append(mask_down[:, :, ::2])
|
|
408
|
+
|
|
409
|
+
masks = masks[:-1]
|
|
410
|
+
mask_mid = masks[-1]
|
|
411
|
+
|
|
412
|
+
for resnet, transformer_blocks in self.mid_blocks:
|
|
413
|
+
x = resnet(x, mask_mid, t)
|
|
414
|
+
x = rearrange(x, "b c t -> b t c")
|
|
415
|
+
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
|
416
|
+
for transformer_block in transformer_blocks:
|
|
417
|
+
x = transformer_block(
|
|
418
|
+
hidden_states=x,
|
|
419
|
+
attention_mask=mask_mid,
|
|
420
|
+
timestep=t,
|
|
421
|
+
)
|
|
422
|
+
x = rearrange(x, "b t c -> b c t")
|
|
423
|
+
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
|
424
|
+
|
|
425
|
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
|
426
|
+
mask_up = masks.pop()
|
|
427
|
+
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
|
428
|
+
x = rearrange(x, "b c t -> b t c")
|
|
429
|
+
mask_up = rearrange(mask_up, "b 1 t -> b t")
|
|
430
|
+
for transformer_block in transformer_blocks:
|
|
431
|
+
x = transformer_block(
|
|
432
|
+
hidden_states=x,
|
|
433
|
+
attention_mask=mask_up,
|
|
434
|
+
timestep=t,
|
|
435
|
+
)
|
|
436
|
+
x = rearrange(x, "b t c -> b c t")
|
|
437
|
+
mask_up = rearrange(mask_up, "b t -> b 1 t")
|
|
438
|
+
x = upsample(x * mask_up)
|
|
439
|
+
|
|
440
|
+
x = self.final_block(x, mask_up)
|
|
441
|
+
output = self.final_proj(x * mask_up)
|
|
442
|
+
|
|
443
|
+
return output * mask
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
from matcha.models.components.decoder import Decoder
|
|
7
|
+
from matcha.utils.pylogger import get_pylogger
|
|
8
|
+
|
|
9
|
+
log = get_pylogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BASECFM(torch.nn.Module, ABC):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
n_feats,
|
|
16
|
+
cfm_params,
|
|
17
|
+
n_spks=1,
|
|
18
|
+
spk_emb_dim=128,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.n_feats = n_feats
|
|
22
|
+
self.n_spks = n_spks
|
|
23
|
+
self.spk_emb_dim = spk_emb_dim
|
|
24
|
+
self.solver = cfm_params.solver
|
|
25
|
+
if hasattr(cfm_params, "sigma_min"):
|
|
26
|
+
self.sigma_min = cfm_params.sigma_min
|
|
27
|
+
else:
|
|
28
|
+
self.sigma_min = 1e-4
|
|
29
|
+
|
|
30
|
+
self.estimator = None
|
|
31
|
+
|
|
32
|
+
@torch.inference_mode()
|
|
33
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
34
|
+
"""Forward diffusion
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
mu (torch.Tensor): output of encoder
|
|
38
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
39
|
+
mask (torch.Tensor): output_mask
|
|
40
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
41
|
+
n_timesteps (int): number of diffusion steps
|
|
42
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
43
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
44
|
+
shape: (batch_size, spk_emb_dim)
|
|
45
|
+
cond: Not used but kept for future purposes
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
sample: generated mel-spectrogram
|
|
49
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
50
|
+
"""
|
|
51
|
+
z = torch.randn_like(mu) * temperature
|
|
52
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
|
53
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
54
|
+
|
|
55
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
56
|
+
"""
|
|
57
|
+
Fixed euler solver for ODEs.
|
|
58
|
+
Args:
|
|
59
|
+
x (torch.Tensor): random noise
|
|
60
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
61
|
+
shape: (n_timesteps + 1,)
|
|
62
|
+
mu (torch.Tensor): output of encoder
|
|
63
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
64
|
+
mask (torch.Tensor): output_mask
|
|
65
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
66
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
67
|
+
shape: (batch_size, spk_emb_dim)
|
|
68
|
+
cond: Not used but kept for future purposes
|
|
69
|
+
"""
|
|
70
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
71
|
+
|
|
72
|
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
73
|
+
# Or in future might add like a return_all_steps flag
|
|
74
|
+
sol = []
|
|
75
|
+
|
|
76
|
+
for step in range(1, len(t_span)):
|
|
77
|
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
78
|
+
|
|
79
|
+
x = x + dt * dphi_dt
|
|
80
|
+
t = t + dt
|
|
81
|
+
sol.append(x)
|
|
82
|
+
if step < len(t_span) - 1:
|
|
83
|
+
dt = t_span[step + 1] - t
|
|
84
|
+
|
|
85
|
+
return sol[-1]
|
|
86
|
+
|
|
87
|
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
88
|
+
"""Computes diffusion loss
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
x1 (torch.Tensor): Target
|
|
92
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
93
|
+
mask (torch.Tensor): target mask
|
|
94
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
95
|
+
mu (torch.Tensor): output of encoder
|
|
96
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
97
|
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
|
98
|
+
shape: (batch_size, spk_emb_dim)
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
loss: conditional flow matching loss
|
|
102
|
+
y: conditional flow
|
|
103
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
104
|
+
"""
|
|
105
|
+
b, _, t = mu.shape
|
|
106
|
+
|
|
107
|
+
# random timestep
|
|
108
|
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
|
109
|
+
# sample noise p(x_0)
|
|
110
|
+
z = torch.randn_like(x1)
|
|
111
|
+
|
|
112
|
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
|
113
|
+
u = x1 - (1 - self.sigma_min) * z
|
|
114
|
+
|
|
115
|
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
|
116
|
+
torch.sum(mask) * u.shape[1]
|
|
117
|
+
)
|
|
118
|
+
return loss, y
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class CFM(BASECFM):
|
|
122
|
+
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
|
123
|
+
super().__init__(
|
|
124
|
+
n_feats=in_channels,
|
|
125
|
+
cfm_params=cfm_params,
|
|
126
|
+
n_spks=n_spks,
|
|
127
|
+
spk_emb_dim=spk_emb_dim,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
|
131
|
+
# Just change the architecture of the estimator here
|
|
132
|
+
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|