minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,494 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import pack, rearrange, repeat
19
+ from cosyvoice.utils.common import mask_to_bias
20
+ from cosyvoice.utils.mask import add_optional_chunk_mask
21
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
22
+ from matcha.models.components.transformer import BasicTransformerBlock
23
+
24
+
25
+ class Transpose(torch.nn.Module):
26
+ def __init__(self, dim0: int, dim1: int):
27
+ super().__init__()
28
+ self.dim0 = dim0
29
+ self.dim1 = dim1
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ x = torch.transpose(x, self.dim0, self.dim1)
33
+ return x
34
+
35
+
36
+ class CausalConv1d(torch.nn.Conv1d):
37
+ def __init__(
38
+ self,
39
+ in_channels: int,
40
+ out_channels: int,
41
+ kernel_size: int,
42
+ stride: int = 1,
43
+ dilation: int = 1,
44
+ groups: int = 1,
45
+ bias: bool = True,
46
+ padding_mode: str = 'zeros',
47
+ device=None,
48
+ dtype=None
49
+ ) -> None:
50
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
51
+ kernel_size, stride,
52
+ padding=0, dilation=dilation,
53
+ groups=groups, bias=bias,
54
+ padding_mode=padding_mode,
55
+ device=device, dtype=dtype)
56
+ assert stride == 1
57
+ self.causal_padding = kernel_size - 1
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
61
+ x = super(CausalConv1d, self).forward(x)
62
+ return x
63
+
64
+
65
+ class CausalBlock1D(Block1D):
66
+ def __init__(self, dim: int, dim_out: int):
67
+ super(CausalBlock1D, self).__init__(dim, dim_out)
68
+ self.block = torch.nn.Sequential(
69
+ CausalConv1d(dim, dim_out, 3),
70
+ Transpose(1, 2),
71
+ nn.LayerNorm(dim_out),
72
+ Transpose(1, 2),
73
+ nn.Mish(),
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ output = self.block(x * mask)
78
+ return output * mask
79
+
80
+
81
+ class CausalResnetBlock1D(ResnetBlock1D):
82
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
83
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
84
+ self.block1 = CausalBlock1D(dim, dim_out)
85
+ self.block2 = CausalBlock1D(dim_out, dim_out)
86
+
87
+
88
+ class ConditionalDecoder(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels,
92
+ out_channels,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
127
+ transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(
130
+ dim=output_channel,
131
+ num_attention_heads=num_heads,
132
+ attention_head_dim=attention_head_dim,
133
+ dropout=dropout,
134
+ activation_fn=act_fn,
135
+ )
136
+ for _ in range(n_blocks)
137
+ ]
138
+ )
139
+ downsample = (
140
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
141
+ )
142
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
143
+
144
+ for _ in range(num_mid_blocks):
145
+ input_channel = channels[-1]
146
+ out_channels = channels[-1]
147
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
148
+
149
+ transformer_blocks = nn.ModuleList(
150
+ [
151
+ BasicTransformerBlock(
152
+ dim=output_channel,
153
+ num_attention_heads=num_heads,
154
+ attention_head_dim=attention_head_dim,
155
+ dropout=dropout,
156
+ activation_fn=act_fn,
157
+ )
158
+ for _ in range(n_blocks)
159
+ ]
160
+ )
161
+
162
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
163
+
164
+ channels = channels[::-1] + (channels[0],)
165
+ for i in range(len(channels) - 1):
166
+ input_channel = channels[i] * 2
167
+ output_channel = channels[i + 1]
168
+ is_last = i == len(channels) - 2
169
+ resnet = ResnetBlock1D(
170
+ dim=input_channel,
171
+ dim_out=output_channel,
172
+ time_emb_dim=time_embed_dim,
173
+ )
174
+ transformer_blocks = nn.ModuleList(
175
+ [
176
+ BasicTransformerBlock(
177
+ dim=output_channel,
178
+ num_attention_heads=num_heads,
179
+ attention_head_dim=attention_head_dim,
180
+ dropout=dropout,
181
+ activation_fn=act_fn,
182
+ )
183
+ for _ in range(n_blocks)
184
+ ]
185
+ )
186
+ upsample = (
187
+ Upsample1D(output_channel, use_conv_transpose=True)
188
+ if not is_last
189
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
190
+ )
191
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
192
+ self.final_block = Block1D(channels[-1], channels[-1])
193
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
194
+ self.initialize_weights()
195
+
196
+ def initialize_weights(self):
197
+ for m in self.modules():
198
+ if isinstance(m, nn.Conv1d):
199
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
200
+ if m.bias is not None:
201
+ nn.init.constant_(m.bias, 0)
202
+ elif isinstance(m, nn.GroupNorm):
203
+ nn.init.constant_(m.weight, 1)
204
+ nn.init.constant_(m.bias, 0)
205
+ elif isinstance(m, nn.Linear):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+
210
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
211
+ """Forward pass of the UNet1DConditional model.
212
+
213
+ Args:
214
+ x (torch.Tensor): shape (batch_size, in_channels, time)
215
+ mask (_type_): shape (batch_size, 1, time)
216
+ t (_type_): shape (batch_size)
217
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
218
+ cond (_type_, optional): placeholder for future use. Defaults to None.
219
+
220
+ Raises:
221
+ ValueError: _description_
222
+ ValueError: _description_
223
+
224
+ Returns:
225
+ _type_: _description_
226
+ """
227
+
228
+ t = self.time_embeddings(t).to(t.dtype)
229
+ t = self.time_mlp(t)
230
+
231
+ x = pack([x, mu], "b * t")[0]
232
+
233
+ if spks is not None:
234
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
235
+ x = pack([x, spks], "b * t")[0]
236
+ if cond is not None:
237
+ x = pack([x, cond], "b * t")[0]
238
+
239
+ hiddens = []
240
+ masks = [mask]
241
+ for resnet, transformer_blocks, downsample in self.down_blocks:
242
+ mask_down = masks[-1]
243
+ x = resnet(x, mask_down, t)
244
+ x = rearrange(x, "b c t -> b t c").contiguous()
245
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
246
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
247
+ for transformer_block in transformer_blocks:
248
+ x = transformer_block(
249
+ hidden_states=x,
250
+ attention_mask=attn_mask,
251
+ timestep=t,
252
+ )
253
+ x = rearrange(x, "b t c -> b c t").contiguous()
254
+ hiddens.append(x) # Save hidden states for skip connections
255
+ x = downsample(x * mask_down)
256
+ masks.append(mask_down[:, :, ::2])
257
+ masks = masks[:-1]
258
+ mask_mid = masks[-1]
259
+
260
+ for resnet, transformer_blocks in self.mid_blocks:
261
+ x = resnet(x, mask_mid, t)
262
+ x = rearrange(x, "b c t -> b t c").contiguous()
263
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
264
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
265
+ for transformer_block in transformer_blocks:
266
+ x = transformer_block(
267
+ hidden_states=x,
268
+ attention_mask=attn_mask,
269
+ timestep=t,
270
+ )
271
+ x = rearrange(x, "b t c -> b c t").contiguous()
272
+
273
+ for resnet, transformer_blocks, upsample in self.up_blocks:
274
+ mask_up = masks.pop()
275
+ skip = hiddens.pop()
276
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
277
+ x = resnet(x, mask_up, t)
278
+ x = rearrange(x, "b c t -> b t c").contiguous()
279
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
280
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
281
+ for transformer_block in transformer_blocks:
282
+ x = transformer_block(
283
+ hidden_states=x,
284
+ attention_mask=attn_mask,
285
+ timestep=t,
286
+ )
287
+ x = rearrange(x, "b t c -> b c t").contiguous()
288
+ x = upsample(x * mask_up)
289
+ x = self.final_block(x, mask_up)
290
+ output = self.final_proj(x * mask_up)
291
+ return output * mask
292
+
293
+
294
+ class CausalConditionalDecoder(ConditionalDecoder):
295
+ def __init__(
296
+ self,
297
+ in_channels,
298
+ out_channels,
299
+ channels=(256, 256),
300
+ dropout=0.05,
301
+ attention_head_dim=64,
302
+ n_blocks=1,
303
+ num_mid_blocks=2,
304
+ num_heads=4,
305
+ act_fn="snake",
306
+ static_chunk_size=50,
307
+ num_decoding_left_chunks=2,
308
+ ):
309
+ """
310
+ This decoder requires an input with the same shape of the target. So, if your text content
311
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
312
+ """
313
+ torch.nn.Module.__init__(self)
314
+ channels = tuple(channels)
315
+ self.in_channels = in_channels
316
+ self.out_channels = out_channels
317
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
318
+ time_embed_dim = channels[0] * 4
319
+ self.time_mlp = TimestepEmbedding(
320
+ in_channels=in_channels,
321
+ time_embed_dim=time_embed_dim,
322
+ act_fn="silu",
323
+ )
324
+ self.static_chunk_size = static_chunk_size
325
+ self.num_decoding_left_chunks = num_decoding_left_chunks
326
+ self.down_blocks = nn.ModuleList([])
327
+ self.mid_blocks = nn.ModuleList([])
328
+ self.up_blocks = nn.ModuleList([])
329
+
330
+ output_channel = in_channels
331
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
332
+ input_channel = output_channel
333
+ output_channel = channels[i]
334
+ is_last = i == len(channels) - 1
335
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
336
+ transformer_blocks = nn.ModuleList(
337
+ [
338
+ BasicTransformerBlock(
339
+ dim=output_channel,
340
+ num_attention_heads=num_heads,
341
+ attention_head_dim=attention_head_dim,
342
+ dropout=dropout,
343
+ activation_fn=act_fn,
344
+ )
345
+ for _ in range(n_blocks)
346
+ ]
347
+ )
348
+ downsample = (
349
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
350
+ )
351
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
352
+
353
+ for _ in range(num_mid_blocks):
354
+ input_channel = channels[-1]
355
+ out_channels = channels[-1]
356
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
357
+
358
+ transformer_blocks = nn.ModuleList(
359
+ [
360
+ BasicTransformerBlock(
361
+ dim=output_channel,
362
+ num_attention_heads=num_heads,
363
+ attention_head_dim=attention_head_dim,
364
+ dropout=dropout,
365
+ activation_fn=act_fn,
366
+ )
367
+ for _ in range(n_blocks)
368
+ ]
369
+ )
370
+
371
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
372
+
373
+ channels = channels[::-1] + (channels[0],)
374
+ for i in range(len(channels) - 1):
375
+ input_channel = channels[i] * 2
376
+ output_channel = channels[i + 1]
377
+ is_last = i == len(channels) - 2
378
+ resnet = CausalResnetBlock1D(
379
+ dim=input_channel,
380
+ dim_out=output_channel,
381
+ time_emb_dim=time_embed_dim,
382
+ )
383
+ transformer_blocks = nn.ModuleList(
384
+ [
385
+ BasicTransformerBlock(
386
+ dim=output_channel,
387
+ num_attention_heads=num_heads,
388
+ attention_head_dim=attention_head_dim,
389
+ dropout=dropout,
390
+ activation_fn=act_fn,
391
+ )
392
+ for _ in range(n_blocks)
393
+ ]
394
+ )
395
+ upsample = (
396
+ Upsample1D(output_channel, use_conv_transpose=True)
397
+ if not is_last
398
+ else CausalConv1d(output_channel, output_channel, 3)
399
+ )
400
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
401
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
402
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
403
+ self.initialize_weights()
404
+
405
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
406
+ """Forward pass of the UNet1DConditional model.
407
+
408
+ Args:
409
+ x (torch.Tensor): shape (batch_size, in_channels, time)
410
+ mask (_type_): shape (batch_size, 1, time)
411
+ t (_type_): shape (batch_size)
412
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
413
+ cond (_type_, optional): placeholder for future use. Defaults to None.
414
+
415
+ Raises:
416
+ ValueError: _description_
417
+ ValueError: _description_
418
+
419
+ Returns:
420
+ _type_: _description_
421
+ """
422
+ t = self.time_embeddings(t).to(t.dtype)
423
+ t = self.time_mlp(t)
424
+
425
+ x = pack([x, mu], "b * t")[0]
426
+
427
+ if spks is not None:
428
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
429
+ x = pack([x, spks], "b * t")[0]
430
+ if cond is not None:
431
+ x = pack([x, cond], "b * t")[0]
432
+
433
+ hiddens = []
434
+ masks = [mask]
435
+ for resnet, transformer_blocks, downsample in self.down_blocks:
436
+ mask_down = masks[-1]
437
+ x = resnet(x, mask_down, t)
438
+ x = rearrange(x, "b c t -> b t c").contiguous()
439
+ if streaming is True:
440
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
441
+ else:
442
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
443
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
444
+ for transformer_block in transformer_blocks:
445
+ x = transformer_block(
446
+ hidden_states=x,
447
+ attention_mask=attn_mask,
448
+ timestep=t,
449
+ )
450
+ x = rearrange(x, "b t c -> b c t").contiguous()
451
+ hiddens.append(x) # Save hidden states for skip connections
452
+ x = downsample(x * mask_down)
453
+ masks.append(mask_down[:, :, ::2])
454
+ masks = masks[:-1]
455
+ mask_mid = masks[-1]
456
+
457
+ for resnet, transformer_blocks in self.mid_blocks:
458
+ x = resnet(x, mask_mid, t)
459
+ x = rearrange(x, "b c t -> b t c").contiguous()
460
+ if streaming is True:
461
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
462
+ else:
463
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
464
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
465
+ for transformer_block in transformer_blocks:
466
+ x = transformer_block(
467
+ hidden_states=x,
468
+ attention_mask=attn_mask,
469
+ timestep=t,
470
+ )
471
+ x = rearrange(x, "b t c -> b c t").contiguous()
472
+
473
+ for resnet, transformer_blocks, upsample in self.up_blocks:
474
+ mask_up = masks.pop()
475
+ skip = hiddens.pop()
476
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
477
+ x = resnet(x, mask_up, t)
478
+ x = rearrange(x, "b c t -> b t c").contiguous()
479
+ if streaming is True:
480
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
481
+ else:
482
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
483
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
484
+ for transformer_block in transformer_blocks:
485
+ x = transformer_block(
486
+ hidden_states=x,
487
+ attention_mask=attn_mask,
488
+ timestep=t,
489
+ )
490
+ x = rearrange(x, "b t c -> b c t").contiguous()
491
+ x = upsample(x * mask_up)
492
+ x = self.final_block(x, mask_up)
493
+ output = self.final_proj(x * mask_up)
494
+ return output * mask