xinference 1.5.1__py3-none-any.whl → 1.6.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/METADATA +59 -39
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/top_level.txt +0 -0
@@ -11,14 +11,16 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ from typing import Tuple, Optional, Dict, Any
14
15
  import torch
15
16
  import torch.nn as nn
16
17
  import torch.nn.functional as F
17
18
  from einops import pack, rearrange, repeat
19
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
18
20
  from cosyvoice.utils.common import mask_to_bias
19
21
  from cosyvoice.utils.mask import add_optional_chunk_mask
20
22
  from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
21
- from matcha.models.components.transformer import BasicTransformerBlock
23
+ from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
22
24
 
23
25
 
24
26
  class Transpose(torch.nn.Module):
@@ -27,34 +29,11 @@ class Transpose(torch.nn.Module):
27
29
  self.dim0 = dim0
28
30
  self.dim1 = dim1
29
31
 
30
- def forward(self, x: torch.Tensor):
32
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
31
33
  x = torch.transpose(x, self.dim0, self.dim1)
32
34
  return x
33
35
 
34
36
 
35
- class CausalBlock1D(Block1D):
36
- def __init__(self, dim: int, dim_out: int):
37
- super(CausalBlock1D, self).__init__(dim, dim_out)
38
- self.block = torch.nn.Sequential(
39
- CausalConv1d(dim, dim_out, 3),
40
- Transpose(1, 2),
41
- nn.LayerNorm(dim_out),
42
- Transpose(1, 2),
43
- nn.Mish(),
44
- )
45
-
46
- def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
- output = self.block(x * mask)
48
- return output * mask
49
-
50
-
51
- class CausalResnetBlock1D(ResnetBlock1D):
52
- def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
53
- super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
- self.block1 = CausalBlock1D(dim, dim_out)
55
- self.block2 = CausalBlock1D(dim_out, dim_out)
56
-
57
-
58
37
  class CausalConv1d(torch.nn.Conv1d):
59
38
  def __init__(
60
39
  self,
@@ -76,12 +55,339 @@ class CausalConv1d(torch.nn.Conv1d):
76
55
  padding_mode=padding_mode,
77
56
  device=device, dtype=dtype)
78
57
  assert stride == 1
79
- self.causal_padding = (kernel_size - 1, 0)
58
+ self.causal_padding = kernel_size - 1
80
59
 
81
- def forward(self, x: torch.Tensor):
82
- x = F.pad(x, self.causal_padding)
60
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ if cache.size(2) == 0:
62
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
63
+ else:
64
+ assert cache.size(2) == self.causal_padding
65
+ x = torch.concat([cache, x], dim=2)
66
+ cache = x[:, :, -self.causal_padding:]
83
67
  x = super(CausalConv1d, self).forward(x)
84
- return x
68
+ return x, cache
69
+
70
+
71
+ class CausalBlock1D(Block1D):
72
+ def __init__(self, dim: int, dim_out: int):
73
+ super(CausalBlock1D, self).__init__(dim, dim_out)
74
+ self.block = torch.nn.Sequential(
75
+ CausalConv1d(dim, dim_out, 3),
76
+ Transpose(1, 2),
77
+ nn.LayerNorm(dim_out),
78
+ Transpose(1, 2),
79
+ nn.Mish(),
80
+ )
81
+
82
+ def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ output, cache = self.block[0](x * mask, cache)
84
+ for i in range(1, len(self.block)):
85
+ output = self.block[i](output)
86
+ return output * mask, cache
87
+
88
+
89
+ class CausalResnetBlock1D(ResnetBlock1D):
90
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
91
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
92
+ self.block1 = CausalBlock1D(dim, dim_out)
93
+ self.block2 = CausalBlock1D(dim_out, dim_out)
94
+
95
+ def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
96
+ block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ h, block1_cache = self.block1(x, mask, block1_cache)
99
+ h += self.mlp(time_emb).unsqueeze(-1)
100
+ h, block2_cache = self.block2(h, mask, block2_cache)
101
+ output = h + self.res_conv(x * mask)
102
+ return output, block1_cache, block2_cache
103
+
104
+
105
+ class CausalAttnProcessor2_0(AttnProcessor2_0):
106
+ r"""
107
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
108
+ """
109
+
110
+ def __init__(self):
111
+ super(CausalAttnProcessor2_0, self).__init__()
112
+
113
+ def __call__(
114
+ self,
115
+ attn: Attention,
116
+ hidden_states: torch.FloatTensor,
117
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
118
+ attention_mask: Optional[torch.FloatTensor] = None,
119
+ temb: Optional[torch.FloatTensor] = None,
120
+ cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
121
+ *args,
122
+ **kwargs,
123
+ ) -> Tuple[torch.FloatTensor, torch.Tensor]:
124
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
125
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
126
+ `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
127
+ deprecate("scale", "1.0.0", deprecation_message)
128
+
129
+ residual = hidden_states
130
+ if attn.spatial_norm is not None:
131
+ hidden_states = attn.spatial_norm(hidden_states, temb)
132
+
133
+ input_ndim = hidden_states.ndim
134
+
135
+ if input_ndim == 4:
136
+ batch_size, channel, height, width = hidden_states.shape
137
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
138
+
139
+ batch_size, sequence_length, _ = (
140
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
141
+ )
142
+
143
+ if attention_mask is not None:
144
+ # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
145
+ # scaled_dot_product_attention expects attention_mask shape to be
146
+ # (batch, heads, source_length, target_length)
147
+ attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
148
+
149
+ if attn.group_norm is not None:
150
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
151
+
152
+ query = attn.to_q(hidden_states)
153
+
154
+ if encoder_hidden_states is None:
155
+ encoder_hidden_states = hidden_states
156
+ elif attn.norm_cross:
157
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
158
+
159
+ key_cache = attn.to_k(encoder_hidden_states)
160
+ value_cache = attn.to_v(encoder_hidden_states)
161
+ # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
162
+ if cache.size(0) != 0:
163
+ key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
164
+ value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
165
+ else:
166
+ key, value = key_cache, value_cache
167
+ cache = torch.stack([key_cache, value_cache], dim=3)
168
+
169
+ inner_dim = key.shape[-1]
170
+ head_dim = inner_dim // attn.heads
171
+
172
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
173
+
174
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
176
+
177
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
178
+ # TODO: add support for attn.scale when we move to Torch 2.1
179
+ hidden_states = F.scaled_dot_product_attention(
180
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
181
+ )
182
+
183
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
184
+ hidden_states = hidden_states.to(query.dtype)
185
+
186
+ # linear proj
187
+ hidden_states = attn.to_out[0](hidden_states)
188
+ # dropout
189
+ hidden_states = attn.to_out[1](hidden_states)
190
+
191
+ if input_ndim == 4:
192
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
193
+
194
+ if attn.residual_connection:
195
+ hidden_states = hidden_states + residual
196
+
197
+ hidden_states = hidden_states / attn.rescale_output_factor
198
+
199
+ return hidden_states, cache
200
+
201
+
202
+ @maybe_allow_in_graph
203
+ class CausalAttention(Attention):
204
+ def __init__(
205
+ self,
206
+ query_dim: int,
207
+ cross_attention_dim: Optional[int] = None,
208
+ heads: int = 8,
209
+ dim_head: int = 64,
210
+ dropout: float = 0.0,
211
+ bias: bool = False,
212
+ upcast_attention: bool = False,
213
+ upcast_softmax: bool = False,
214
+ cross_attention_norm: Optional[str] = None,
215
+ cross_attention_norm_num_groups: int = 32,
216
+ qk_norm: Optional[str] = None,
217
+ added_kv_proj_dim: Optional[int] = None,
218
+ norm_num_groups: Optional[int] = None,
219
+ spatial_norm_dim: Optional[int] = None,
220
+ out_bias: bool = True,
221
+ scale_qk: bool = True,
222
+ only_cross_attention: bool = False,
223
+ eps: float = 1e-5,
224
+ rescale_output_factor: float = 1.0,
225
+ residual_connection: bool = False,
226
+ _from_deprecated_attn_block: bool = False,
227
+ processor: Optional["AttnProcessor2_0"] = None,
228
+ out_dim: int = None,
229
+ ):
230
+ super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
231
+ cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
232
+ spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
233
+ _from_deprecated_attn_block, processor, out_dim)
234
+ processor = CausalAttnProcessor2_0()
235
+ self.set_processor(processor)
236
+
237
+ def forward(
238
+ self,
239
+ hidden_states: torch.FloatTensor,
240
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
241
+ attention_mask: Optional[torch.FloatTensor] = None,
242
+ cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
243
+ **cross_attention_kwargs,
244
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ r"""
246
+ The forward method of the `Attention` class.
247
+
248
+ Args:
249
+ hidden_states (`torch.Tensor`):
250
+ The hidden states of the query.
251
+ encoder_hidden_states (`torch.Tensor`, *optional*):
252
+ The hidden states of the encoder.
253
+ attention_mask (`torch.Tensor`, *optional*):
254
+ The attention mask to use. If `None`, no mask is applied.
255
+ **cross_attention_kwargs:
256
+ Additional keyword arguments to pass along to the cross attention.
257
+
258
+ Returns:
259
+ `torch.Tensor`: The output of the attention layer.
260
+ """
261
+ # The `Attention` class can call different attention processors / attention functions
262
+ # here we simply pass along all tensors to the selected processor class
263
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
264
+
265
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
266
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
267
+ if len(unused_kwargs) > 0:
268
+ logger.warning(
269
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
270
+ )
271
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
272
+
273
+ return self.processor(
274
+ self,
275
+ hidden_states,
276
+ encoder_hidden_states=encoder_hidden_states,
277
+ attention_mask=attention_mask,
278
+ cache=cache,
279
+ **cross_attention_kwargs,
280
+ )
281
+
282
+
283
+ @maybe_allow_in_graph
284
+ class CausalBasicTransformerBlock(BasicTransformerBlock):
285
+ def __init__(
286
+ self,
287
+ dim: int,
288
+ num_attention_heads: int,
289
+ attention_head_dim: int,
290
+ dropout=0.0,
291
+ cross_attention_dim: Optional[int] = None,
292
+ activation_fn: str = "geglu",
293
+ num_embeds_ada_norm: Optional[int] = None,
294
+ attention_bias: bool = False,
295
+ only_cross_attention: bool = False,
296
+ double_self_attention: bool = False,
297
+ upcast_attention: bool = False,
298
+ norm_elementwise_affine: bool = True,
299
+ norm_type: str = "layer_norm",
300
+ final_dropout: bool = False,
301
+ ):
302
+ super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
303
+ cross_attention_dim, activation_fn, num_embeds_ada_norm,
304
+ attention_bias, only_cross_attention, double_self_attention,
305
+ upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
306
+ self.attn1 = CausalAttention(
307
+ query_dim=dim,
308
+ heads=num_attention_heads,
309
+ dim_head=attention_head_dim,
310
+ dropout=dropout,
311
+ bias=attention_bias,
312
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
313
+ upcast_attention=upcast_attention,
314
+ )
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.FloatTensor,
319
+ attention_mask: Optional[torch.FloatTensor] = None,
320
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
321
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
322
+ timestep: Optional[torch.LongTensor] = None,
323
+ cross_attention_kwargs: Dict[str, Any] = None,
324
+ class_labels: Optional[torch.LongTensor] = None,
325
+ cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
326
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
327
+ # Notice that normalization is always applied before the real computation in the following blocks.
328
+ # 1. Self-Attention
329
+ if self.use_ada_layer_norm:
330
+ norm_hidden_states = self.norm1(hidden_states, timestep)
331
+ elif self.use_ada_layer_norm_zero:
332
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
333
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
334
+ )
335
+ else:
336
+ norm_hidden_states = self.norm1(hidden_states)
337
+
338
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
339
+
340
+ attn_output, cache = self.attn1(
341
+ norm_hidden_states,
342
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
343
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
344
+ cache=cache,
345
+ **cross_attention_kwargs,
346
+ )
347
+ if self.use_ada_layer_norm_zero:
348
+ attn_output = gate_msa.unsqueeze(1) * attn_output
349
+ hidden_states = attn_output + hidden_states
350
+
351
+ # 2. Cross-Attention
352
+ if self.attn2 is not None:
353
+ norm_hidden_states = (
354
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
355
+ )
356
+
357
+ attn_output = self.attn2(
358
+ norm_hidden_states,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ attention_mask=encoder_attention_mask,
361
+ **cross_attention_kwargs,
362
+ )
363
+ hidden_states = attn_output + hidden_states
364
+
365
+ # 3. Feed-forward
366
+ norm_hidden_states = self.norm3(hidden_states)
367
+
368
+ if self.use_ada_layer_norm_zero:
369
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
370
+
371
+ if self._chunk_size is not None:
372
+ # "feed_forward_chunk_size" can be used to save memory
373
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
374
+ raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
375
+ {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
376
+
377
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
378
+ ff_output = torch.cat(
379
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
380
+ dim=self._chunk_dim,
381
+ )
382
+ else:
383
+ ff_output = self.ff(norm_hidden_states)
384
+
385
+ if self.use_ada_layer_norm_zero:
386
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
387
+
388
+ hidden_states = ff_output + hidden_states
389
+
390
+ return hidden_states, cache
85
391
 
86
392
 
87
393
  class ConditionalDecoder(nn.Module):
@@ -89,7 +395,6 @@ class ConditionalDecoder(nn.Module):
89
395
  self,
90
396
  in_channels,
91
397
  out_channels,
92
- causal=False,
93
398
  channels=(256, 256),
94
399
  dropout=0.05,
95
400
  attention_head_dim=64,
@@ -106,7 +411,7 @@ class ConditionalDecoder(nn.Module):
106
411
  channels = tuple(channels)
107
412
  self.in_channels = in_channels
108
413
  self.out_channels = out_channels
109
- self.causal = causal
414
+
110
415
  self.time_embeddings = SinusoidalPosEmb(in_channels)
111
416
  time_embed_dim = channels[0] * 4
112
417
  self.time_mlp = TimestepEmbedding(
@@ -123,8 +428,7 @@ class ConditionalDecoder(nn.Module):
123
428
  input_channel = output_channel
124
429
  output_channel = channels[i]
125
430
  is_last = i == len(channels) - 1
126
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
127
- ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
431
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
128
432
  transformer_blocks = nn.ModuleList(
129
433
  [
130
434
  BasicTransformerBlock(
@@ -138,16 +442,14 @@ class ConditionalDecoder(nn.Module):
138
442
  ]
139
443
  )
140
444
  downsample = (
141
- Downsample1D(output_channel) if not is_last else
142
- CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
445
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
143
446
  )
144
447
  self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
145
448
 
146
449
  for _ in range(num_mid_blocks):
147
450
  input_channel = channels[-1]
148
451
  out_channels = channels[-1]
149
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
150
- ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
452
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
151
453
 
152
454
  transformer_blocks = nn.ModuleList(
153
455
  [
@@ -169,11 +471,7 @@ class ConditionalDecoder(nn.Module):
169
471
  input_channel = channels[i] * 2
170
472
  output_channel = channels[i + 1]
171
473
  is_last = i == len(channels) - 2
172
- resnet = CausalResnetBlock1D(
173
- dim=input_channel,
174
- dim_out=output_channel,
175
- time_emb_dim=time_embed_dim,
176
- ) if self.causal else ResnetBlock1D(
474
+ resnet = ResnetBlock1D(
177
475
  dim=input_channel,
178
476
  dim_out=output_channel,
179
477
  time_emb_dim=time_embed_dim,
@@ -193,10 +491,10 @@ class ConditionalDecoder(nn.Module):
193
491
  upsample = (
194
492
  Upsample1D(output_channel, use_conv_transpose=True)
195
493
  if not is_last
196
- else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
494
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
197
495
  )
198
496
  self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
199
- self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
497
+ self.final_block = Block1D(channels[-1], channels[-1])
200
498
  self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
201
499
  self.initialize_weights()
202
500
 
@@ -214,7 +512,7 @@ class ConditionalDecoder(nn.Module):
214
512
  if m.bias is not None:
215
513
  nn.init.constant_(m.bias, 0)
216
514
 
217
- def forward(self, x, mask, mu, t, spks=None, cond=None):
515
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
218
516
  """Forward pass of the UNet1DConditional model.
219
517
 
220
518
  Args:
@@ -249,9 +547,8 @@ class ConditionalDecoder(nn.Module):
249
547
  mask_down = masks[-1]
250
548
  x = resnet(x, mask_down, t)
251
549
  x = rearrange(x, "b c t -> b t c").contiguous()
252
- # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
253
- attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
254
- attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
550
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
551
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
255
552
  for transformer_block in transformer_blocks:
256
553
  x = transformer_block(
257
554
  hidden_states=x,
@@ -268,9 +565,8 @@ class ConditionalDecoder(nn.Module):
268
565
  for resnet, transformer_blocks in self.mid_blocks:
269
566
  x = resnet(x, mask_mid, t)
270
567
  x = rearrange(x, "b c t -> b t c").contiguous()
271
- # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
272
- attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
273
- attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
568
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
569
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
274
570
  for transformer_block in transformer_blocks:
275
571
  x = transformer_block(
276
572
  hidden_states=x,
@@ -285,9 +581,8 @@ class ConditionalDecoder(nn.Module):
285
581
  x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
286
582
  x = resnet(x, mask_up, t)
287
583
  x = rearrange(x, "b c t -> b t c").contiguous()
288
- # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
289
- attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
290
- attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
584
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
585
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
291
586
  for transformer_block in transformer_blocks:
292
587
  x = transformer_block(
293
588
  hidden_states=x,
@@ -299,3 +594,310 @@ class ConditionalDecoder(nn.Module):
299
594
  x = self.final_block(x, mask_up)
300
595
  output = self.final_proj(x * mask_up)
301
596
  return output * mask
597
+
598
+
599
+ class CausalConditionalDecoder(ConditionalDecoder):
600
+ def __init__(
601
+ self,
602
+ in_channels,
603
+ out_channels,
604
+ channels=(256, 256),
605
+ dropout=0.05,
606
+ attention_head_dim=64,
607
+ n_blocks=1,
608
+ num_mid_blocks=2,
609
+ num_heads=4,
610
+ act_fn="snake",
611
+ static_chunk_size=50,
612
+ num_decoding_left_chunks=2,
613
+ ):
614
+ """
615
+ This decoder requires an input with the same shape of the target. So, if your text content
616
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
617
+ """
618
+ torch.nn.Module.__init__(self)
619
+ channels = tuple(channels)
620
+ self.in_channels = in_channels
621
+ self.out_channels = out_channels
622
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
623
+ time_embed_dim = channels[0] * 4
624
+ self.time_mlp = TimestepEmbedding(
625
+ in_channels=in_channels,
626
+ time_embed_dim=time_embed_dim,
627
+ act_fn="silu",
628
+ )
629
+ self.static_chunk_size = static_chunk_size
630
+ self.num_decoding_left_chunks = num_decoding_left_chunks
631
+ self.down_blocks = nn.ModuleList([])
632
+ self.mid_blocks = nn.ModuleList([])
633
+ self.up_blocks = nn.ModuleList([])
634
+
635
+ output_channel = in_channels
636
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
637
+ input_channel = output_channel
638
+ output_channel = channels[i]
639
+ is_last = i == len(channels) - 1
640
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
641
+ transformer_blocks = nn.ModuleList(
642
+ [
643
+ CausalBasicTransformerBlock(
644
+ dim=output_channel,
645
+ num_attention_heads=num_heads,
646
+ attention_head_dim=attention_head_dim,
647
+ dropout=dropout,
648
+ activation_fn=act_fn,
649
+ )
650
+ for _ in range(n_blocks)
651
+ ]
652
+ )
653
+ downsample = (
654
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
655
+ )
656
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
657
+
658
+ for _ in range(num_mid_blocks):
659
+ input_channel = channels[-1]
660
+ out_channels = channels[-1]
661
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
662
+
663
+ transformer_blocks = nn.ModuleList(
664
+ [
665
+ CausalBasicTransformerBlock(
666
+ dim=output_channel,
667
+ num_attention_heads=num_heads,
668
+ attention_head_dim=attention_head_dim,
669
+ dropout=dropout,
670
+ activation_fn=act_fn,
671
+ )
672
+ for _ in range(n_blocks)
673
+ ]
674
+ )
675
+
676
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
677
+
678
+ channels = channels[::-1] + (channels[0],)
679
+ for i in range(len(channels) - 1):
680
+ input_channel = channels[i] * 2
681
+ output_channel = channels[i + 1]
682
+ is_last = i == len(channels) - 2
683
+ resnet = CausalResnetBlock1D(
684
+ dim=input_channel,
685
+ dim_out=output_channel,
686
+ time_emb_dim=time_embed_dim,
687
+ )
688
+ transformer_blocks = nn.ModuleList(
689
+ [
690
+ CausalBasicTransformerBlock(
691
+ dim=output_channel,
692
+ num_attention_heads=num_heads,
693
+ attention_head_dim=attention_head_dim,
694
+ dropout=dropout,
695
+ activation_fn=act_fn,
696
+ )
697
+ for _ in range(n_blocks)
698
+ ]
699
+ )
700
+ upsample = (
701
+ Upsample1D(output_channel, use_conv_transpose=True)
702
+ if not is_last
703
+ else CausalConv1d(output_channel, output_channel, 3)
704
+ )
705
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
706
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
707
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
708
+ self.initialize_weights()
709
+
710
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
711
+ """Forward pass of the UNet1DConditional model.
712
+
713
+ Args:
714
+ x (torch.Tensor): shape (batch_size, in_channels, time)
715
+ mask (_type_): shape (batch_size, 1, time)
716
+ t (_type_): shape (batch_size)
717
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
718
+ cond (_type_, optional): placeholder for future use. Defaults to None.
719
+
720
+ Raises:
721
+ ValueError: _description_
722
+ ValueError: _description_
723
+
724
+ Returns:
725
+ _type_: _description_
726
+ """
727
+
728
+ t = self.time_embeddings(t).to(t.dtype)
729
+ t = self.time_mlp(t)
730
+
731
+ x = pack([x, mu], "b * t")[0]
732
+
733
+ if spks is not None:
734
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
735
+ x = pack([x, spks], "b * t")[0]
736
+ if cond is not None:
737
+ x = pack([x, cond], "b * t")[0]
738
+
739
+ hiddens = []
740
+ masks = [mask]
741
+ for resnet, transformer_blocks, downsample in self.down_blocks:
742
+ mask_down = masks[-1]
743
+ x, _, _ = resnet(x, mask_down, t)
744
+ x = rearrange(x, "b c t -> b t c").contiguous()
745
+ if streaming is True:
746
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
747
+ else:
748
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
749
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
750
+ for transformer_block in transformer_blocks:
751
+ x, _ = transformer_block(
752
+ hidden_states=x,
753
+ attention_mask=attn_mask,
754
+ timestep=t,
755
+ )
756
+ x = rearrange(x, "b t c -> b c t").contiguous()
757
+ hiddens.append(x) # Save hidden states for skip connections
758
+ x, _ = downsample(x * mask_down)
759
+ masks.append(mask_down[:, :, ::2])
760
+ masks = masks[:-1]
761
+ mask_mid = masks[-1]
762
+
763
+ for resnet, transformer_blocks in self.mid_blocks:
764
+ x, _, _ = resnet(x, mask_mid, t)
765
+ x = rearrange(x, "b c t -> b t c").contiguous()
766
+ if streaming is True:
767
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
768
+ else:
769
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
770
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
771
+ for transformer_block in transformer_blocks:
772
+ x, _ = transformer_block(
773
+ hidden_states=x,
774
+ attention_mask=attn_mask,
775
+ timestep=t,
776
+ )
777
+ x = rearrange(x, "b t c -> b c t").contiguous()
778
+
779
+ for resnet, transformer_blocks, upsample in self.up_blocks:
780
+ mask_up = masks.pop()
781
+ skip = hiddens.pop()
782
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
783
+ x, _, _ = resnet(x, mask_up, t)
784
+ x = rearrange(x, "b c t -> b t c").contiguous()
785
+ if streaming is True:
786
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
787
+ else:
788
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
789
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
790
+ for transformer_block in transformer_blocks:
791
+ x, _ = transformer_block(
792
+ hidden_states=x,
793
+ attention_mask=attn_mask,
794
+ timestep=t,
795
+ )
796
+ x = rearrange(x, "b t c -> b c t").contiguous()
797
+ x, _ = upsample(x * mask_up)
798
+ x, _ = self.final_block(x, mask_up)
799
+ output = self.final_proj(x * mask_up)
800
+ return output * mask
801
+
802
+ @torch.inference_mode()
803
+ def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
804
+ down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
805
+ down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
806
+ mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
807
+ mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
808
+ up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
809
+ up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
810
+ final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
811
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
812
+ """Forward pass of the UNet1DConditional model.
813
+
814
+ Args:
815
+ x (torch.Tensor): shape (batch_size, in_channels, time)
816
+ mask (_type_): shape (batch_size, 1, time)
817
+ t (_type_): shape (batch_size)
818
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
819
+ cond (_type_, optional): placeholder for future use. Defaults to None.
820
+
821
+ Raises:
822
+ ValueError: _description_
823
+ ValueError: _description_
824
+
825
+ Returns:
826
+ _type_: _description_
827
+ """
828
+
829
+ t = self.time_embeddings(t).to(t.dtype)
830
+ t = self.time_mlp(t)
831
+
832
+ x = pack([x, mu], "b * t")[0]
833
+
834
+ if spks is not None:
835
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
836
+ x = pack([x, spks], "b * t")[0]
837
+ if cond is not None:
838
+ x = pack([x, cond], "b * t")[0]
839
+
840
+ hiddens = []
841
+ masks = [mask]
842
+
843
+ down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
844
+ mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
845
+ up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
846
+ for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
847
+ mask_down = masks[-1]
848
+ x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
849
+ resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
850
+ x = rearrange(x, "b c t -> b t c").contiguous()
851
+ attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
852
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
853
+ for i, transformer_block in enumerate(transformer_blocks):
854
+ x, down_blocks_kv_cache_new[index, i] = transformer_block(
855
+ hidden_states=x,
856
+ attention_mask=attn_mask,
857
+ timestep=t,
858
+ cache=down_blocks_kv_cache[index, i],
859
+ )
860
+ x = rearrange(x, "b t c -> b c t").contiguous()
861
+ hiddens.append(x) # Save hidden states for skip connections
862
+ x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
863
+ masks.append(mask_down[:, :, ::2])
864
+ masks = masks[:-1]
865
+ mask_mid = masks[-1]
866
+
867
+ for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
868
+ x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
869
+ resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
870
+ x = rearrange(x, "b c t -> b t c").contiguous()
871
+ attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
872
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
873
+ for i, transformer_block in enumerate(transformer_blocks):
874
+ x, mid_blocks_kv_cache_new[index, i] = transformer_block(
875
+ hidden_states=x,
876
+ attention_mask=attn_mask,
877
+ timestep=t,
878
+ cache=mid_blocks_kv_cache[index, i]
879
+ )
880
+ x = rearrange(x, "b t c -> b c t").contiguous()
881
+
882
+ for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
883
+ mask_up = masks.pop()
884
+ skip = hiddens.pop()
885
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
886
+ x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
887
+ resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
888
+ x = rearrange(x, "b c t -> b t c").contiguous()
889
+ attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
890
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
891
+ for i, transformer_block in enumerate(transformer_blocks):
892
+ x, up_blocks_kv_cache_new[index, i] = transformer_block(
893
+ hidden_states=x,
894
+ attention_mask=attn_mask,
895
+ timestep=t,
896
+ cache=up_blocks_kv_cache[index, i]
897
+ )
898
+ x = rearrange(x, "b t c -> b c t").contiguous()
899
+ x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
900
+ x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
901
+ output = self.final_proj(x * mask_up)
902
+ return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
903
+ up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache