minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,974 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
|
|
8
|
+
AdaLayerNormZero, ApproximateGELU)
|
|
9
|
+
from diffusers.models.attention_processor import Attention
|
|
10
|
+
from diffusers.models.lora import LoRACompatibleLinear
|
|
11
|
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
12
|
+
from einops import pack, rearrange, repeat
|
|
13
|
+
|
|
14
|
+
from stepaudio2.flashcosyvoice.modules.flow_components.upsample_encoder import \
|
|
15
|
+
add_optional_chunk_mask
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
19
|
+
assert mask.dtype == torch.bool
|
|
20
|
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
|
21
|
+
mask = mask.to(dtype)
|
|
22
|
+
# attention mask bias
|
|
23
|
+
# NOTE(Mddct): torch.finfo jit issues
|
|
24
|
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
|
25
|
+
mask = (1.0 - mask) * -1.0e+10
|
|
26
|
+
return mask
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SnakeBeta(nn.Module):
|
|
30
|
+
"""
|
|
31
|
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
|
32
|
+
Shape:
|
|
33
|
+
- Input: (B, C, T)
|
|
34
|
+
- Output: (B, C, T), same shape as the input
|
|
35
|
+
Parameters:
|
|
36
|
+
- alpha - trainable parameter that controls frequency
|
|
37
|
+
- beta - trainable parameter that controls magnitude
|
|
38
|
+
References:
|
|
39
|
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
|
40
|
+
https://arxiv.org/abs/2006.08195
|
|
41
|
+
Examples:
|
|
42
|
+
>>> a1 = snakebeta(256)
|
|
43
|
+
>>> x = torch.randn(256)
|
|
44
|
+
>>> x = a1(x)
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
in_features: shape of the input
|
|
48
|
+
out_features: shape of the output
|
|
49
|
+
alpha: trainable parameter that controls frequency
|
|
50
|
+
alpha_trainable: whether alpha is trainable
|
|
51
|
+
alpha_logscale: whether to use log scale for alpha
|
|
52
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
|
53
|
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
|
54
|
+
alpha will be trained along with the rest of your model.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
|
60
|
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
|
61
|
+
|
|
62
|
+
# initialize alpha
|
|
63
|
+
self.alpha_logscale = alpha_logscale
|
|
64
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
65
|
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
|
66
|
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
|
67
|
+
else: # linear scale alphas initialized to ones
|
|
68
|
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
|
69
|
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
|
70
|
+
|
|
71
|
+
self.alpha.requires_grad = alpha_trainable
|
|
72
|
+
self.beta.requires_grad = alpha_trainable
|
|
73
|
+
|
|
74
|
+
self.no_div_by_zero = 0.000000001
|
|
75
|
+
|
|
76
|
+
def forward(self, x):
|
|
77
|
+
"""
|
|
78
|
+
Forward pass of the function.
|
|
79
|
+
Applies the function to the input elementwise.
|
|
80
|
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
|
81
|
+
"""
|
|
82
|
+
x = self.proj(x)
|
|
83
|
+
if self.alpha_logscale:
|
|
84
|
+
alpha = torch.exp(self.alpha)
|
|
85
|
+
beta = torch.exp(self.beta)
|
|
86
|
+
else:
|
|
87
|
+
alpha = self.alpha
|
|
88
|
+
beta = self.beta
|
|
89
|
+
|
|
90
|
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
|
91
|
+
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class FeedForward(nn.Module):
|
|
96
|
+
r"""
|
|
97
|
+
A feed-forward layer.
|
|
98
|
+
|
|
99
|
+
Parameters:
|
|
100
|
+
dim (`int`): The number of channels in the input.
|
|
101
|
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
|
102
|
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
|
103
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
104
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
105
|
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
dim: int,
|
|
111
|
+
dim_out: Optional[int] = None,
|
|
112
|
+
mult: int = 4,
|
|
113
|
+
dropout: float = 0.0,
|
|
114
|
+
activation_fn: str = "geglu",
|
|
115
|
+
final_dropout: bool = False,
|
|
116
|
+
):
|
|
117
|
+
super().__init__()
|
|
118
|
+
inner_dim = int(dim * mult)
|
|
119
|
+
dim_out = dim_out if dim_out is not None else dim
|
|
120
|
+
|
|
121
|
+
if activation_fn == "gelu":
|
|
122
|
+
act_fn = GELU(dim, inner_dim)
|
|
123
|
+
if activation_fn == "gelu-approximate":
|
|
124
|
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
|
125
|
+
elif activation_fn == "geglu":
|
|
126
|
+
act_fn = GEGLU(dim, inner_dim)
|
|
127
|
+
elif activation_fn == "geglu-approximate":
|
|
128
|
+
act_fn = ApproximateGELU(dim, inner_dim)
|
|
129
|
+
elif activation_fn == "snakebeta":
|
|
130
|
+
act_fn = SnakeBeta(dim, inner_dim)
|
|
131
|
+
|
|
132
|
+
self.net = nn.ModuleList([])
|
|
133
|
+
# project in
|
|
134
|
+
self.net.append(act_fn)
|
|
135
|
+
# project dropout
|
|
136
|
+
self.net.append(nn.Dropout(dropout))
|
|
137
|
+
# project out
|
|
138
|
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
|
139
|
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
|
140
|
+
if final_dropout:
|
|
141
|
+
self.net.append(nn.Dropout(dropout))
|
|
142
|
+
|
|
143
|
+
def forward(self, hidden_states):
|
|
144
|
+
for module in self.net:
|
|
145
|
+
hidden_states = module(hidden_states)
|
|
146
|
+
return hidden_states
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@maybe_allow_in_graph
|
|
150
|
+
class BasicTransformerBlock(nn.Module):
|
|
151
|
+
r"""
|
|
152
|
+
A basic Transformer block.
|
|
153
|
+
|
|
154
|
+
Parameters:
|
|
155
|
+
dim (`int`): The number of channels in the input and output.
|
|
156
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
|
157
|
+
attention_head_dim (`int`): The number of channels in each head.
|
|
158
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
159
|
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
|
160
|
+
only_cross_attention (`bool`, *optional*):
|
|
161
|
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
|
162
|
+
double_self_attention (`bool`, *optional*):
|
|
163
|
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
|
164
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
165
|
+
num_embeds_ada_norm (:
|
|
166
|
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
|
167
|
+
attention_bias (:
|
|
168
|
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
dim: int,
|
|
174
|
+
num_attention_heads: int,
|
|
175
|
+
attention_head_dim: int,
|
|
176
|
+
dropout=0.0,
|
|
177
|
+
cross_attention_dim: Optional[int] = None,
|
|
178
|
+
activation_fn: str = "geglu",
|
|
179
|
+
num_embeds_ada_norm: Optional[int] = None,
|
|
180
|
+
attention_bias: bool = False,
|
|
181
|
+
only_cross_attention: bool = False,
|
|
182
|
+
double_self_attention: bool = False,
|
|
183
|
+
upcast_attention: bool = False,
|
|
184
|
+
norm_elementwise_affine: bool = True,
|
|
185
|
+
norm_type: str = "layer_norm",
|
|
186
|
+
final_dropout: bool = False,
|
|
187
|
+
):
|
|
188
|
+
super().__init__()
|
|
189
|
+
self.only_cross_attention = only_cross_attention
|
|
190
|
+
|
|
191
|
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
|
192
|
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
|
193
|
+
|
|
194
|
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
|
197
|
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Define 3 blocks. Each block has its own normalization layer.
|
|
201
|
+
# 1. Self-Attn
|
|
202
|
+
if self.use_ada_layer_norm:
|
|
203
|
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
|
204
|
+
elif self.use_ada_layer_norm_zero:
|
|
205
|
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
|
206
|
+
else:
|
|
207
|
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
208
|
+
self.attn1 = Attention(
|
|
209
|
+
query_dim=dim,
|
|
210
|
+
heads=num_attention_heads,
|
|
211
|
+
dim_head=attention_head_dim,
|
|
212
|
+
dropout=dropout,
|
|
213
|
+
bias=attention_bias,
|
|
214
|
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
215
|
+
upcast_attention=upcast_attention,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# 2. Cross-Attn
|
|
219
|
+
if cross_attention_dim is not None or double_self_attention:
|
|
220
|
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
|
221
|
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
|
222
|
+
# the second cross attention block.
|
|
223
|
+
self.norm2 = (
|
|
224
|
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
|
225
|
+
if self.use_ada_layer_norm
|
|
226
|
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
227
|
+
)
|
|
228
|
+
self.attn2 = Attention(
|
|
229
|
+
query_dim=dim,
|
|
230
|
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
|
231
|
+
heads=num_attention_heads,
|
|
232
|
+
dim_head=attention_head_dim,
|
|
233
|
+
dropout=dropout,
|
|
234
|
+
bias=attention_bias,
|
|
235
|
+
upcast_attention=upcast_attention,
|
|
236
|
+
# scale_qk=False, # uncomment this to not to use flash attention
|
|
237
|
+
) # is self-attn if encoder_hidden_states is none
|
|
238
|
+
else:
|
|
239
|
+
self.norm2 = None
|
|
240
|
+
self.attn2 = None
|
|
241
|
+
|
|
242
|
+
# 3. Feed-forward
|
|
243
|
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
244
|
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
|
245
|
+
|
|
246
|
+
# let chunk size default to None
|
|
247
|
+
self._chunk_size = None
|
|
248
|
+
self._chunk_dim = 0
|
|
249
|
+
|
|
250
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
|
251
|
+
# Sets chunk feed-forward
|
|
252
|
+
self._chunk_size = chunk_size
|
|
253
|
+
self._chunk_dim = dim
|
|
254
|
+
|
|
255
|
+
def forward(
|
|
256
|
+
self,
|
|
257
|
+
hidden_states: torch.FloatTensor,
|
|
258
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
259
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
260
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
261
|
+
timestep: Optional[torch.LongTensor] = None,
|
|
262
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
|
263
|
+
class_labels: Optional[torch.LongTensor] = None,
|
|
264
|
+
):
|
|
265
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
266
|
+
# 1. Self-Attention
|
|
267
|
+
if self.use_ada_layer_norm:
|
|
268
|
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
269
|
+
elif self.use_ada_layer_norm_zero:
|
|
270
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
271
|
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
norm_hidden_states = self.norm1(hidden_states)
|
|
275
|
+
|
|
276
|
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
277
|
+
|
|
278
|
+
attn_output = self.attn1(
|
|
279
|
+
norm_hidden_states,
|
|
280
|
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
281
|
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
|
282
|
+
**cross_attention_kwargs,
|
|
283
|
+
)
|
|
284
|
+
if self.use_ada_layer_norm_zero:
|
|
285
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
286
|
+
hidden_states = attn_output + hidden_states
|
|
287
|
+
|
|
288
|
+
# 2. Cross-Attention
|
|
289
|
+
if self.attn2 is not None:
|
|
290
|
+
norm_hidden_states = (
|
|
291
|
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
attn_output = self.attn2(
|
|
295
|
+
norm_hidden_states,
|
|
296
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
297
|
+
attention_mask=encoder_attention_mask,
|
|
298
|
+
**cross_attention_kwargs,
|
|
299
|
+
)
|
|
300
|
+
hidden_states = attn_output + hidden_states
|
|
301
|
+
|
|
302
|
+
# 3. Feed-forward
|
|
303
|
+
norm_hidden_states = self.norm3(hidden_states)
|
|
304
|
+
|
|
305
|
+
if self.use_ada_layer_norm_zero:
|
|
306
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
307
|
+
|
|
308
|
+
if self._chunk_size is not None:
|
|
309
|
+
# "feed_forward_chunk_size" can be used to save memory
|
|
310
|
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
|
316
|
+
ff_output = torch.cat(
|
|
317
|
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
|
318
|
+
dim=self._chunk_dim,
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
ff_output = self.ff(norm_hidden_states)
|
|
322
|
+
|
|
323
|
+
if self.use_ada_layer_norm_zero:
|
|
324
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
325
|
+
|
|
326
|
+
hidden_states = ff_output + hidden_states
|
|
327
|
+
|
|
328
|
+
return hidden_states
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class SinusoidalPosEmb(torch.nn.Module):
|
|
332
|
+
def __init__(self, dim):
|
|
333
|
+
super().__init__()
|
|
334
|
+
self.dim = dim
|
|
335
|
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
|
336
|
+
|
|
337
|
+
def forward(self, x, scale=1000):
|
|
338
|
+
if x.ndim < 1:
|
|
339
|
+
x = x.unsqueeze(0)
|
|
340
|
+
device = x.device
|
|
341
|
+
half_dim = self.dim // 2
|
|
342
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
343
|
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
|
344
|
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
|
345
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
346
|
+
return emb
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
class Block1D(torch.nn.Module):
|
|
350
|
+
def __init__(self, dim, dim_out, groups=8):
|
|
351
|
+
super().__init__()
|
|
352
|
+
self.block = torch.nn.Sequential(
|
|
353
|
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
|
354
|
+
torch.nn.GroupNorm(groups, dim_out),
|
|
355
|
+
nn.Mish(),
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
def forward(self, x, mask):
|
|
359
|
+
output = self.block(x * mask)
|
|
360
|
+
return output * mask
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class ResnetBlock1D(torch.nn.Module):
|
|
364
|
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
|
365
|
+
super().__init__()
|
|
366
|
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
|
367
|
+
|
|
368
|
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
|
369
|
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
|
370
|
+
|
|
371
|
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
|
372
|
+
|
|
373
|
+
def forward(self, x, mask, time_emb):
|
|
374
|
+
h = self.block1(x, mask)
|
|
375
|
+
h += self.mlp(time_emb).unsqueeze(-1)
|
|
376
|
+
h = self.block2(h, mask)
|
|
377
|
+
output = h + self.res_conv(x * mask)
|
|
378
|
+
return output
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class Downsample1D(nn.Module):
|
|
382
|
+
def __init__(self, dim):
|
|
383
|
+
super().__init__()
|
|
384
|
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
|
385
|
+
|
|
386
|
+
def forward(self, x):
|
|
387
|
+
return self.conv(x)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class TimestepEmbedding(nn.Module):
|
|
391
|
+
def __init__(
|
|
392
|
+
self,
|
|
393
|
+
in_channels: int,
|
|
394
|
+
time_embed_dim: int,
|
|
395
|
+
act_fn: str = "silu",
|
|
396
|
+
out_dim: int = None,
|
|
397
|
+
post_act_fn: Optional[str] = None,
|
|
398
|
+
cond_proj_dim=None,
|
|
399
|
+
):
|
|
400
|
+
super().__init__()
|
|
401
|
+
assert act_fn == "silu", "act_fn must be silu"
|
|
402
|
+
|
|
403
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
|
404
|
+
|
|
405
|
+
if cond_proj_dim is not None:
|
|
406
|
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
|
407
|
+
else:
|
|
408
|
+
self.cond_proj = None
|
|
409
|
+
|
|
410
|
+
self.act = nn.SiLU()
|
|
411
|
+
|
|
412
|
+
if out_dim is not None:
|
|
413
|
+
time_embed_dim_out = out_dim
|
|
414
|
+
else:
|
|
415
|
+
time_embed_dim_out = time_embed_dim
|
|
416
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
|
417
|
+
|
|
418
|
+
if post_act_fn is None:
|
|
419
|
+
self.post_act = None
|
|
420
|
+
else:
|
|
421
|
+
self.post_act = nn.SiLU()
|
|
422
|
+
|
|
423
|
+
def forward(self, sample, condition=None):
|
|
424
|
+
if condition is not None:
|
|
425
|
+
sample = sample + self.cond_proj(condition)
|
|
426
|
+
sample = self.linear_1(sample)
|
|
427
|
+
|
|
428
|
+
if self.act is not None:
|
|
429
|
+
sample = self.act(sample)
|
|
430
|
+
|
|
431
|
+
sample = self.linear_2(sample)
|
|
432
|
+
|
|
433
|
+
if self.post_act is not None:
|
|
434
|
+
sample = self.post_act(sample)
|
|
435
|
+
return sample
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class Upsample1D(nn.Module):
|
|
439
|
+
"""A 1D upsampling layer with an optional convolution.
|
|
440
|
+
|
|
441
|
+
Parameters:
|
|
442
|
+
channels (`int`):
|
|
443
|
+
number of channels in the inputs and outputs.
|
|
444
|
+
use_conv (`bool`, default `False`):
|
|
445
|
+
option to use a convolution.
|
|
446
|
+
use_conv_transpose (`bool`, default `False`):
|
|
447
|
+
option to use a convolution transpose.
|
|
448
|
+
out_channels (`int`, optional):
|
|
449
|
+
number of output channels. Defaults to `channels`.
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
|
453
|
+
super().__init__()
|
|
454
|
+
self.channels = channels
|
|
455
|
+
self.out_channels = out_channels or channels
|
|
456
|
+
self.use_conv = use_conv
|
|
457
|
+
self.use_conv_transpose = use_conv_transpose
|
|
458
|
+
self.name = name
|
|
459
|
+
|
|
460
|
+
self.conv = None
|
|
461
|
+
if use_conv_transpose:
|
|
462
|
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
|
463
|
+
elif use_conv:
|
|
464
|
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
|
465
|
+
|
|
466
|
+
def forward(self, inputs):
|
|
467
|
+
assert inputs.shape[1] == self.channels
|
|
468
|
+
if self.use_conv_transpose:
|
|
469
|
+
return self.conv(inputs)
|
|
470
|
+
|
|
471
|
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
|
472
|
+
|
|
473
|
+
if self.use_conv:
|
|
474
|
+
outputs = self.conv(outputs)
|
|
475
|
+
|
|
476
|
+
return outputs
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class Transpose(torch.nn.Module):
|
|
480
|
+
def __init__(self, dim0: int, dim1: int):
|
|
481
|
+
super().__init__()
|
|
482
|
+
self.dim0 = dim0
|
|
483
|
+
self.dim1 = dim1
|
|
484
|
+
|
|
485
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
486
|
+
x = torch.transpose(x, self.dim0, self.dim1)
|
|
487
|
+
return x
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
class CausalConv1d(torch.nn.Conv1d):
|
|
491
|
+
def __init__(
|
|
492
|
+
self,
|
|
493
|
+
in_channels: int,
|
|
494
|
+
out_channels: int,
|
|
495
|
+
kernel_size: int,
|
|
496
|
+
stride: int = 1,
|
|
497
|
+
dilation: int = 1,
|
|
498
|
+
groups: int = 1,
|
|
499
|
+
bias: bool = True,
|
|
500
|
+
padding_mode: str = 'zeros',
|
|
501
|
+
device=None,
|
|
502
|
+
dtype=None
|
|
503
|
+
) -> None:
|
|
504
|
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
|
505
|
+
kernel_size, stride,
|
|
506
|
+
padding=0, dilation=dilation,
|
|
507
|
+
groups=groups, bias=bias,
|
|
508
|
+
padding_mode=padding_mode,
|
|
509
|
+
device=device, dtype=dtype)
|
|
510
|
+
assert stride == 1
|
|
511
|
+
self.causal_padding = kernel_size - 1
|
|
512
|
+
|
|
513
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
514
|
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
|
515
|
+
x = super(CausalConv1d, self).forward(x)
|
|
516
|
+
return x
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class CausalBlock1D(Block1D):
|
|
520
|
+
def __init__(self, dim: int, dim_out: int):
|
|
521
|
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
|
522
|
+
self.block = torch.nn.Sequential(
|
|
523
|
+
CausalConv1d(dim, dim_out, 3),
|
|
524
|
+
Transpose(1, 2),
|
|
525
|
+
nn.LayerNorm(dim_out),
|
|
526
|
+
Transpose(1, 2),
|
|
527
|
+
nn.Mish(),
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
531
|
+
output = self.block(x * mask)
|
|
532
|
+
return output * mask
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
class CausalResnetBlock1D(ResnetBlock1D):
|
|
536
|
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
|
537
|
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
538
|
+
self.block1 = CausalBlock1D(dim, dim_out)
|
|
539
|
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class ConditionalDecoder(nn.Module):
|
|
543
|
+
"""
|
|
544
|
+
This decoder requires an input with the same shape of the target. So, if your text content
|
|
545
|
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
in_channels: number of input channels
|
|
549
|
+
out_channels: number of output channels
|
|
550
|
+
channels: tuple of channel dimensions
|
|
551
|
+
dropout: dropout rate
|
|
552
|
+
attention_head_dim: dimension of attention heads
|
|
553
|
+
n_blocks: number of transformer blocks
|
|
554
|
+
num_mid_blocks: number of middle blocks
|
|
555
|
+
num_heads: number of attention heads
|
|
556
|
+
act_fn: activation function name
|
|
557
|
+
"""
|
|
558
|
+
|
|
559
|
+
def __init__(
|
|
560
|
+
self,
|
|
561
|
+
in_channels,
|
|
562
|
+
out_channels,
|
|
563
|
+
channels=(256, 256),
|
|
564
|
+
dropout=0.05,
|
|
565
|
+
attention_head_dim=64,
|
|
566
|
+
n_blocks=1,
|
|
567
|
+
num_mid_blocks=2,
|
|
568
|
+
num_heads=4,
|
|
569
|
+
act_fn="snake",
|
|
570
|
+
):
|
|
571
|
+
super().__init__()
|
|
572
|
+
channels = tuple(channels)
|
|
573
|
+
self.in_channels = in_channels
|
|
574
|
+
self.out_channels = out_channels
|
|
575
|
+
|
|
576
|
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
577
|
+
time_embed_dim = channels[0] * 4
|
|
578
|
+
self.time_mlp = TimestepEmbedding(
|
|
579
|
+
in_channels=in_channels,
|
|
580
|
+
time_embed_dim=time_embed_dim,
|
|
581
|
+
act_fn="silu",
|
|
582
|
+
)
|
|
583
|
+
self.down_blocks = nn.ModuleList([])
|
|
584
|
+
self.mid_blocks = nn.ModuleList([])
|
|
585
|
+
self.up_blocks = nn.ModuleList([])
|
|
586
|
+
|
|
587
|
+
output_channel = in_channels
|
|
588
|
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
|
589
|
+
input_channel = output_channel
|
|
590
|
+
output_channel = channels[i]
|
|
591
|
+
is_last = i == len(channels) - 1
|
|
592
|
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
593
|
+
transformer_blocks = nn.ModuleList(
|
|
594
|
+
[
|
|
595
|
+
BasicTransformerBlock(
|
|
596
|
+
dim=output_channel,
|
|
597
|
+
num_attention_heads=num_heads,
|
|
598
|
+
attention_head_dim=attention_head_dim,
|
|
599
|
+
dropout=dropout,
|
|
600
|
+
activation_fn=act_fn,
|
|
601
|
+
)
|
|
602
|
+
for _ in range(n_blocks)
|
|
603
|
+
]
|
|
604
|
+
)
|
|
605
|
+
downsample = (
|
|
606
|
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
607
|
+
)
|
|
608
|
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
609
|
+
|
|
610
|
+
for _ in range(num_mid_blocks):
|
|
611
|
+
input_channel = channels[-1]
|
|
612
|
+
out_channels = channels[-1]
|
|
613
|
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
614
|
+
|
|
615
|
+
transformer_blocks = nn.ModuleList(
|
|
616
|
+
[
|
|
617
|
+
BasicTransformerBlock(
|
|
618
|
+
dim=output_channel,
|
|
619
|
+
num_attention_heads=num_heads,
|
|
620
|
+
attention_head_dim=attention_head_dim,
|
|
621
|
+
dropout=dropout,
|
|
622
|
+
activation_fn=act_fn,
|
|
623
|
+
)
|
|
624
|
+
for _ in range(n_blocks)
|
|
625
|
+
]
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
|
629
|
+
|
|
630
|
+
channels = channels[::-1] + (channels[0],)
|
|
631
|
+
for i in range(len(channels) - 1):
|
|
632
|
+
input_channel = channels[i] * 2
|
|
633
|
+
output_channel = channels[i + 1]
|
|
634
|
+
is_last = i == len(channels) - 2
|
|
635
|
+
resnet = ResnetBlock1D(
|
|
636
|
+
dim=input_channel,
|
|
637
|
+
dim_out=output_channel,
|
|
638
|
+
time_emb_dim=time_embed_dim,
|
|
639
|
+
)
|
|
640
|
+
transformer_blocks = nn.ModuleList(
|
|
641
|
+
[
|
|
642
|
+
BasicTransformerBlock(
|
|
643
|
+
dim=output_channel,
|
|
644
|
+
num_attention_heads=num_heads,
|
|
645
|
+
attention_head_dim=attention_head_dim,
|
|
646
|
+
dropout=dropout,
|
|
647
|
+
activation_fn=act_fn,
|
|
648
|
+
)
|
|
649
|
+
for _ in range(n_blocks)
|
|
650
|
+
]
|
|
651
|
+
)
|
|
652
|
+
upsample = (
|
|
653
|
+
Upsample1D(output_channel, use_conv_transpose=True)
|
|
654
|
+
if not is_last
|
|
655
|
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
656
|
+
)
|
|
657
|
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
658
|
+
self.final_block = Block1D(channels[-1], channels[-1])
|
|
659
|
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
660
|
+
self.initialize_weights()
|
|
661
|
+
|
|
662
|
+
def initialize_weights(self):
|
|
663
|
+
for m in self.modules():
|
|
664
|
+
if isinstance(m, nn.Conv1d):
|
|
665
|
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
666
|
+
if m.bias is not None:
|
|
667
|
+
nn.init.constant_(m.bias, 0)
|
|
668
|
+
elif isinstance(m, nn.GroupNorm):
|
|
669
|
+
nn.init.constant_(m.weight, 1)
|
|
670
|
+
nn.init.constant_(m.bias, 0)
|
|
671
|
+
elif isinstance(m, nn.Linear):
|
|
672
|
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
673
|
+
if m.bias is not None:
|
|
674
|
+
nn.init.constant_(m.bias, 0)
|
|
675
|
+
|
|
676
|
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
|
677
|
+
"""Forward pass of the UNet1DConditional model.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
681
|
+
mask (_type_): shape (batch_size, 1, time)
|
|
682
|
+
t (_type_): shape (batch_size)
|
|
683
|
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
684
|
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
685
|
+
|
|
686
|
+
Raises:
|
|
687
|
+
ValueError: _description_
|
|
688
|
+
ValueError: _description_
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
_type_: _description_
|
|
692
|
+
"""
|
|
693
|
+
|
|
694
|
+
t = self.time_embeddings(t).to(t.dtype)
|
|
695
|
+
t = self.time_mlp(t)
|
|
696
|
+
|
|
697
|
+
x = pack([x, mu], "b * t")[0]
|
|
698
|
+
|
|
699
|
+
if spks is not None:
|
|
700
|
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
701
|
+
x = pack([x, spks], "b * t")[0]
|
|
702
|
+
if cond is not None:
|
|
703
|
+
x = pack([x, cond], "b * t")[0]
|
|
704
|
+
|
|
705
|
+
hiddens = []
|
|
706
|
+
masks = [mask]
|
|
707
|
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
|
708
|
+
mask_down = masks[-1]
|
|
709
|
+
x = resnet(x, mask_down, t)
|
|
710
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
711
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
712
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
713
|
+
for transformer_block in transformer_blocks:
|
|
714
|
+
x = transformer_block(
|
|
715
|
+
hidden_states=x,
|
|
716
|
+
attention_mask=attn_mask,
|
|
717
|
+
timestep=t,
|
|
718
|
+
)
|
|
719
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
720
|
+
hiddens.append(x) # Save hidden states for skip connections
|
|
721
|
+
x = downsample(x * mask_down)
|
|
722
|
+
masks.append(mask_down[:, :, ::2])
|
|
723
|
+
masks = masks[:-1]
|
|
724
|
+
mask_mid = masks[-1]
|
|
725
|
+
|
|
726
|
+
for resnet, transformer_blocks in self.mid_blocks:
|
|
727
|
+
x = resnet(x, mask_mid, t)
|
|
728
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
729
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
730
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
731
|
+
for transformer_block in transformer_blocks:
|
|
732
|
+
x = transformer_block(
|
|
733
|
+
hidden_states=x,
|
|
734
|
+
attention_mask=attn_mask,
|
|
735
|
+
timestep=t,
|
|
736
|
+
)
|
|
737
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
738
|
+
|
|
739
|
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
|
740
|
+
mask_up = masks.pop()
|
|
741
|
+
skip = hiddens.pop()
|
|
742
|
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
743
|
+
x = resnet(x, mask_up, t)
|
|
744
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
745
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
746
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
747
|
+
for transformer_block in transformer_blocks:
|
|
748
|
+
x = transformer_block(
|
|
749
|
+
hidden_states=x,
|
|
750
|
+
attention_mask=attn_mask,
|
|
751
|
+
timestep=t,
|
|
752
|
+
)
|
|
753
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
754
|
+
x = upsample(x * mask_up)
|
|
755
|
+
x = self.final_block(x, mask_up)
|
|
756
|
+
output = self.final_proj(x * mask_up)
|
|
757
|
+
return output * mask
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
class CausalConditionalDecoder(ConditionalDecoder):
|
|
761
|
+
"""
|
|
762
|
+
This decoder requires an input with the same shape of the target. So, if your text content
|
|
763
|
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
in_channels: number of input channels
|
|
767
|
+
out_channels: number of output channels
|
|
768
|
+
channels: list of channel dimensions
|
|
769
|
+
dropout: dropout rate
|
|
770
|
+
attention_head_dim: dimension of attention heads
|
|
771
|
+
n_blocks: number of transformer blocks
|
|
772
|
+
num_mid_blocks: number of middle blocks
|
|
773
|
+
num_heads: number of attention heads
|
|
774
|
+
act_fn: activation function name
|
|
775
|
+
static_chunk_size: size of static chunks
|
|
776
|
+
num_decoding_left_chunks: number of left chunks for decoding
|
|
777
|
+
"""
|
|
778
|
+
|
|
779
|
+
def __init__(
|
|
780
|
+
self,
|
|
781
|
+
in_channels=320,
|
|
782
|
+
out_channels=80,
|
|
783
|
+
channels=[256], # noqa
|
|
784
|
+
dropout=0.0,
|
|
785
|
+
attention_head_dim=64,
|
|
786
|
+
n_blocks=4,
|
|
787
|
+
num_mid_blocks=12,
|
|
788
|
+
num_heads=8,
|
|
789
|
+
act_fn="gelu",
|
|
790
|
+
static_chunk_size=50,
|
|
791
|
+
num_decoding_left_chunks=-1,
|
|
792
|
+
):
|
|
793
|
+
torch.nn.Module.__init__(self)
|
|
794
|
+
channels = tuple(channels)
|
|
795
|
+
self.in_channels = in_channels
|
|
796
|
+
self.out_channels = out_channels
|
|
797
|
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
798
|
+
time_embed_dim = channels[0] * 4
|
|
799
|
+
self.time_mlp = TimestepEmbedding(
|
|
800
|
+
in_channels=in_channels,
|
|
801
|
+
time_embed_dim=time_embed_dim,
|
|
802
|
+
act_fn="silu",
|
|
803
|
+
)
|
|
804
|
+
self.static_chunk_size = static_chunk_size
|
|
805
|
+
self.num_decoding_left_chunks = num_decoding_left_chunks
|
|
806
|
+
self.down_blocks = nn.ModuleList([])
|
|
807
|
+
self.mid_blocks = nn.ModuleList([])
|
|
808
|
+
self.up_blocks = nn.ModuleList([])
|
|
809
|
+
|
|
810
|
+
output_channel = in_channels
|
|
811
|
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
|
812
|
+
input_channel = output_channel
|
|
813
|
+
output_channel = channels[i]
|
|
814
|
+
is_last = i == len(channels) - 1
|
|
815
|
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
816
|
+
transformer_blocks = nn.ModuleList(
|
|
817
|
+
[
|
|
818
|
+
BasicTransformerBlock(
|
|
819
|
+
dim=output_channel,
|
|
820
|
+
num_attention_heads=num_heads,
|
|
821
|
+
attention_head_dim=attention_head_dim,
|
|
822
|
+
dropout=dropout,
|
|
823
|
+
activation_fn=act_fn,
|
|
824
|
+
)
|
|
825
|
+
for _ in range(n_blocks)
|
|
826
|
+
]
|
|
827
|
+
)
|
|
828
|
+
downsample = (
|
|
829
|
+
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
|
830
|
+
)
|
|
831
|
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
832
|
+
|
|
833
|
+
for _ in range(num_mid_blocks):
|
|
834
|
+
input_channel = channels[-1]
|
|
835
|
+
out_channels = channels[-1]
|
|
836
|
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
837
|
+
|
|
838
|
+
transformer_blocks = nn.ModuleList(
|
|
839
|
+
[
|
|
840
|
+
BasicTransformerBlock(
|
|
841
|
+
dim=output_channel,
|
|
842
|
+
num_attention_heads=num_heads,
|
|
843
|
+
attention_head_dim=attention_head_dim,
|
|
844
|
+
dropout=dropout,
|
|
845
|
+
activation_fn=act_fn,
|
|
846
|
+
)
|
|
847
|
+
for _ in range(n_blocks)
|
|
848
|
+
]
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
|
852
|
+
|
|
853
|
+
channels = channels[::-1] + (channels[0],)
|
|
854
|
+
for i in range(len(channels) - 1):
|
|
855
|
+
input_channel = channels[i] * 2
|
|
856
|
+
output_channel = channels[i + 1]
|
|
857
|
+
is_last = i == len(channels) - 2
|
|
858
|
+
resnet = CausalResnetBlock1D(
|
|
859
|
+
dim=input_channel,
|
|
860
|
+
dim_out=output_channel,
|
|
861
|
+
time_emb_dim=time_embed_dim,
|
|
862
|
+
)
|
|
863
|
+
transformer_blocks = nn.ModuleList(
|
|
864
|
+
[
|
|
865
|
+
BasicTransformerBlock(
|
|
866
|
+
dim=output_channel,
|
|
867
|
+
num_attention_heads=num_heads,
|
|
868
|
+
attention_head_dim=attention_head_dim,
|
|
869
|
+
dropout=dropout,
|
|
870
|
+
activation_fn=act_fn,
|
|
871
|
+
)
|
|
872
|
+
for _ in range(n_blocks)
|
|
873
|
+
]
|
|
874
|
+
)
|
|
875
|
+
upsample = (
|
|
876
|
+
Upsample1D(output_channel, use_conv_transpose=True)
|
|
877
|
+
if not is_last
|
|
878
|
+
else CausalConv1d(output_channel, output_channel, 3)
|
|
879
|
+
)
|
|
880
|
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
881
|
+
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
|
882
|
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
883
|
+
self.initialize_weights()
|
|
884
|
+
|
|
885
|
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
|
886
|
+
"""Forward pass of the UNet1DConditional model.
|
|
887
|
+
|
|
888
|
+
Args:
|
|
889
|
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
890
|
+
mask (_type_): shape (batch_size, 1, time)
|
|
891
|
+
t (_type_): shape (batch_size)
|
|
892
|
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
893
|
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
894
|
+
|
|
895
|
+
Raises:
|
|
896
|
+
ValueError: _description_
|
|
897
|
+
ValueError: _description_
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
_type_: _description_
|
|
901
|
+
"""
|
|
902
|
+
t = self.time_embeddings(t).to(t.dtype)
|
|
903
|
+
t = self.time_mlp(t)
|
|
904
|
+
|
|
905
|
+
x = pack([x, mu], "b * t")[0]
|
|
906
|
+
|
|
907
|
+
if spks is not None:
|
|
908
|
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
909
|
+
x = pack([x, spks], "b * t")[0]
|
|
910
|
+
if cond is not None:
|
|
911
|
+
x = pack([x, cond], "b * t")[0]
|
|
912
|
+
|
|
913
|
+
hiddens = []
|
|
914
|
+
masks = [mask]
|
|
915
|
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
|
916
|
+
mask_down = masks[-1]
|
|
917
|
+
x = resnet(x, mask_down, t)
|
|
918
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
919
|
+
if streaming is True:
|
|
920
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
921
|
+
else:
|
|
922
|
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
923
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
924
|
+
for transformer_block in transformer_blocks:
|
|
925
|
+
x = transformer_block(
|
|
926
|
+
hidden_states=x,
|
|
927
|
+
attention_mask=attn_mask,
|
|
928
|
+
timestep=t,
|
|
929
|
+
)
|
|
930
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
931
|
+
hiddens.append(x) # Save hidden states for skip connections
|
|
932
|
+
x = downsample(x * mask_down)
|
|
933
|
+
masks.append(mask_down[:, :, ::2])
|
|
934
|
+
masks = masks[:-1]
|
|
935
|
+
mask_mid = masks[-1]
|
|
936
|
+
|
|
937
|
+
for resnet, transformer_blocks in self.mid_blocks:
|
|
938
|
+
x = resnet(x, mask_mid, t)
|
|
939
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
940
|
+
if streaming is True:
|
|
941
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
942
|
+
else:
|
|
943
|
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
944
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
945
|
+
for transformer_block in transformer_blocks:
|
|
946
|
+
x = transformer_block(
|
|
947
|
+
hidden_states=x,
|
|
948
|
+
attention_mask=attn_mask,
|
|
949
|
+
timestep=t,
|
|
950
|
+
)
|
|
951
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
952
|
+
|
|
953
|
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
|
954
|
+
mask_up = masks.pop()
|
|
955
|
+
skip = hiddens.pop()
|
|
956
|
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
957
|
+
x = resnet(x, mask_up, t)
|
|
958
|
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
959
|
+
if streaming is True:
|
|
960
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
961
|
+
else:
|
|
962
|
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
963
|
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
964
|
+
for transformer_block in transformer_blocks:
|
|
965
|
+
x = transformer_block(
|
|
966
|
+
hidden_states=x,
|
|
967
|
+
attention_mask=attn_mask,
|
|
968
|
+
timestep=t,
|
|
969
|
+
)
|
|
970
|
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
971
|
+
x = upsample(x * mask_up)
|
|
972
|
+
x = self.final_block(x, mask_up)
|
|
973
|
+
output = self.final_proj(x * mask_up)
|
|
974
|
+
return output * mask
|