xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
|
@@ -11,16 +11,15 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import Tuple
|
|
14
|
+
from typing import Tuple
|
|
15
15
|
import torch
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
import torch.nn.functional as F
|
|
18
18
|
from einops import pack, rearrange, repeat
|
|
19
|
-
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
|
|
20
19
|
from cosyvoice.utils.common import mask_to_bias
|
|
21
20
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
|
22
21
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
|
23
|
-
from matcha.models.components.transformer import BasicTransformerBlock
|
|
22
|
+
from matcha.models.components.transformer import BasicTransformerBlock
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class Transpose(torch.nn.Module):
|
|
@@ -29,7 +28,7 @@ class Transpose(torch.nn.Module):
|
|
|
29
28
|
self.dim0 = dim0
|
|
30
29
|
self.dim1 = dim1
|
|
31
30
|
|
|
32
|
-
def forward(self, x: torch.Tensor) ->
|
|
31
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
33
32
|
x = torch.transpose(x, self.dim0, self.dim1)
|
|
34
33
|
return x
|
|
35
34
|
|
|
@@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d):
|
|
|
57
56
|
assert stride == 1
|
|
58
57
|
self.causal_padding = kernel_size - 1
|
|
59
58
|
|
|
60
|
-
def forward(self, x: torch.Tensor
|
|
61
|
-
|
|
62
|
-
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
|
63
|
-
else:
|
|
64
|
-
assert cache.size(2) == self.causal_padding
|
|
65
|
-
x = torch.concat([cache, x], dim=2)
|
|
66
|
-
cache = x[:, :, -self.causal_padding:]
|
|
59
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
|
67
61
|
x = super(CausalConv1d, self).forward(x)
|
|
68
|
-
return x
|
|
62
|
+
return x
|
|
69
63
|
|
|
70
64
|
|
|
71
65
|
class CausalBlock1D(Block1D):
|
|
@@ -79,11 +73,9 @@ class CausalBlock1D(Block1D):
|
|
|
79
73
|
nn.Mish(),
|
|
80
74
|
)
|
|
81
75
|
|
|
82
|
-
def forward(self, x: torch.Tensor, mask: torch.Tensor
|
|
83
|
-
output
|
|
84
|
-
|
|
85
|
-
output = self.block[i](output)
|
|
86
|
-
return output * mask, cache
|
|
76
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
77
|
+
output = self.block(x * mask)
|
|
78
|
+
return output * mask
|
|
87
79
|
|
|
88
80
|
|
|
89
81
|
class CausalResnetBlock1D(ResnetBlock1D):
|
|
@@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D):
|
|
|
92
84
|
self.block1 = CausalBlock1D(dim, dim_out)
|
|
93
85
|
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
94
86
|
|
|
95
|
-
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
|
|
96
|
-
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
|
97
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
98
|
-
h, block1_cache = self.block1(x, mask, block1_cache)
|
|
99
|
-
h += self.mlp(time_emb).unsqueeze(-1)
|
|
100
|
-
h, block2_cache = self.block2(h, mask, block2_cache)
|
|
101
|
-
output = h + self.res_conv(x * mask)
|
|
102
|
-
return output, block1_cache, block2_cache
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
class CausalAttnProcessor2_0(AttnProcessor2_0):
|
|
106
|
-
r"""
|
|
107
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
def __init__(self):
|
|
111
|
-
super(CausalAttnProcessor2_0, self).__init__()
|
|
112
|
-
|
|
113
|
-
def __call__(
|
|
114
|
-
self,
|
|
115
|
-
attn: Attention,
|
|
116
|
-
hidden_states: torch.FloatTensor,
|
|
117
|
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
118
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
|
119
|
-
temb: Optional[torch.FloatTensor] = None,
|
|
120
|
-
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
121
|
-
*args,
|
|
122
|
-
**kwargs,
|
|
123
|
-
) -> Tuple[torch.FloatTensor, torch.Tensor]:
|
|
124
|
-
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
|
125
|
-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
|
|
126
|
-
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
|
127
|
-
deprecate("scale", "1.0.0", deprecation_message)
|
|
128
|
-
|
|
129
|
-
residual = hidden_states
|
|
130
|
-
if attn.spatial_norm is not None:
|
|
131
|
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
132
|
-
|
|
133
|
-
input_ndim = hidden_states.ndim
|
|
134
|
-
|
|
135
|
-
if input_ndim == 4:
|
|
136
|
-
batch_size, channel, height, width = hidden_states.shape
|
|
137
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
138
|
-
|
|
139
|
-
batch_size, sequence_length, _ = (
|
|
140
|
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
if attention_mask is not None:
|
|
144
|
-
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
|
|
145
|
-
# scaled_dot_product_attention expects attention_mask shape to be
|
|
146
|
-
# (batch, heads, source_length, target_length)
|
|
147
|
-
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
|
|
148
|
-
|
|
149
|
-
if attn.group_norm is not None:
|
|
150
|
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
151
|
-
|
|
152
|
-
query = attn.to_q(hidden_states)
|
|
153
|
-
|
|
154
|
-
if encoder_hidden_states is None:
|
|
155
|
-
encoder_hidden_states = hidden_states
|
|
156
|
-
elif attn.norm_cross:
|
|
157
|
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
158
|
-
|
|
159
|
-
key_cache = attn.to_k(encoder_hidden_states)
|
|
160
|
-
value_cache = attn.to_v(encoder_hidden_states)
|
|
161
|
-
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
|
162
|
-
if cache.size(0) != 0:
|
|
163
|
-
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
|
164
|
-
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
|
165
|
-
else:
|
|
166
|
-
key, value = key_cache, value_cache
|
|
167
|
-
cache = torch.stack([key_cache, value_cache], dim=3)
|
|
168
|
-
|
|
169
|
-
inner_dim = key.shape[-1]
|
|
170
|
-
head_dim = inner_dim // attn.heads
|
|
171
|
-
|
|
172
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
173
|
-
|
|
174
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
175
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
176
|
-
|
|
177
|
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
178
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
179
|
-
hidden_states = F.scaled_dot_product_attention(
|
|
180
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
184
|
-
hidden_states = hidden_states.to(query.dtype)
|
|
185
|
-
|
|
186
|
-
# linear proj
|
|
187
|
-
hidden_states = attn.to_out[0](hidden_states)
|
|
188
|
-
# dropout
|
|
189
|
-
hidden_states = attn.to_out[1](hidden_states)
|
|
190
|
-
|
|
191
|
-
if input_ndim == 4:
|
|
192
|
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
193
|
-
|
|
194
|
-
if attn.residual_connection:
|
|
195
|
-
hidden_states = hidden_states + residual
|
|
196
|
-
|
|
197
|
-
hidden_states = hidden_states / attn.rescale_output_factor
|
|
198
|
-
|
|
199
|
-
return hidden_states, cache
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
@maybe_allow_in_graph
|
|
203
|
-
class CausalAttention(Attention):
|
|
204
|
-
def __init__(
|
|
205
|
-
self,
|
|
206
|
-
query_dim: int,
|
|
207
|
-
cross_attention_dim: Optional[int] = None,
|
|
208
|
-
heads: int = 8,
|
|
209
|
-
dim_head: int = 64,
|
|
210
|
-
dropout: float = 0.0,
|
|
211
|
-
bias: bool = False,
|
|
212
|
-
upcast_attention: bool = False,
|
|
213
|
-
upcast_softmax: bool = False,
|
|
214
|
-
cross_attention_norm: Optional[str] = None,
|
|
215
|
-
cross_attention_norm_num_groups: int = 32,
|
|
216
|
-
qk_norm: Optional[str] = None,
|
|
217
|
-
added_kv_proj_dim: Optional[int] = None,
|
|
218
|
-
norm_num_groups: Optional[int] = None,
|
|
219
|
-
spatial_norm_dim: Optional[int] = None,
|
|
220
|
-
out_bias: bool = True,
|
|
221
|
-
scale_qk: bool = True,
|
|
222
|
-
only_cross_attention: bool = False,
|
|
223
|
-
eps: float = 1e-5,
|
|
224
|
-
rescale_output_factor: float = 1.0,
|
|
225
|
-
residual_connection: bool = False,
|
|
226
|
-
_from_deprecated_attn_block: bool = False,
|
|
227
|
-
processor: Optional["AttnProcessor2_0"] = None,
|
|
228
|
-
out_dim: int = None,
|
|
229
|
-
):
|
|
230
|
-
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
|
|
231
|
-
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
|
|
232
|
-
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
|
|
233
|
-
_from_deprecated_attn_block, processor, out_dim)
|
|
234
|
-
processor = CausalAttnProcessor2_0()
|
|
235
|
-
self.set_processor(processor)
|
|
236
|
-
|
|
237
|
-
def forward(
|
|
238
|
-
self,
|
|
239
|
-
hidden_states: torch.FloatTensor,
|
|
240
|
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
241
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
|
242
|
-
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
243
|
-
**cross_attention_kwargs,
|
|
244
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
245
|
-
r"""
|
|
246
|
-
The forward method of the `Attention` class.
|
|
247
|
-
|
|
248
|
-
Args:
|
|
249
|
-
hidden_states (`torch.Tensor`):
|
|
250
|
-
The hidden states of the query.
|
|
251
|
-
encoder_hidden_states (`torch.Tensor`, *optional*):
|
|
252
|
-
The hidden states of the encoder.
|
|
253
|
-
attention_mask (`torch.Tensor`, *optional*):
|
|
254
|
-
The attention mask to use. If `None`, no mask is applied.
|
|
255
|
-
**cross_attention_kwargs:
|
|
256
|
-
Additional keyword arguments to pass along to the cross attention.
|
|
257
|
-
|
|
258
|
-
Returns:
|
|
259
|
-
`torch.Tensor`: The output of the attention layer.
|
|
260
|
-
"""
|
|
261
|
-
# The `Attention` class can call different attention processors / attention functions
|
|
262
|
-
# here we simply pass along all tensors to the selected processor class
|
|
263
|
-
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
|
264
|
-
|
|
265
|
-
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
|
266
|
-
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
|
267
|
-
if len(unused_kwargs) > 0:
|
|
268
|
-
logger.warning(
|
|
269
|
-
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
|
270
|
-
)
|
|
271
|
-
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
|
272
|
-
|
|
273
|
-
return self.processor(
|
|
274
|
-
self,
|
|
275
|
-
hidden_states,
|
|
276
|
-
encoder_hidden_states=encoder_hidden_states,
|
|
277
|
-
attention_mask=attention_mask,
|
|
278
|
-
cache=cache,
|
|
279
|
-
**cross_attention_kwargs,
|
|
280
|
-
)
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
@maybe_allow_in_graph
|
|
284
|
-
class CausalBasicTransformerBlock(BasicTransformerBlock):
|
|
285
|
-
def __init__(
|
|
286
|
-
self,
|
|
287
|
-
dim: int,
|
|
288
|
-
num_attention_heads: int,
|
|
289
|
-
attention_head_dim: int,
|
|
290
|
-
dropout=0.0,
|
|
291
|
-
cross_attention_dim: Optional[int] = None,
|
|
292
|
-
activation_fn: str = "geglu",
|
|
293
|
-
num_embeds_ada_norm: Optional[int] = None,
|
|
294
|
-
attention_bias: bool = False,
|
|
295
|
-
only_cross_attention: bool = False,
|
|
296
|
-
double_self_attention: bool = False,
|
|
297
|
-
upcast_attention: bool = False,
|
|
298
|
-
norm_elementwise_affine: bool = True,
|
|
299
|
-
norm_type: str = "layer_norm",
|
|
300
|
-
final_dropout: bool = False,
|
|
301
|
-
):
|
|
302
|
-
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
|
|
303
|
-
cross_attention_dim, activation_fn, num_embeds_ada_norm,
|
|
304
|
-
attention_bias, only_cross_attention, double_self_attention,
|
|
305
|
-
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
|
|
306
|
-
self.attn1 = CausalAttention(
|
|
307
|
-
query_dim=dim,
|
|
308
|
-
heads=num_attention_heads,
|
|
309
|
-
dim_head=attention_head_dim,
|
|
310
|
-
dropout=dropout,
|
|
311
|
-
bias=attention_bias,
|
|
312
|
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
313
|
-
upcast_attention=upcast_attention,
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
def forward(
|
|
317
|
-
self,
|
|
318
|
-
hidden_states: torch.FloatTensor,
|
|
319
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
|
320
|
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
321
|
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
322
|
-
timestep: Optional[torch.LongTensor] = None,
|
|
323
|
-
cross_attention_kwargs: Dict[str, Any] = None,
|
|
324
|
-
class_labels: Optional[torch.LongTensor] = None,
|
|
325
|
-
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
326
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
327
|
-
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
328
|
-
# 1. Self-Attention
|
|
329
|
-
if self.use_ada_layer_norm:
|
|
330
|
-
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
331
|
-
elif self.use_ada_layer_norm_zero:
|
|
332
|
-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
333
|
-
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
334
|
-
)
|
|
335
|
-
else:
|
|
336
|
-
norm_hidden_states = self.norm1(hidden_states)
|
|
337
|
-
|
|
338
|
-
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
339
|
-
|
|
340
|
-
attn_output, cache = self.attn1(
|
|
341
|
-
norm_hidden_states,
|
|
342
|
-
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
343
|
-
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
|
344
|
-
cache=cache,
|
|
345
|
-
**cross_attention_kwargs,
|
|
346
|
-
)
|
|
347
|
-
if self.use_ada_layer_norm_zero:
|
|
348
|
-
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
349
|
-
hidden_states = attn_output + hidden_states
|
|
350
|
-
|
|
351
|
-
# 2. Cross-Attention
|
|
352
|
-
if self.attn2 is not None:
|
|
353
|
-
norm_hidden_states = (
|
|
354
|
-
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
attn_output = self.attn2(
|
|
358
|
-
norm_hidden_states,
|
|
359
|
-
encoder_hidden_states=encoder_hidden_states,
|
|
360
|
-
attention_mask=encoder_attention_mask,
|
|
361
|
-
**cross_attention_kwargs,
|
|
362
|
-
)
|
|
363
|
-
hidden_states = attn_output + hidden_states
|
|
364
|
-
|
|
365
|
-
# 3. Feed-forward
|
|
366
|
-
norm_hidden_states = self.norm3(hidden_states)
|
|
367
|
-
|
|
368
|
-
if self.use_ada_layer_norm_zero:
|
|
369
|
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
370
|
-
|
|
371
|
-
if self._chunk_size is not None:
|
|
372
|
-
# "feed_forward_chunk_size" can be used to save memory
|
|
373
|
-
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
|
374
|
-
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
|
|
375
|
-
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
|
|
376
|
-
|
|
377
|
-
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
|
378
|
-
ff_output = torch.cat(
|
|
379
|
-
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
|
380
|
-
dim=self._chunk_dim,
|
|
381
|
-
)
|
|
382
|
-
else:
|
|
383
|
-
ff_output = self.ff(norm_hidden_states)
|
|
384
|
-
|
|
385
|
-
if self.use_ada_layer_norm_zero:
|
|
386
|
-
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
387
|
-
|
|
388
|
-
hidden_states = ff_output + hidden_states
|
|
389
|
-
|
|
390
|
-
return hidden_states, cache
|
|
391
|
-
|
|
392
87
|
|
|
393
88
|
class ConditionalDecoder(nn.Module):
|
|
394
89
|
def __init__(
|
|
@@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
640
335
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
641
336
|
transformer_blocks = nn.ModuleList(
|
|
642
337
|
[
|
|
643
|
-
|
|
338
|
+
BasicTransformerBlock(
|
|
644
339
|
dim=output_channel,
|
|
645
340
|
num_attention_heads=num_heads,
|
|
646
341
|
attention_head_dim=attention_head_dim,
|
|
@@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
662
357
|
|
|
663
358
|
transformer_blocks = nn.ModuleList(
|
|
664
359
|
[
|
|
665
|
-
|
|
360
|
+
BasicTransformerBlock(
|
|
666
361
|
dim=output_channel,
|
|
667
362
|
num_attention_heads=num_heads,
|
|
668
363
|
attention_head_dim=attention_head_dim,
|
|
@@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
687
382
|
)
|
|
688
383
|
transformer_blocks = nn.ModuleList(
|
|
689
384
|
[
|
|
690
|
-
|
|
385
|
+
BasicTransformerBlock(
|
|
691
386
|
dim=output_channel,
|
|
692
387
|
num_attention_heads=num_heads,
|
|
693
388
|
attention_head_dim=attention_head_dim,
|
|
@@ -724,7 +419,6 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
724
419
|
Returns:
|
|
725
420
|
_type_: _description_
|
|
726
421
|
"""
|
|
727
|
-
|
|
728
422
|
t = self.time_embeddings(t).to(t.dtype)
|
|
729
423
|
t = self.time_mlp(t)
|
|
730
424
|
|
|
@@ -740,36 +434,36 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
740
434
|
masks = [mask]
|
|
741
435
|
for resnet, transformer_blocks, downsample in self.down_blocks:
|
|
742
436
|
mask_down = masks[-1]
|
|
743
|
-
x
|
|
437
|
+
x = resnet(x, mask_down, t)
|
|
744
438
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
745
439
|
if streaming is True:
|
|
746
|
-
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size,
|
|
440
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
747
441
|
else:
|
|
748
442
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
749
443
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
750
444
|
for transformer_block in transformer_blocks:
|
|
751
|
-
x
|
|
445
|
+
x = transformer_block(
|
|
752
446
|
hidden_states=x,
|
|
753
447
|
attention_mask=attn_mask,
|
|
754
448
|
timestep=t,
|
|
755
449
|
)
|
|
756
450
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
757
451
|
hiddens.append(x) # Save hidden states for skip connections
|
|
758
|
-
x
|
|
452
|
+
x = downsample(x * mask_down)
|
|
759
453
|
masks.append(mask_down[:, :, ::2])
|
|
760
454
|
masks = masks[:-1]
|
|
761
455
|
mask_mid = masks[-1]
|
|
762
456
|
|
|
763
457
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
764
|
-
x
|
|
458
|
+
x = resnet(x, mask_mid, t)
|
|
765
459
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
766
460
|
if streaming is True:
|
|
767
|
-
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size,
|
|
461
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
768
462
|
else:
|
|
769
463
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
770
464
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
771
465
|
for transformer_block in transformer_blocks:
|
|
772
|
-
x
|
|
466
|
+
x = transformer_block(
|
|
773
467
|
hidden_states=x,
|
|
774
468
|
attention_mask=attn_mask,
|
|
775
469
|
timestep=t,
|
|
@@ -780,124 +474,21 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
780
474
|
mask_up = masks.pop()
|
|
781
475
|
skip = hiddens.pop()
|
|
782
476
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
783
|
-
x
|
|
477
|
+
x = resnet(x, mask_up, t)
|
|
784
478
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
785
479
|
if streaming is True:
|
|
786
|
-
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size,
|
|
480
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
787
481
|
else:
|
|
788
482
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
789
483
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
790
484
|
for transformer_block in transformer_blocks:
|
|
791
|
-
x
|
|
485
|
+
x = transformer_block(
|
|
792
486
|
hidden_states=x,
|
|
793
487
|
attention_mask=attn_mask,
|
|
794
488
|
timestep=t,
|
|
795
489
|
)
|
|
796
490
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
797
|
-
x
|
|
798
|
-
x
|
|
491
|
+
x = upsample(x * mask_up)
|
|
492
|
+
x = self.final_block(x, mask_up)
|
|
799
493
|
output = self.final_proj(x * mask_up)
|
|
800
494
|
return output * mask
|
|
801
|
-
|
|
802
|
-
@torch.inference_mode()
|
|
803
|
-
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
|
804
|
-
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
805
|
-
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
806
|
-
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
807
|
-
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
808
|
-
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
809
|
-
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
810
|
-
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
|
811
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
812
|
-
"""Forward pass of the UNet1DConditional model.
|
|
813
|
-
|
|
814
|
-
Args:
|
|
815
|
-
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
816
|
-
mask (_type_): shape (batch_size, 1, time)
|
|
817
|
-
t (_type_): shape (batch_size)
|
|
818
|
-
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
819
|
-
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
820
|
-
|
|
821
|
-
Raises:
|
|
822
|
-
ValueError: _description_
|
|
823
|
-
ValueError: _description_
|
|
824
|
-
|
|
825
|
-
Returns:
|
|
826
|
-
_type_: _description_
|
|
827
|
-
"""
|
|
828
|
-
|
|
829
|
-
t = self.time_embeddings(t).to(t.dtype)
|
|
830
|
-
t = self.time_mlp(t)
|
|
831
|
-
|
|
832
|
-
x = pack([x, mu], "b * t")[0]
|
|
833
|
-
|
|
834
|
-
if spks is not None:
|
|
835
|
-
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
836
|
-
x = pack([x, spks], "b * t")[0]
|
|
837
|
-
if cond is not None:
|
|
838
|
-
x = pack([x, cond], "b * t")[0]
|
|
839
|
-
|
|
840
|
-
hiddens = []
|
|
841
|
-
masks = [mask]
|
|
842
|
-
|
|
843
|
-
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
844
|
-
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
845
|
-
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
846
|
-
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
|
|
847
|
-
mask_down = masks[-1]
|
|
848
|
-
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
|
|
849
|
-
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
|
|
850
|
-
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
851
|
-
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
|
|
852
|
-
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
853
|
-
for i, transformer_block in enumerate(transformer_blocks):
|
|
854
|
-
x, down_blocks_kv_cache_new[index, i] = transformer_block(
|
|
855
|
-
hidden_states=x,
|
|
856
|
-
attention_mask=attn_mask,
|
|
857
|
-
timestep=t,
|
|
858
|
-
cache=down_blocks_kv_cache[index, i],
|
|
859
|
-
)
|
|
860
|
-
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
861
|
-
hiddens.append(x) # Save hidden states for skip connections
|
|
862
|
-
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
|
|
863
|
-
masks.append(mask_down[:, :, ::2])
|
|
864
|
-
masks = masks[:-1]
|
|
865
|
-
mask_mid = masks[-1]
|
|
866
|
-
|
|
867
|
-
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
|
|
868
|
-
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
|
|
869
|
-
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
|
|
870
|
-
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
871
|
-
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
|
|
872
|
-
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
873
|
-
for i, transformer_block in enumerate(transformer_blocks):
|
|
874
|
-
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
|
|
875
|
-
hidden_states=x,
|
|
876
|
-
attention_mask=attn_mask,
|
|
877
|
-
timestep=t,
|
|
878
|
-
cache=mid_blocks_kv_cache[index, i]
|
|
879
|
-
)
|
|
880
|
-
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
881
|
-
|
|
882
|
-
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
|
|
883
|
-
mask_up = masks.pop()
|
|
884
|
-
skip = hiddens.pop()
|
|
885
|
-
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
886
|
-
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
|
|
887
|
-
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
|
|
888
|
-
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
889
|
-
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
|
|
890
|
-
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
891
|
-
for i, transformer_block in enumerate(transformer_blocks):
|
|
892
|
-
x, up_blocks_kv_cache_new[index, i] = transformer_block(
|
|
893
|
-
hidden_states=x,
|
|
894
|
-
attention_mask=attn_mask,
|
|
895
|
-
timestep=t,
|
|
896
|
-
cache=up_blocks_kv_cache[index, i]
|
|
897
|
-
)
|
|
898
|
-
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
899
|
-
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
|
|
900
|
-
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
|
|
901
|
-
output = self.final_proj(x * mask_up)
|
|
902
|
-
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
|
|
903
|
-
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
|
|
@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
92
92
|
|
|
93
93
|
mask = (~make_pad_mask(feat_len)).to(h)
|
|
94
94
|
# NOTE this is unnecessary, feat/h already same shape
|
|
95
|
-
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
96
95
|
loss, _ = self.decoder.compute_loss(
|
|
97
96
|
feat.transpose(1, 2).contiguous(),
|
|
98
97
|
mask.unsqueeze(1),
|
|
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
214
213
|
h = self.encoder_proj(h)
|
|
215
214
|
|
|
216
215
|
# get conditions
|
|
217
|
-
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
218
216
|
conds = torch.zeros(feat.shape, device=token.device)
|
|
219
217
|
for i, j in enumerate(feat_len):
|
|
220
218
|
if random.random() < 0.5:
|
|
@@ -243,7 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
243
241
|
prompt_feat,
|
|
244
242
|
prompt_feat_len,
|
|
245
243
|
embedding,
|
|
246
|
-
|
|
244
|
+
streaming,
|
|
247
245
|
finalize):
|
|
248
246
|
assert token.shape[0] == 1
|
|
249
247
|
# xvec projection
|
|
@@ -257,16 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
257
255
|
|
|
258
256
|
# text encode
|
|
259
257
|
if finalize is True:
|
|
260
|
-
h, h_lengths
|
|
258
|
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
|
261
259
|
else:
|
|
262
260
|
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
|
263
|
-
h, h_lengths
|
|
264
|
-
cache['encoder_cache']['offset'] = encoder_cache[0]
|
|
265
|
-
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
|
|
266
|
-
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
|
|
267
|
-
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
|
|
268
|
-
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
|
|
269
|
-
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
|
|
261
|
+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
|
270
262
|
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
|
271
263
|
h = self.encoder_proj(h)
|
|
272
264
|
|
|
@@ -276,14 +268,14 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
276
268
|
conds = conds.transpose(1, 2)
|
|
277
269
|
|
|
278
270
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
279
|
-
feat,
|
|
271
|
+
feat, _ = self.decoder(
|
|
280
272
|
mu=h.transpose(1, 2).contiguous(),
|
|
281
273
|
mask=mask.unsqueeze(1),
|
|
282
274
|
spks=embedding,
|
|
283
275
|
cond=conds,
|
|
284
276
|
n_timesteps=10,
|
|
285
|
-
|
|
277
|
+
streaming=streaming
|
|
286
278
|
)
|
|
287
279
|
feat = feat[:, :, mel_len1:]
|
|
288
280
|
assert feat.shape[2] == mel_len2
|
|
289
|
-
return feat.float(),
|
|
281
|
+
return feat.float(), None
|