minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,974 @@
1
+ import math
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
8
+ AdaLayerNormZero, ApproximateGELU)
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.models.lora import LoRACompatibleLinear
11
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
12
+ from einops import pack, rearrange, repeat
13
+
14
+ from stepaudio2.flashcosyvoice.modules.flow_components.upsample_encoder import \
15
+ add_optional_chunk_mask
16
+
17
+
18
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
19
+ assert mask.dtype == torch.bool
20
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
21
+ mask = mask.to(dtype)
22
+ # attention mask bias
23
+ # NOTE(Mddct): torch.finfo jit issues
24
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
25
+ mask = (1.0 - mask) * -1.0e+10
26
+ return mask
27
+
28
+
29
+ class SnakeBeta(nn.Module):
30
+ """
31
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
32
+ Shape:
33
+ - Input: (B, C, T)
34
+ - Output: (B, C, T), same shape as the input
35
+ Parameters:
36
+ - alpha - trainable parameter that controls frequency
37
+ - beta - trainable parameter that controls magnitude
38
+ References:
39
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
40
+ https://arxiv.org/abs/2006.08195
41
+ Examples:
42
+ >>> a1 = snakebeta(256)
43
+ >>> x = torch.randn(256)
44
+ >>> x = a1(x)
45
+
46
+ Args:
47
+ in_features: shape of the input
48
+ out_features: shape of the output
49
+ alpha: trainable parameter that controls frequency
50
+ alpha_trainable: whether alpha is trainable
51
+ alpha_logscale: whether to use log scale for alpha
52
+ alpha is initialized to 1 by default, higher values = higher-frequency.
53
+ beta is initialized to 1 by default, higher values = higher-magnitude.
54
+ alpha will be trained along with the rest of your model.
55
+ """
56
+
57
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
58
+ super().__init__()
59
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
60
+ self.proj = LoRACompatibleLinear(in_features, out_features)
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
66
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
67
+ else: # linear scale alphas initialized to ones
68
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
69
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
70
+
71
+ self.alpha.requires_grad = alpha_trainable
72
+ self.beta.requires_grad = alpha_trainable
73
+
74
+ self.no_div_by_zero = 0.000000001
75
+
76
+ def forward(self, x):
77
+ """
78
+ Forward pass of the function.
79
+ Applies the function to the input elementwise.
80
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
81
+ """
82
+ x = self.proj(x)
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(self.alpha)
85
+ beta = torch.exp(self.beta)
86
+ else:
87
+ alpha = self.alpha
88
+ beta = self.beta
89
+
90
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
91
+
92
+ return x
93
+
94
+
95
+ class FeedForward(nn.Module):
96
+ r"""
97
+ A feed-forward layer.
98
+
99
+ Parameters:
100
+ dim (`int`): The number of channels in the input.
101
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
102
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
103
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
104
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
105
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ dim_out: Optional[int] = None,
112
+ mult: int = 4,
113
+ dropout: float = 0.0,
114
+ activation_fn: str = "geglu",
115
+ final_dropout: bool = False,
116
+ ):
117
+ super().__init__()
118
+ inner_dim = int(dim * mult)
119
+ dim_out = dim_out if dim_out is not None else dim
120
+
121
+ if activation_fn == "gelu":
122
+ act_fn = GELU(dim, inner_dim)
123
+ if activation_fn == "gelu-approximate":
124
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
125
+ elif activation_fn == "geglu":
126
+ act_fn = GEGLU(dim, inner_dim)
127
+ elif activation_fn == "geglu-approximate":
128
+ act_fn = ApproximateGELU(dim, inner_dim)
129
+ elif activation_fn == "snakebeta":
130
+ act_fn = SnakeBeta(dim, inner_dim)
131
+
132
+ self.net = nn.ModuleList([])
133
+ # project in
134
+ self.net.append(act_fn)
135
+ # project dropout
136
+ self.net.append(nn.Dropout(dropout))
137
+ # project out
138
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
139
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
140
+ if final_dropout:
141
+ self.net.append(nn.Dropout(dropout))
142
+
143
+ def forward(self, hidden_states):
144
+ for module in self.net:
145
+ hidden_states = module(hidden_states)
146
+ return hidden_states
147
+
148
+
149
+ @maybe_allow_in_graph
150
+ class BasicTransformerBlock(nn.Module):
151
+ r"""
152
+ A basic Transformer block.
153
+
154
+ Parameters:
155
+ dim (`int`): The number of channels in the input and output.
156
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
157
+ attention_head_dim (`int`): The number of channels in each head.
158
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
159
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
160
+ only_cross_attention (`bool`, *optional*):
161
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
162
+ double_self_attention (`bool`, *optional*):
163
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
164
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
165
+ num_embeds_ada_norm (:
166
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
167
+ attention_bias (:
168
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ dim: int,
174
+ num_attention_heads: int,
175
+ attention_head_dim: int,
176
+ dropout=0.0,
177
+ cross_attention_dim: Optional[int] = None,
178
+ activation_fn: str = "geglu",
179
+ num_embeds_ada_norm: Optional[int] = None,
180
+ attention_bias: bool = False,
181
+ only_cross_attention: bool = False,
182
+ double_self_attention: bool = False,
183
+ upcast_attention: bool = False,
184
+ norm_elementwise_affine: bool = True,
185
+ norm_type: str = "layer_norm",
186
+ final_dropout: bool = False,
187
+ ):
188
+ super().__init__()
189
+ self.only_cross_attention = only_cross_attention
190
+
191
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
192
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
193
+
194
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
195
+ raise ValueError(
196
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
197
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
198
+ )
199
+
200
+ # Define 3 blocks. Each block has its own normalization layer.
201
+ # 1. Self-Attn
202
+ if self.use_ada_layer_norm:
203
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
204
+ elif self.use_ada_layer_norm_zero:
205
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
206
+ else:
207
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
208
+ self.attn1 = Attention(
209
+ query_dim=dim,
210
+ heads=num_attention_heads,
211
+ dim_head=attention_head_dim,
212
+ dropout=dropout,
213
+ bias=attention_bias,
214
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
215
+ upcast_attention=upcast_attention,
216
+ )
217
+
218
+ # 2. Cross-Attn
219
+ if cross_attention_dim is not None or double_self_attention:
220
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
221
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
222
+ # the second cross attention block.
223
+ self.norm2 = (
224
+ AdaLayerNorm(dim, num_embeds_ada_norm)
225
+ if self.use_ada_layer_norm
226
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
227
+ )
228
+ self.attn2 = Attention(
229
+ query_dim=dim,
230
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
231
+ heads=num_attention_heads,
232
+ dim_head=attention_head_dim,
233
+ dropout=dropout,
234
+ bias=attention_bias,
235
+ upcast_attention=upcast_attention,
236
+ # scale_qk=False, # uncomment this to not to use flash attention
237
+ ) # is self-attn if encoder_hidden_states is none
238
+ else:
239
+ self.norm2 = None
240
+ self.attn2 = None
241
+
242
+ # 3. Feed-forward
243
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
244
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
245
+
246
+ # let chunk size default to None
247
+ self._chunk_size = None
248
+ self._chunk_dim = 0
249
+
250
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
251
+ # Sets chunk feed-forward
252
+ self._chunk_size = chunk_size
253
+ self._chunk_dim = dim
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.FloatTensor,
258
+ attention_mask: Optional[torch.FloatTensor] = None,
259
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
260
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
261
+ timestep: Optional[torch.LongTensor] = None,
262
+ cross_attention_kwargs: Dict[str, Any] = None,
263
+ class_labels: Optional[torch.LongTensor] = None,
264
+ ):
265
+ # Notice that normalization is always applied before the real computation in the following blocks.
266
+ # 1. Self-Attention
267
+ if self.use_ada_layer_norm:
268
+ norm_hidden_states = self.norm1(hidden_states, timestep)
269
+ elif self.use_ada_layer_norm_zero:
270
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
271
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
272
+ )
273
+ else:
274
+ norm_hidden_states = self.norm1(hidden_states)
275
+
276
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
277
+
278
+ attn_output = self.attn1(
279
+ norm_hidden_states,
280
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
281
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
282
+ **cross_attention_kwargs,
283
+ )
284
+ if self.use_ada_layer_norm_zero:
285
+ attn_output = gate_msa.unsqueeze(1) * attn_output
286
+ hidden_states = attn_output + hidden_states
287
+
288
+ # 2. Cross-Attention
289
+ if self.attn2 is not None:
290
+ norm_hidden_states = (
291
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
292
+ )
293
+
294
+ attn_output = self.attn2(
295
+ norm_hidden_states,
296
+ encoder_hidden_states=encoder_hidden_states,
297
+ attention_mask=encoder_attention_mask,
298
+ **cross_attention_kwargs,
299
+ )
300
+ hidden_states = attn_output + hidden_states
301
+
302
+ # 3. Feed-forward
303
+ norm_hidden_states = self.norm3(hidden_states)
304
+
305
+ if self.use_ada_layer_norm_zero:
306
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
307
+
308
+ if self._chunk_size is not None:
309
+ # "feed_forward_chunk_size" can be used to save memory
310
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
311
+ raise ValueError(
312
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
313
+ )
314
+
315
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
316
+ ff_output = torch.cat(
317
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
318
+ dim=self._chunk_dim,
319
+ )
320
+ else:
321
+ ff_output = self.ff(norm_hidden_states)
322
+
323
+ if self.use_ada_layer_norm_zero:
324
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
325
+
326
+ hidden_states = ff_output + hidden_states
327
+
328
+ return hidden_states
329
+
330
+
331
+ class SinusoidalPosEmb(torch.nn.Module):
332
+ def __init__(self, dim):
333
+ super().__init__()
334
+ self.dim = dim
335
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
336
+
337
+ def forward(self, x, scale=1000):
338
+ if x.ndim < 1:
339
+ x = x.unsqueeze(0)
340
+ device = x.device
341
+ half_dim = self.dim // 2
342
+ emb = math.log(10000) / (half_dim - 1)
343
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
344
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
345
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
346
+ return emb
347
+
348
+
349
+ class Block1D(torch.nn.Module):
350
+ def __init__(self, dim, dim_out, groups=8):
351
+ super().__init__()
352
+ self.block = torch.nn.Sequential(
353
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
354
+ torch.nn.GroupNorm(groups, dim_out),
355
+ nn.Mish(),
356
+ )
357
+
358
+ def forward(self, x, mask):
359
+ output = self.block(x * mask)
360
+ return output * mask
361
+
362
+
363
+ class ResnetBlock1D(torch.nn.Module):
364
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
365
+ super().__init__()
366
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
367
+
368
+ self.block1 = Block1D(dim, dim_out, groups=groups)
369
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
370
+
371
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
372
+
373
+ def forward(self, x, mask, time_emb):
374
+ h = self.block1(x, mask)
375
+ h += self.mlp(time_emb).unsqueeze(-1)
376
+ h = self.block2(h, mask)
377
+ output = h + self.res_conv(x * mask)
378
+ return output
379
+
380
+
381
+ class Downsample1D(nn.Module):
382
+ def __init__(self, dim):
383
+ super().__init__()
384
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
385
+
386
+ def forward(self, x):
387
+ return self.conv(x)
388
+
389
+
390
+ class TimestepEmbedding(nn.Module):
391
+ def __init__(
392
+ self,
393
+ in_channels: int,
394
+ time_embed_dim: int,
395
+ act_fn: str = "silu",
396
+ out_dim: int = None,
397
+ post_act_fn: Optional[str] = None,
398
+ cond_proj_dim=None,
399
+ ):
400
+ super().__init__()
401
+ assert act_fn == "silu", "act_fn must be silu"
402
+
403
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
404
+
405
+ if cond_proj_dim is not None:
406
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
407
+ else:
408
+ self.cond_proj = None
409
+
410
+ self.act = nn.SiLU()
411
+
412
+ if out_dim is not None:
413
+ time_embed_dim_out = out_dim
414
+ else:
415
+ time_embed_dim_out = time_embed_dim
416
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
417
+
418
+ if post_act_fn is None:
419
+ self.post_act = None
420
+ else:
421
+ self.post_act = nn.SiLU()
422
+
423
+ def forward(self, sample, condition=None):
424
+ if condition is not None:
425
+ sample = sample + self.cond_proj(condition)
426
+ sample = self.linear_1(sample)
427
+
428
+ if self.act is not None:
429
+ sample = self.act(sample)
430
+
431
+ sample = self.linear_2(sample)
432
+
433
+ if self.post_act is not None:
434
+ sample = self.post_act(sample)
435
+ return sample
436
+
437
+
438
+ class Upsample1D(nn.Module):
439
+ """A 1D upsampling layer with an optional convolution.
440
+
441
+ Parameters:
442
+ channels (`int`):
443
+ number of channels in the inputs and outputs.
444
+ use_conv (`bool`, default `False`):
445
+ option to use a convolution.
446
+ use_conv_transpose (`bool`, default `False`):
447
+ option to use a convolution transpose.
448
+ out_channels (`int`, optional):
449
+ number of output channels. Defaults to `channels`.
450
+ """
451
+
452
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
453
+ super().__init__()
454
+ self.channels = channels
455
+ self.out_channels = out_channels or channels
456
+ self.use_conv = use_conv
457
+ self.use_conv_transpose = use_conv_transpose
458
+ self.name = name
459
+
460
+ self.conv = None
461
+ if use_conv_transpose:
462
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
463
+ elif use_conv:
464
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
465
+
466
+ def forward(self, inputs):
467
+ assert inputs.shape[1] == self.channels
468
+ if self.use_conv_transpose:
469
+ return self.conv(inputs)
470
+
471
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
472
+
473
+ if self.use_conv:
474
+ outputs = self.conv(outputs)
475
+
476
+ return outputs
477
+
478
+
479
+ class Transpose(torch.nn.Module):
480
+ def __init__(self, dim0: int, dim1: int):
481
+ super().__init__()
482
+ self.dim0 = dim0
483
+ self.dim1 = dim1
484
+
485
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
486
+ x = torch.transpose(x, self.dim0, self.dim1)
487
+ return x
488
+
489
+
490
+ class CausalConv1d(torch.nn.Conv1d):
491
+ def __init__(
492
+ self,
493
+ in_channels: int,
494
+ out_channels: int,
495
+ kernel_size: int,
496
+ stride: int = 1,
497
+ dilation: int = 1,
498
+ groups: int = 1,
499
+ bias: bool = True,
500
+ padding_mode: str = 'zeros',
501
+ device=None,
502
+ dtype=None
503
+ ) -> None:
504
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
505
+ kernel_size, stride,
506
+ padding=0, dilation=dilation,
507
+ groups=groups, bias=bias,
508
+ padding_mode=padding_mode,
509
+ device=device, dtype=dtype)
510
+ assert stride == 1
511
+ self.causal_padding = kernel_size - 1
512
+
513
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
514
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
515
+ x = super(CausalConv1d, self).forward(x)
516
+ return x
517
+
518
+
519
+ class CausalBlock1D(Block1D):
520
+ def __init__(self, dim: int, dim_out: int):
521
+ super(CausalBlock1D, self).__init__(dim, dim_out)
522
+ self.block = torch.nn.Sequential(
523
+ CausalConv1d(dim, dim_out, 3),
524
+ Transpose(1, 2),
525
+ nn.LayerNorm(dim_out),
526
+ Transpose(1, 2),
527
+ nn.Mish(),
528
+ )
529
+
530
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
531
+ output = self.block(x * mask)
532
+ return output * mask
533
+
534
+
535
+ class CausalResnetBlock1D(ResnetBlock1D):
536
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
537
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
538
+ self.block1 = CausalBlock1D(dim, dim_out)
539
+ self.block2 = CausalBlock1D(dim_out, dim_out)
540
+
541
+
542
+ class ConditionalDecoder(nn.Module):
543
+ """
544
+ This decoder requires an input with the same shape of the target. So, if your text content
545
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
546
+
547
+ Args:
548
+ in_channels: number of input channels
549
+ out_channels: number of output channels
550
+ channels: tuple of channel dimensions
551
+ dropout: dropout rate
552
+ attention_head_dim: dimension of attention heads
553
+ n_blocks: number of transformer blocks
554
+ num_mid_blocks: number of middle blocks
555
+ num_heads: number of attention heads
556
+ act_fn: activation function name
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ in_channels,
562
+ out_channels,
563
+ channels=(256, 256),
564
+ dropout=0.05,
565
+ attention_head_dim=64,
566
+ n_blocks=1,
567
+ num_mid_blocks=2,
568
+ num_heads=4,
569
+ act_fn="snake",
570
+ ):
571
+ super().__init__()
572
+ channels = tuple(channels)
573
+ self.in_channels = in_channels
574
+ self.out_channels = out_channels
575
+
576
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
577
+ time_embed_dim = channels[0] * 4
578
+ self.time_mlp = TimestepEmbedding(
579
+ in_channels=in_channels,
580
+ time_embed_dim=time_embed_dim,
581
+ act_fn="silu",
582
+ )
583
+ self.down_blocks = nn.ModuleList([])
584
+ self.mid_blocks = nn.ModuleList([])
585
+ self.up_blocks = nn.ModuleList([])
586
+
587
+ output_channel = in_channels
588
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
589
+ input_channel = output_channel
590
+ output_channel = channels[i]
591
+ is_last = i == len(channels) - 1
592
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
593
+ transformer_blocks = nn.ModuleList(
594
+ [
595
+ BasicTransformerBlock(
596
+ dim=output_channel,
597
+ num_attention_heads=num_heads,
598
+ attention_head_dim=attention_head_dim,
599
+ dropout=dropout,
600
+ activation_fn=act_fn,
601
+ )
602
+ for _ in range(n_blocks)
603
+ ]
604
+ )
605
+ downsample = (
606
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
607
+ )
608
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
609
+
610
+ for _ in range(num_mid_blocks):
611
+ input_channel = channels[-1]
612
+ out_channels = channels[-1]
613
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
614
+
615
+ transformer_blocks = nn.ModuleList(
616
+ [
617
+ BasicTransformerBlock(
618
+ dim=output_channel,
619
+ num_attention_heads=num_heads,
620
+ attention_head_dim=attention_head_dim,
621
+ dropout=dropout,
622
+ activation_fn=act_fn,
623
+ )
624
+ for _ in range(n_blocks)
625
+ ]
626
+ )
627
+
628
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
629
+
630
+ channels = channels[::-1] + (channels[0],)
631
+ for i in range(len(channels) - 1):
632
+ input_channel = channels[i] * 2
633
+ output_channel = channels[i + 1]
634
+ is_last = i == len(channels) - 2
635
+ resnet = ResnetBlock1D(
636
+ dim=input_channel,
637
+ dim_out=output_channel,
638
+ time_emb_dim=time_embed_dim,
639
+ )
640
+ transformer_blocks = nn.ModuleList(
641
+ [
642
+ BasicTransformerBlock(
643
+ dim=output_channel,
644
+ num_attention_heads=num_heads,
645
+ attention_head_dim=attention_head_dim,
646
+ dropout=dropout,
647
+ activation_fn=act_fn,
648
+ )
649
+ for _ in range(n_blocks)
650
+ ]
651
+ )
652
+ upsample = (
653
+ Upsample1D(output_channel, use_conv_transpose=True)
654
+ if not is_last
655
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
656
+ )
657
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
658
+ self.final_block = Block1D(channels[-1], channels[-1])
659
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
660
+ self.initialize_weights()
661
+
662
+ def initialize_weights(self):
663
+ for m in self.modules():
664
+ if isinstance(m, nn.Conv1d):
665
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
666
+ if m.bias is not None:
667
+ nn.init.constant_(m.bias, 0)
668
+ elif isinstance(m, nn.GroupNorm):
669
+ nn.init.constant_(m.weight, 1)
670
+ nn.init.constant_(m.bias, 0)
671
+ elif isinstance(m, nn.Linear):
672
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
673
+ if m.bias is not None:
674
+ nn.init.constant_(m.bias, 0)
675
+
676
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
677
+ """Forward pass of the UNet1DConditional model.
678
+
679
+ Args:
680
+ x (torch.Tensor): shape (batch_size, in_channels, time)
681
+ mask (_type_): shape (batch_size, 1, time)
682
+ t (_type_): shape (batch_size)
683
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
684
+ cond (_type_, optional): placeholder for future use. Defaults to None.
685
+
686
+ Raises:
687
+ ValueError: _description_
688
+ ValueError: _description_
689
+
690
+ Returns:
691
+ _type_: _description_
692
+ """
693
+
694
+ t = self.time_embeddings(t).to(t.dtype)
695
+ t = self.time_mlp(t)
696
+
697
+ x = pack([x, mu], "b * t")[0]
698
+
699
+ if spks is not None:
700
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
701
+ x = pack([x, spks], "b * t")[0]
702
+ if cond is not None:
703
+ x = pack([x, cond], "b * t")[0]
704
+
705
+ hiddens = []
706
+ masks = [mask]
707
+ for resnet, transformer_blocks, downsample in self.down_blocks:
708
+ mask_down = masks[-1]
709
+ x = resnet(x, mask_down, t)
710
+ x = rearrange(x, "b c t -> b t c").contiguous()
711
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
712
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
713
+ for transformer_block in transformer_blocks:
714
+ x = transformer_block(
715
+ hidden_states=x,
716
+ attention_mask=attn_mask,
717
+ timestep=t,
718
+ )
719
+ x = rearrange(x, "b t c -> b c t").contiguous()
720
+ hiddens.append(x) # Save hidden states for skip connections
721
+ x = downsample(x * mask_down)
722
+ masks.append(mask_down[:, :, ::2])
723
+ masks = masks[:-1]
724
+ mask_mid = masks[-1]
725
+
726
+ for resnet, transformer_blocks in self.mid_blocks:
727
+ x = resnet(x, mask_mid, t)
728
+ x = rearrange(x, "b c t -> b t c").contiguous()
729
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
730
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
731
+ for transformer_block in transformer_blocks:
732
+ x = transformer_block(
733
+ hidden_states=x,
734
+ attention_mask=attn_mask,
735
+ timestep=t,
736
+ )
737
+ x = rearrange(x, "b t c -> b c t").contiguous()
738
+
739
+ for resnet, transformer_blocks, upsample in self.up_blocks:
740
+ mask_up = masks.pop()
741
+ skip = hiddens.pop()
742
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
743
+ x = resnet(x, mask_up, t)
744
+ x = rearrange(x, "b c t -> b t c").contiguous()
745
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
746
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
747
+ for transformer_block in transformer_blocks:
748
+ x = transformer_block(
749
+ hidden_states=x,
750
+ attention_mask=attn_mask,
751
+ timestep=t,
752
+ )
753
+ x = rearrange(x, "b t c -> b c t").contiguous()
754
+ x = upsample(x * mask_up)
755
+ x = self.final_block(x, mask_up)
756
+ output = self.final_proj(x * mask_up)
757
+ return output * mask
758
+
759
+
760
+ class CausalConditionalDecoder(ConditionalDecoder):
761
+ """
762
+ This decoder requires an input with the same shape of the target. So, if your text content
763
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
764
+
765
+ Args:
766
+ in_channels: number of input channels
767
+ out_channels: number of output channels
768
+ channels: list of channel dimensions
769
+ dropout: dropout rate
770
+ attention_head_dim: dimension of attention heads
771
+ n_blocks: number of transformer blocks
772
+ num_mid_blocks: number of middle blocks
773
+ num_heads: number of attention heads
774
+ act_fn: activation function name
775
+ static_chunk_size: size of static chunks
776
+ num_decoding_left_chunks: number of left chunks for decoding
777
+ """
778
+
779
+ def __init__(
780
+ self,
781
+ in_channels=320,
782
+ out_channels=80,
783
+ channels=[256], # noqa
784
+ dropout=0.0,
785
+ attention_head_dim=64,
786
+ n_blocks=4,
787
+ num_mid_blocks=12,
788
+ num_heads=8,
789
+ act_fn="gelu",
790
+ static_chunk_size=50,
791
+ num_decoding_left_chunks=-1,
792
+ ):
793
+ torch.nn.Module.__init__(self)
794
+ channels = tuple(channels)
795
+ self.in_channels = in_channels
796
+ self.out_channels = out_channels
797
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
798
+ time_embed_dim = channels[0] * 4
799
+ self.time_mlp = TimestepEmbedding(
800
+ in_channels=in_channels,
801
+ time_embed_dim=time_embed_dim,
802
+ act_fn="silu",
803
+ )
804
+ self.static_chunk_size = static_chunk_size
805
+ self.num_decoding_left_chunks = num_decoding_left_chunks
806
+ self.down_blocks = nn.ModuleList([])
807
+ self.mid_blocks = nn.ModuleList([])
808
+ self.up_blocks = nn.ModuleList([])
809
+
810
+ output_channel = in_channels
811
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
812
+ input_channel = output_channel
813
+ output_channel = channels[i]
814
+ is_last = i == len(channels) - 1
815
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
816
+ transformer_blocks = nn.ModuleList(
817
+ [
818
+ BasicTransformerBlock(
819
+ dim=output_channel,
820
+ num_attention_heads=num_heads,
821
+ attention_head_dim=attention_head_dim,
822
+ dropout=dropout,
823
+ activation_fn=act_fn,
824
+ )
825
+ for _ in range(n_blocks)
826
+ ]
827
+ )
828
+ downsample = (
829
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
830
+ )
831
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
832
+
833
+ for _ in range(num_mid_blocks):
834
+ input_channel = channels[-1]
835
+ out_channels = channels[-1]
836
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
837
+
838
+ transformer_blocks = nn.ModuleList(
839
+ [
840
+ BasicTransformerBlock(
841
+ dim=output_channel,
842
+ num_attention_heads=num_heads,
843
+ attention_head_dim=attention_head_dim,
844
+ dropout=dropout,
845
+ activation_fn=act_fn,
846
+ )
847
+ for _ in range(n_blocks)
848
+ ]
849
+ )
850
+
851
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
852
+
853
+ channels = channels[::-1] + (channels[0],)
854
+ for i in range(len(channels) - 1):
855
+ input_channel = channels[i] * 2
856
+ output_channel = channels[i + 1]
857
+ is_last = i == len(channels) - 2
858
+ resnet = CausalResnetBlock1D(
859
+ dim=input_channel,
860
+ dim_out=output_channel,
861
+ time_emb_dim=time_embed_dim,
862
+ )
863
+ transformer_blocks = nn.ModuleList(
864
+ [
865
+ BasicTransformerBlock(
866
+ dim=output_channel,
867
+ num_attention_heads=num_heads,
868
+ attention_head_dim=attention_head_dim,
869
+ dropout=dropout,
870
+ activation_fn=act_fn,
871
+ )
872
+ for _ in range(n_blocks)
873
+ ]
874
+ )
875
+ upsample = (
876
+ Upsample1D(output_channel, use_conv_transpose=True)
877
+ if not is_last
878
+ else CausalConv1d(output_channel, output_channel, 3)
879
+ )
880
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
881
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
882
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
883
+ self.initialize_weights()
884
+
885
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
886
+ """Forward pass of the UNet1DConditional model.
887
+
888
+ Args:
889
+ x (torch.Tensor): shape (batch_size, in_channels, time)
890
+ mask (_type_): shape (batch_size, 1, time)
891
+ t (_type_): shape (batch_size)
892
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
893
+ cond (_type_, optional): placeholder for future use. Defaults to None.
894
+
895
+ Raises:
896
+ ValueError: _description_
897
+ ValueError: _description_
898
+
899
+ Returns:
900
+ _type_: _description_
901
+ """
902
+ t = self.time_embeddings(t).to(t.dtype)
903
+ t = self.time_mlp(t)
904
+
905
+ x = pack([x, mu], "b * t")[0]
906
+
907
+ if spks is not None:
908
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
909
+ x = pack([x, spks], "b * t")[0]
910
+ if cond is not None:
911
+ x = pack([x, cond], "b * t")[0]
912
+
913
+ hiddens = []
914
+ masks = [mask]
915
+ for resnet, transformer_blocks, downsample in self.down_blocks:
916
+ mask_down = masks[-1]
917
+ x = resnet(x, mask_down, t)
918
+ x = rearrange(x, "b c t -> b t c").contiguous()
919
+ if streaming is True:
920
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
921
+ else:
922
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
923
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
924
+ for transformer_block in transformer_blocks:
925
+ x = transformer_block(
926
+ hidden_states=x,
927
+ attention_mask=attn_mask,
928
+ timestep=t,
929
+ )
930
+ x = rearrange(x, "b t c -> b c t").contiguous()
931
+ hiddens.append(x) # Save hidden states for skip connections
932
+ x = downsample(x * mask_down)
933
+ masks.append(mask_down[:, :, ::2])
934
+ masks = masks[:-1]
935
+ mask_mid = masks[-1]
936
+
937
+ for resnet, transformer_blocks in self.mid_blocks:
938
+ x = resnet(x, mask_mid, t)
939
+ x = rearrange(x, "b c t -> b t c").contiguous()
940
+ if streaming is True:
941
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
942
+ else:
943
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
944
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
945
+ for transformer_block in transformer_blocks:
946
+ x = transformer_block(
947
+ hidden_states=x,
948
+ attention_mask=attn_mask,
949
+ timestep=t,
950
+ )
951
+ x = rearrange(x, "b t c -> b c t").contiguous()
952
+
953
+ for resnet, transformer_blocks, upsample in self.up_blocks:
954
+ mask_up = masks.pop()
955
+ skip = hiddens.pop()
956
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
957
+ x = resnet(x, mask_up, t)
958
+ x = rearrange(x, "b c t -> b t c").contiguous()
959
+ if streaming is True:
960
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
961
+ else:
962
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
963
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
964
+ for transformer_block in transformer_blocks:
965
+ x = transformer_block(
966
+ hidden_states=x,
967
+ attention_mask=attn_mask,
968
+ timestep=t,
969
+ )
970
+ x = rearrange(x, "b t c -> b c t").contiguous()
971
+ x = upsample(x * mask_up)
972
+ x = self.final_block(x, mask_up)
973
+ output = self.final_proj(x * mask_up)
974
+ return output * mask