xinference 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +97 -8
- xinference/client/restful/restful_client.py +51 -11
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/worker.py +31 -37
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +1 -0
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +20 -3
- xinference/model/audio/model_spec_modelscope.json +18 -1
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +37 -110
- xinference/model/llm/core.py +15 -6
- xinference/model/llm/llama_cpp/core.py +25 -353
- xinference/model/llm/llm_family.json +613 -89
- xinference/model/llm/llm_family.py +9 -1
- xinference/model/llm/llm_family_modelscope.json +540 -90
- xinference/model/llm/mlx/core.py +6 -3
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +16 -3
- xinference/model/llm/transformers/chatglm.py +2 -2
- xinference/model/llm/transformers/cogagent.py +1 -1
- xinference/model/llm/transformers/cogvlm2.py +1 -1
- xinference/model/llm/transformers/core.py +9 -3
- xinference/model/llm/transformers/glm4v.py +1 -1
- xinference/model/llm/transformers/minicpmv26.py +1 -1
- xinference/model/llm/transformers/qwen-omni.py +6 -0
- xinference/model/llm/transformers/qwen_vl.py +1 -1
- xinference/model/llm/utils.py +68 -45
- xinference/model/llm/vllm/core.py +38 -18
- xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +133 -16
- xinference/model/video/model_spec.json +54 -0
- xinference/model/video/model_spec_modelscope.json +56 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +0 -71
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/src/locales/en.json +6 -4
- xinference/web/ui/src/locales/zh.json +6 -4
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/RECORD +87 -87
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
- xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -11,14 +11,16 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
from typing import Tuple, Optional, Dict, Any
|
|
14
15
|
import torch
|
|
15
16
|
import torch.nn as nn
|
|
16
17
|
import torch.nn.functional as F
|
|
17
18
|
from einops import pack, rearrange, repeat
|
|
19
|
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
|
|
18
20
|
from cosyvoice.utils.common import mask_to_bias
|
|
19
21
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
|
20
22
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
|
21
|
-
from matcha.models.components.transformer import BasicTransformerBlock
|
|
23
|
+
from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class Transpose(torch.nn.Module):
|
|
@@ -27,34 +29,11 @@ class Transpose(torch.nn.Module):
|
|
|
27
29
|
self.dim0 = dim0
|
|
28
30
|
self.dim1 = dim1
|
|
29
31
|
|
|
30
|
-
def forward(self, x: torch.Tensor):
|
|
32
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
31
33
|
x = torch.transpose(x, self.dim0, self.dim1)
|
|
32
34
|
return x
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
class CausalBlock1D(Block1D):
|
|
36
|
-
def __init__(self, dim: int, dim_out: int):
|
|
37
|
-
super(CausalBlock1D, self).__init__(dim, dim_out)
|
|
38
|
-
self.block = torch.nn.Sequential(
|
|
39
|
-
CausalConv1d(dim, dim_out, 3),
|
|
40
|
-
Transpose(1, 2),
|
|
41
|
-
nn.LayerNorm(dim_out),
|
|
42
|
-
Transpose(1, 2),
|
|
43
|
-
nn.Mish(),
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
|
47
|
-
output = self.block(x * mask)
|
|
48
|
-
return output * mask
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class CausalResnetBlock1D(ResnetBlock1D):
|
|
52
|
-
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
|
53
|
-
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
54
|
-
self.block1 = CausalBlock1D(dim, dim_out)
|
|
55
|
-
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
56
|
-
|
|
57
|
-
|
|
58
37
|
class CausalConv1d(torch.nn.Conv1d):
|
|
59
38
|
def __init__(
|
|
60
39
|
self,
|
|
@@ -76,12 +55,339 @@ class CausalConv1d(torch.nn.Conv1d):
|
|
|
76
55
|
padding_mode=padding_mode,
|
|
77
56
|
device=device, dtype=dtype)
|
|
78
57
|
assert stride == 1
|
|
79
|
-
self.causal_padding =
|
|
58
|
+
self.causal_padding = kernel_size - 1
|
|
80
59
|
|
|
81
|
-
def forward(self, x: torch.Tensor):
|
|
82
|
-
|
|
60
|
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
61
|
+
if cache.size(2) == 0:
|
|
62
|
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
|
63
|
+
else:
|
|
64
|
+
assert cache.size(2) == self.causal_padding
|
|
65
|
+
x = torch.concat([cache, x], dim=2)
|
|
66
|
+
cache = x[:, :, -self.causal_padding:]
|
|
83
67
|
x = super(CausalConv1d, self).forward(x)
|
|
84
|
-
return x
|
|
68
|
+
return x, cache
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CausalBlock1D(Block1D):
|
|
72
|
+
def __init__(self, dim: int, dim_out: int):
|
|
73
|
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
|
74
|
+
self.block = torch.nn.Sequential(
|
|
75
|
+
CausalConv1d(dim, dim_out, 3),
|
|
76
|
+
Transpose(1, 2),
|
|
77
|
+
nn.LayerNorm(dim_out),
|
|
78
|
+
Transpose(1, 2),
|
|
79
|
+
nn.Mish(),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
+
output, cache = self.block[0](x * mask, cache)
|
|
84
|
+
for i in range(1, len(self.block)):
|
|
85
|
+
output = self.block[i](output)
|
|
86
|
+
return output * mask, cache
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class CausalResnetBlock1D(ResnetBlock1D):
|
|
90
|
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
|
91
|
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
92
|
+
self.block1 = CausalBlock1D(dim, dim_out)
|
|
93
|
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
94
|
+
|
|
95
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
|
|
96
|
+
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
|
97
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
98
|
+
h, block1_cache = self.block1(x, mask, block1_cache)
|
|
99
|
+
h += self.mlp(time_emb).unsqueeze(-1)
|
|
100
|
+
h, block2_cache = self.block2(h, mask, block2_cache)
|
|
101
|
+
output = h + self.res_conv(x * mask)
|
|
102
|
+
return output, block1_cache, block2_cache
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class CausalAttnProcessor2_0(AttnProcessor2_0):
|
|
106
|
+
r"""
|
|
107
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self):
|
|
111
|
+
super(CausalAttnProcessor2_0, self).__init__()
|
|
112
|
+
|
|
113
|
+
def __call__(
|
|
114
|
+
self,
|
|
115
|
+
attn: Attention,
|
|
116
|
+
hidden_states: torch.FloatTensor,
|
|
117
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
118
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
119
|
+
temb: Optional[torch.FloatTensor] = None,
|
|
120
|
+
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
121
|
+
*args,
|
|
122
|
+
**kwargs,
|
|
123
|
+
) -> Tuple[torch.FloatTensor, torch.Tensor]:
|
|
124
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
|
125
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
|
|
126
|
+
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
|
127
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
|
128
|
+
|
|
129
|
+
residual = hidden_states
|
|
130
|
+
if attn.spatial_norm is not None:
|
|
131
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
132
|
+
|
|
133
|
+
input_ndim = hidden_states.ndim
|
|
134
|
+
|
|
135
|
+
if input_ndim == 4:
|
|
136
|
+
batch_size, channel, height, width = hidden_states.shape
|
|
137
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
138
|
+
|
|
139
|
+
batch_size, sequence_length, _ = (
|
|
140
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if attention_mask is not None:
|
|
144
|
+
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
|
|
145
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
|
146
|
+
# (batch, heads, source_length, target_length)
|
|
147
|
+
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
|
|
148
|
+
|
|
149
|
+
if attn.group_norm is not None:
|
|
150
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
151
|
+
|
|
152
|
+
query = attn.to_q(hidden_states)
|
|
153
|
+
|
|
154
|
+
if encoder_hidden_states is None:
|
|
155
|
+
encoder_hidden_states = hidden_states
|
|
156
|
+
elif attn.norm_cross:
|
|
157
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
158
|
+
|
|
159
|
+
key_cache = attn.to_k(encoder_hidden_states)
|
|
160
|
+
value_cache = attn.to_v(encoder_hidden_states)
|
|
161
|
+
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
|
162
|
+
if cache.size(0) != 0:
|
|
163
|
+
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
|
164
|
+
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
|
165
|
+
else:
|
|
166
|
+
key, value = key_cache, value_cache
|
|
167
|
+
cache = torch.stack([key_cache, value_cache], dim=3)
|
|
168
|
+
|
|
169
|
+
inner_dim = key.shape[-1]
|
|
170
|
+
head_dim = inner_dim // attn.heads
|
|
171
|
+
|
|
172
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
173
|
+
|
|
174
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
175
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
176
|
+
|
|
177
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
178
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
179
|
+
hidden_states = F.scaled_dot_product_attention(
|
|
180
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
184
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
185
|
+
|
|
186
|
+
# linear proj
|
|
187
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
188
|
+
# dropout
|
|
189
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
190
|
+
|
|
191
|
+
if input_ndim == 4:
|
|
192
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
193
|
+
|
|
194
|
+
if attn.residual_connection:
|
|
195
|
+
hidden_states = hidden_states + residual
|
|
196
|
+
|
|
197
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
|
198
|
+
|
|
199
|
+
return hidden_states, cache
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@maybe_allow_in_graph
|
|
203
|
+
class CausalAttention(Attention):
|
|
204
|
+
def __init__(
|
|
205
|
+
self,
|
|
206
|
+
query_dim: int,
|
|
207
|
+
cross_attention_dim: Optional[int] = None,
|
|
208
|
+
heads: int = 8,
|
|
209
|
+
dim_head: int = 64,
|
|
210
|
+
dropout: float = 0.0,
|
|
211
|
+
bias: bool = False,
|
|
212
|
+
upcast_attention: bool = False,
|
|
213
|
+
upcast_softmax: bool = False,
|
|
214
|
+
cross_attention_norm: Optional[str] = None,
|
|
215
|
+
cross_attention_norm_num_groups: int = 32,
|
|
216
|
+
qk_norm: Optional[str] = None,
|
|
217
|
+
added_kv_proj_dim: Optional[int] = None,
|
|
218
|
+
norm_num_groups: Optional[int] = None,
|
|
219
|
+
spatial_norm_dim: Optional[int] = None,
|
|
220
|
+
out_bias: bool = True,
|
|
221
|
+
scale_qk: bool = True,
|
|
222
|
+
only_cross_attention: bool = False,
|
|
223
|
+
eps: float = 1e-5,
|
|
224
|
+
rescale_output_factor: float = 1.0,
|
|
225
|
+
residual_connection: bool = False,
|
|
226
|
+
_from_deprecated_attn_block: bool = False,
|
|
227
|
+
processor: Optional["AttnProcessor2_0"] = None,
|
|
228
|
+
out_dim: int = None,
|
|
229
|
+
):
|
|
230
|
+
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
|
|
231
|
+
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
|
|
232
|
+
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
|
|
233
|
+
_from_deprecated_attn_block, processor, out_dim)
|
|
234
|
+
processor = CausalAttnProcessor2_0()
|
|
235
|
+
self.set_processor(processor)
|
|
236
|
+
|
|
237
|
+
def forward(
|
|
238
|
+
self,
|
|
239
|
+
hidden_states: torch.FloatTensor,
|
|
240
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
241
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
242
|
+
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
243
|
+
**cross_attention_kwargs,
|
|
244
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
245
|
+
r"""
|
|
246
|
+
The forward method of the `Attention` class.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
hidden_states (`torch.Tensor`):
|
|
250
|
+
The hidden states of the query.
|
|
251
|
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
|
252
|
+
The hidden states of the encoder.
|
|
253
|
+
attention_mask (`torch.Tensor`, *optional*):
|
|
254
|
+
The attention mask to use. If `None`, no mask is applied.
|
|
255
|
+
**cross_attention_kwargs:
|
|
256
|
+
Additional keyword arguments to pass along to the cross attention.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
`torch.Tensor`: The output of the attention layer.
|
|
260
|
+
"""
|
|
261
|
+
# The `Attention` class can call different attention processors / attention functions
|
|
262
|
+
# here we simply pass along all tensors to the selected processor class
|
|
263
|
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
|
264
|
+
|
|
265
|
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
|
266
|
+
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
|
267
|
+
if len(unused_kwargs) > 0:
|
|
268
|
+
logger.warning(
|
|
269
|
+
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
|
270
|
+
)
|
|
271
|
+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
|
272
|
+
|
|
273
|
+
return self.processor(
|
|
274
|
+
self,
|
|
275
|
+
hidden_states,
|
|
276
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
277
|
+
attention_mask=attention_mask,
|
|
278
|
+
cache=cache,
|
|
279
|
+
**cross_attention_kwargs,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@maybe_allow_in_graph
|
|
284
|
+
class CausalBasicTransformerBlock(BasicTransformerBlock):
|
|
285
|
+
def __init__(
|
|
286
|
+
self,
|
|
287
|
+
dim: int,
|
|
288
|
+
num_attention_heads: int,
|
|
289
|
+
attention_head_dim: int,
|
|
290
|
+
dropout=0.0,
|
|
291
|
+
cross_attention_dim: Optional[int] = None,
|
|
292
|
+
activation_fn: str = "geglu",
|
|
293
|
+
num_embeds_ada_norm: Optional[int] = None,
|
|
294
|
+
attention_bias: bool = False,
|
|
295
|
+
only_cross_attention: bool = False,
|
|
296
|
+
double_self_attention: bool = False,
|
|
297
|
+
upcast_attention: bool = False,
|
|
298
|
+
norm_elementwise_affine: bool = True,
|
|
299
|
+
norm_type: str = "layer_norm",
|
|
300
|
+
final_dropout: bool = False,
|
|
301
|
+
):
|
|
302
|
+
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
|
|
303
|
+
cross_attention_dim, activation_fn, num_embeds_ada_norm,
|
|
304
|
+
attention_bias, only_cross_attention, double_self_attention,
|
|
305
|
+
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
|
|
306
|
+
self.attn1 = CausalAttention(
|
|
307
|
+
query_dim=dim,
|
|
308
|
+
heads=num_attention_heads,
|
|
309
|
+
dim_head=attention_head_dim,
|
|
310
|
+
dropout=dropout,
|
|
311
|
+
bias=attention_bias,
|
|
312
|
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
313
|
+
upcast_attention=upcast_attention,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
def forward(
|
|
317
|
+
self,
|
|
318
|
+
hidden_states: torch.FloatTensor,
|
|
319
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
320
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
321
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
322
|
+
timestep: Optional[torch.LongTensor] = None,
|
|
323
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
|
324
|
+
class_labels: Optional[torch.LongTensor] = None,
|
|
325
|
+
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
326
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
327
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
328
|
+
# 1. Self-Attention
|
|
329
|
+
if self.use_ada_layer_norm:
|
|
330
|
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
331
|
+
elif self.use_ada_layer_norm_zero:
|
|
332
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
333
|
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
norm_hidden_states = self.norm1(hidden_states)
|
|
337
|
+
|
|
338
|
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
339
|
+
|
|
340
|
+
attn_output, cache = self.attn1(
|
|
341
|
+
norm_hidden_states,
|
|
342
|
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
343
|
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
|
344
|
+
cache=cache,
|
|
345
|
+
**cross_attention_kwargs,
|
|
346
|
+
)
|
|
347
|
+
if self.use_ada_layer_norm_zero:
|
|
348
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
349
|
+
hidden_states = attn_output + hidden_states
|
|
350
|
+
|
|
351
|
+
# 2. Cross-Attention
|
|
352
|
+
if self.attn2 is not None:
|
|
353
|
+
norm_hidden_states = (
|
|
354
|
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
attn_output = self.attn2(
|
|
358
|
+
norm_hidden_states,
|
|
359
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
360
|
+
attention_mask=encoder_attention_mask,
|
|
361
|
+
**cross_attention_kwargs,
|
|
362
|
+
)
|
|
363
|
+
hidden_states = attn_output + hidden_states
|
|
364
|
+
|
|
365
|
+
# 3. Feed-forward
|
|
366
|
+
norm_hidden_states = self.norm3(hidden_states)
|
|
367
|
+
|
|
368
|
+
if self.use_ada_layer_norm_zero:
|
|
369
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
370
|
+
|
|
371
|
+
if self._chunk_size is not None:
|
|
372
|
+
# "feed_forward_chunk_size" can be used to save memory
|
|
373
|
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
|
374
|
+
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
|
|
375
|
+
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
|
|
376
|
+
|
|
377
|
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
|
378
|
+
ff_output = torch.cat(
|
|
379
|
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
|
380
|
+
dim=self._chunk_dim,
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
ff_output = self.ff(norm_hidden_states)
|
|
384
|
+
|
|
385
|
+
if self.use_ada_layer_norm_zero:
|
|
386
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
387
|
+
|
|
388
|
+
hidden_states = ff_output + hidden_states
|
|
389
|
+
|
|
390
|
+
return hidden_states, cache
|
|
85
391
|
|
|
86
392
|
|
|
87
393
|
class ConditionalDecoder(nn.Module):
|
|
@@ -89,7 +395,6 @@ class ConditionalDecoder(nn.Module):
|
|
|
89
395
|
self,
|
|
90
396
|
in_channels,
|
|
91
397
|
out_channels,
|
|
92
|
-
causal=False,
|
|
93
398
|
channels=(256, 256),
|
|
94
399
|
dropout=0.05,
|
|
95
400
|
attention_head_dim=64,
|
|
@@ -106,7 +411,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
106
411
|
channels = tuple(channels)
|
|
107
412
|
self.in_channels = in_channels
|
|
108
413
|
self.out_channels = out_channels
|
|
109
|
-
|
|
414
|
+
|
|
110
415
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
111
416
|
time_embed_dim = channels[0] * 4
|
|
112
417
|
self.time_mlp = TimestepEmbedding(
|
|
@@ -123,8 +428,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
123
428
|
input_channel = output_channel
|
|
124
429
|
output_channel = channels[i]
|
|
125
430
|
is_last = i == len(channels) - 1
|
|
126
|
-
resnet =
|
|
127
|
-
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
431
|
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
128
432
|
transformer_blocks = nn.ModuleList(
|
|
129
433
|
[
|
|
130
434
|
BasicTransformerBlock(
|
|
@@ -138,16 +442,14 @@ class ConditionalDecoder(nn.Module):
|
|
|
138
442
|
]
|
|
139
443
|
)
|
|
140
444
|
downsample = (
|
|
141
|
-
Downsample1D(output_channel) if not is_last else
|
|
142
|
-
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
445
|
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
143
446
|
)
|
|
144
447
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
145
448
|
|
|
146
449
|
for _ in range(num_mid_blocks):
|
|
147
450
|
input_channel = channels[-1]
|
|
148
451
|
out_channels = channels[-1]
|
|
149
|
-
resnet =
|
|
150
|
-
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
452
|
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
151
453
|
|
|
152
454
|
transformer_blocks = nn.ModuleList(
|
|
153
455
|
[
|
|
@@ -169,11 +471,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
169
471
|
input_channel = channels[i] * 2
|
|
170
472
|
output_channel = channels[i + 1]
|
|
171
473
|
is_last = i == len(channels) - 2
|
|
172
|
-
resnet =
|
|
173
|
-
dim=input_channel,
|
|
174
|
-
dim_out=output_channel,
|
|
175
|
-
time_emb_dim=time_embed_dim,
|
|
176
|
-
) if self.causal else ResnetBlock1D(
|
|
474
|
+
resnet = ResnetBlock1D(
|
|
177
475
|
dim=input_channel,
|
|
178
476
|
dim_out=output_channel,
|
|
179
477
|
time_emb_dim=time_embed_dim,
|
|
@@ -193,10 +491,10 @@ class ConditionalDecoder(nn.Module):
|
|
|
193
491
|
upsample = (
|
|
194
492
|
Upsample1D(output_channel, use_conv_transpose=True)
|
|
195
493
|
if not is_last
|
|
196
|
-
else
|
|
494
|
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
197
495
|
)
|
|
198
496
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
199
|
-
self.final_block =
|
|
497
|
+
self.final_block = Block1D(channels[-1], channels[-1])
|
|
200
498
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
201
499
|
self.initialize_weights()
|
|
202
500
|
|
|
@@ -214,7 +512,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
214
512
|
if m.bias is not None:
|
|
215
513
|
nn.init.constant_(m.bias, 0)
|
|
216
514
|
|
|
217
|
-
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
|
515
|
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
|
218
516
|
"""Forward pass of the UNet1DConditional model.
|
|
219
517
|
|
|
220
518
|
Args:
|
|
@@ -249,9 +547,8 @@ class ConditionalDecoder(nn.Module):
|
|
|
249
547
|
mask_down = masks[-1]
|
|
250
548
|
x = resnet(x, mask_down, t)
|
|
251
549
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
252
|
-
|
|
253
|
-
attn_mask =
|
|
254
|
-
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
550
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
551
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
255
552
|
for transformer_block in transformer_blocks:
|
|
256
553
|
x = transformer_block(
|
|
257
554
|
hidden_states=x,
|
|
@@ -268,9 +565,8 @@ class ConditionalDecoder(nn.Module):
|
|
|
268
565
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
269
566
|
x = resnet(x, mask_mid, t)
|
|
270
567
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
271
|
-
|
|
272
|
-
attn_mask =
|
|
273
|
-
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
568
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
569
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
274
570
|
for transformer_block in transformer_blocks:
|
|
275
571
|
x = transformer_block(
|
|
276
572
|
hidden_states=x,
|
|
@@ -285,9 +581,8 @@ class ConditionalDecoder(nn.Module):
|
|
|
285
581
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
286
582
|
x = resnet(x, mask_up, t)
|
|
287
583
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
288
|
-
|
|
289
|
-
attn_mask =
|
|
290
|
-
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
584
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
585
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
291
586
|
for transformer_block in transformer_blocks:
|
|
292
587
|
x = transformer_block(
|
|
293
588
|
hidden_states=x,
|
|
@@ -299,3 +594,310 @@ class ConditionalDecoder(nn.Module):
|
|
|
299
594
|
x = self.final_block(x, mask_up)
|
|
300
595
|
output = self.final_proj(x * mask_up)
|
|
301
596
|
return output * mask
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class CausalConditionalDecoder(ConditionalDecoder):
|
|
600
|
+
def __init__(
|
|
601
|
+
self,
|
|
602
|
+
in_channels,
|
|
603
|
+
out_channels,
|
|
604
|
+
channels=(256, 256),
|
|
605
|
+
dropout=0.05,
|
|
606
|
+
attention_head_dim=64,
|
|
607
|
+
n_blocks=1,
|
|
608
|
+
num_mid_blocks=2,
|
|
609
|
+
num_heads=4,
|
|
610
|
+
act_fn="snake",
|
|
611
|
+
static_chunk_size=50,
|
|
612
|
+
num_decoding_left_chunks=2,
|
|
613
|
+
):
|
|
614
|
+
"""
|
|
615
|
+
This decoder requires an input with the same shape of the target. So, if your text content
|
|
616
|
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
|
617
|
+
"""
|
|
618
|
+
torch.nn.Module.__init__(self)
|
|
619
|
+
channels = tuple(channels)
|
|
620
|
+
self.in_channels = in_channels
|
|
621
|
+
self.out_channels = out_channels
|
|
622
|
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
623
|
+
time_embed_dim = channels[0] * 4
|
|
624
|
+
self.time_mlp = TimestepEmbedding(
|
|
625
|
+
in_channels=in_channels,
|
|
626
|
+
time_embed_dim=time_embed_dim,
|
|
627
|
+
act_fn="silu",
|
|
628
|
+
)
|
|
629
|
+
self.static_chunk_size = static_chunk_size
|
|
630
|
+
self.num_decoding_left_chunks = num_decoding_left_chunks
|
|
631
|
+
self.down_blocks = nn.ModuleList([])
|
|
632
|
+
self.mid_blocks = nn.ModuleList([])
|
|
633
|
+
self.up_blocks = nn.ModuleList([])
|
|
634
|
+
|
|
635
|
+
output_channel = in_channels
|
|
636
|
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
|
637
|
+
input_channel = output_channel
|
|
638
|
+
output_channel = channels[i]
|
|
639
|
+
is_last = i == len(channels) - 1
|
|
640
|
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
641
|
+
transformer_blocks = nn.ModuleList(
|
|
642
|
+
[
|
|
643
|
+
CausalBasicTransformerBlock(
|
|
644
|
+
dim=output_channel,
|
|
645
|
+
num_attention_heads=num_heads,
|
|
646
|
+
attention_head_dim=attention_head_dim,
|
|
647
|
+
dropout=dropout,
|
|
648
|
+
activation_fn=act_fn,
|
|
649
|
+
)
|
|
650
|
+
for _ in range(n_blocks)
|
|
651
|
+
]
|
|
652
|
+
)
|
|
653
|
+
downsample = (
|
|
654
|
+
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
|
655
|
+
)
|
|
656
|
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
657
|
+
|
|
658
|
+
for _ in range(num_mid_blocks):
|
|
659
|
+
input_channel = channels[-1]
|
|
660
|
+
out_channels = channels[-1]
|
|
661
|
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
662
|
+
|
|
663
|
+
transformer_blocks = nn.ModuleList(
|
|
664
|
+
[
|
|
665
|
+
CausalBasicTransformerBlock(
|
|
666
|
+
dim=output_channel,
|
|
667
|
+
num_attention_heads=num_heads,
|
|
668
|
+
attention_head_dim=attention_head_dim,
|
|
669
|
+
dropout=dropout,
|
|
670
|
+
activation_fn=act_fn,
|
|
671
|
+
)
|
|
672
|
+
for _ in range(n_blocks)
|
|
673
|
+
]
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
|
677
|
+
|
|
678
|
+
channels = channels[::-1] + (channels[0],)
|
|
679
|
+
for i in range(len(channels) - 1):
|
|
680
|
+
input_channel = channels[i] * 2
|
|
681
|
+
output_channel = channels[i + 1]
|
|
682
|
+
is_last = i == len(channels) - 2
|
|
683
|
+
resnet = CausalResnetBlock1D(
|
|
684
|
+
dim=input_channel,
|
|
685
|
+
dim_out=output_channel,
|
|
686
|
+
time_emb_dim=time_embed_dim,
|
|
687
|
+
)
|
|
688
|
+
transformer_blocks = nn.ModuleList(
|
|
689
|
+
[
|
|
690
|
+
CausalBasicTransformerBlock(
|
|
691
|
+
dim=output_channel,
|
|
692
|
+
num_attention_heads=num_heads,
|
|
693
|
+
attention_head_dim=attention_head_dim,
|
|
694
|
+
dropout=dropout,
|
|
695
|
+
activation_fn=act_fn,
|
|
696
|
+
)
|
|
697
|
+
for _ in range(n_blocks)
|
|
698
|
+
]
|
|
699
|
+
)
|
|
700
|
+
upsample = (
|
|
701
|
+
Upsample1D(output_channel, use_conv_transpose=True)
|
|
702
|
+
if not is_last
|
|
703
|
+
else CausalConv1d(output_channel, output_channel, 3)
|
|
704
|
+
)
|
|
705
|
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
706
|
+
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
|
707
|
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
708
|
+
self.initialize_weights()
|
|
709
|
+
|
|
710
|
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
|
711
|
+
"""Forward pass of the UNet1DConditional model.
|
|
712
|
+
|
|
713
|
+
Args:
|
|
714
|
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
715
|
+
mask (_type_): shape (batch_size, 1, time)
|
|
716
|
+
t (_type_): shape (batch_size)
|
|
717
|
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
718
|
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
719
|
+
|
|
720
|
+
Raises:
|
|
721
|
+
ValueError: _description_
|
|
722
|
+
ValueError: _description_
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
_type_: _description_
|
|
726
|
+
"""
|
|
727
|
+
|
|
728
|
+
t = self.time_embeddings(t).to(t.dtype)
|
|
729
|
+
t = self.time_mlp(t)
|
|
730
|
+
|
|
731
|
+
x = pack([x, mu], "b * t")[0]
|
|
732
|
+
|
|
733
|
+
if spks is not None:
|
|
734
|
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
735
|
+
x = pack([x, spks], "b * t")[0]
|
|
736
|
+
if cond is not None:
|
|
737
|
+
x = pack([x, cond], "b * t")[0]
|
|
738
|
+
|
|
739
|
+
hiddens = []
|
|
740
|
+
masks = [mask]
|
|
741
|
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
|
742
|
+
mask_down = masks[-1]
|
|
743
|
+
x, _, _ = resnet(x, mask_down, t)
|
|
744
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
745
|
+
if streaming is True:
|
|
746
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
747
|
+
else:
|
|
748
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
749
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
750
|
+
for transformer_block in transformer_blocks:
|
|
751
|
+
x, _ = transformer_block(
|
|
752
|
+
hidden_states=x,
|
|
753
|
+
attention_mask=attn_mask,
|
|
754
|
+
timestep=t,
|
|
755
|
+
)
|
|
756
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
757
|
+
hiddens.append(x) # Save hidden states for skip connections
|
|
758
|
+
x, _ = downsample(x * mask_down)
|
|
759
|
+
masks.append(mask_down[:, :, ::2])
|
|
760
|
+
masks = masks[:-1]
|
|
761
|
+
mask_mid = masks[-1]
|
|
762
|
+
|
|
763
|
+
for resnet, transformer_blocks in self.mid_blocks:
|
|
764
|
+
x, _, _ = resnet(x, mask_mid, t)
|
|
765
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
766
|
+
if streaming is True:
|
|
767
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
768
|
+
else:
|
|
769
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
770
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
771
|
+
for transformer_block in transformer_blocks:
|
|
772
|
+
x, _ = transformer_block(
|
|
773
|
+
hidden_states=x,
|
|
774
|
+
attention_mask=attn_mask,
|
|
775
|
+
timestep=t,
|
|
776
|
+
)
|
|
777
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
778
|
+
|
|
779
|
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
|
780
|
+
mask_up = masks.pop()
|
|
781
|
+
skip = hiddens.pop()
|
|
782
|
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
783
|
+
x, _, _ = resnet(x, mask_up, t)
|
|
784
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
785
|
+
if streaming is True:
|
|
786
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
787
|
+
else:
|
|
788
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
789
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
790
|
+
for transformer_block in transformer_blocks:
|
|
791
|
+
x, _ = transformer_block(
|
|
792
|
+
hidden_states=x,
|
|
793
|
+
attention_mask=attn_mask,
|
|
794
|
+
timestep=t,
|
|
795
|
+
)
|
|
796
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
797
|
+
x, _ = upsample(x * mask_up)
|
|
798
|
+
x, _ = self.final_block(x, mask_up)
|
|
799
|
+
output = self.final_proj(x * mask_up)
|
|
800
|
+
return output * mask
|
|
801
|
+
|
|
802
|
+
@torch.inference_mode()
|
|
803
|
+
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
|
804
|
+
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
805
|
+
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
806
|
+
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
807
|
+
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
808
|
+
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
809
|
+
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
810
|
+
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
|
811
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
812
|
+
"""Forward pass of the UNet1DConditional model.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
816
|
+
mask (_type_): shape (batch_size, 1, time)
|
|
817
|
+
t (_type_): shape (batch_size)
|
|
818
|
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
819
|
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
820
|
+
|
|
821
|
+
Raises:
|
|
822
|
+
ValueError: _description_
|
|
823
|
+
ValueError: _description_
|
|
824
|
+
|
|
825
|
+
Returns:
|
|
826
|
+
_type_: _description_
|
|
827
|
+
"""
|
|
828
|
+
|
|
829
|
+
t = self.time_embeddings(t).to(t.dtype)
|
|
830
|
+
t = self.time_mlp(t)
|
|
831
|
+
|
|
832
|
+
x = pack([x, mu], "b * t")[0]
|
|
833
|
+
|
|
834
|
+
if spks is not None:
|
|
835
|
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
836
|
+
x = pack([x, spks], "b * t")[0]
|
|
837
|
+
if cond is not None:
|
|
838
|
+
x = pack([x, cond], "b * t")[0]
|
|
839
|
+
|
|
840
|
+
hiddens = []
|
|
841
|
+
masks = [mask]
|
|
842
|
+
|
|
843
|
+
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
844
|
+
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
845
|
+
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
846
|
+
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
|
|
847
|
+
mask_down = masks[-1]
|
|
848
|
+
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
|
|
849
|
+
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
|
|
850
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
851
|
+
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
|
|
852
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
853
|
+
for i, transformer_block in enumerate(transformer_blocks):
|
|
854
|
+
x, down_blocks_kv_cache_new[index, i] = transformer_block(
|
|
855
|
+
hidden_states=x,
|
|
856
|
+
attention_mask=attn_mask,
|
|
857
|
+
timestep=t,
|
|
858
|
+
cache=down_blocks_kv_cache[index, i],
|
|
859
|
+
)
|
|
860
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
861
|
+
hiddens.append(x) # Save hidden states for skip connections
|
|
862
|
+
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
|
|
863
|
+
masks.append(mask_down[:, :, ::2])
|
|
864
|
+
masks = masks[:-1]
|
|
865
|
+
mask_mid = masks[-1]
|
|
866
|
+
|
|
867
|
+
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
|
|
868
|
+
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
|
|
869
|
+
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
|
|
870
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
871
|
+
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
|
|
872
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
873
|
+
for i, transformer_block in enumerate(transformer_blocks):
|
|
874
|
+
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
|
|
875
|
+
hidden_states=x,
|
|
876
|
+
attention_mask=attn_mask,
|
|
877
|
+
timestep=t,
|
|
878
|
+
cache=mid_blocks_kv_cache[index, i]
|
|
879
|
+
)
|
|
880
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
881
|
+
|
|
882
|
+
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
|
|
883
|
+
mask_up = masks.pop()
|
|
884
|
+
skip = hiddens.pop()
|
|
885
|
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
886
|
+
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
|
|
887
|
+
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
|
|
888
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
889
|
+
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
|
|
890
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
891
|
+
for i, transformer_block in enumerate(transformer_blocks):
|
|
892
|
+
x, up_blocks_kv_cache_new[index, i] = transformer_block(
|
|
893
|
+
hidden_states=x,
|
|
894
|
+
attention_mask=attn_mask,
|
|
895
|
+
timestep=t,
|
|
896
|
+
cache=up_blocks_kv_cache[index, i]
|
|
897
|
+
)
|
|
898
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
899
|
+
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
|
|
900
|
+
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
|
|
901
|
+
output = self.final_proj(x * mask_up)
|
|
902
|
+
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
|
|
903
|
+
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
|