minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,209 @@
1
+ """
2
+ This is a base lightning module that can be used to train a model.
3
+ The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
4
+ """
5
+ import inspect
6
+ from abc import ABC
7
+ from typing import Any, Dict
8
+
9
+ import torch
10
+ from lightning import LightningModule
11
+ from lightning.pytorch.utilities import grad_norm
12
+
13
+ from matcha import utils
14
+ from matcha.utils.utils import plot_tensor
15
+
16
+ log = utils.get_pylogger(__name__)
17
+
18
+
19
+ class BaseLightningClass(LightningModule, ABC):
20
+ def update_data_statistics(self, data_statistics):
21
+ if data_statistics is None:
22
+ data_statistics = {
23
+ "mel_mean": 0.0,
24
+ "mel_std": 1.0,
25
+ }
26
+
27
+ self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
28
+ self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
29
+
30
+ def configure_optimizers(self) -> Any:
31
+ optimizer = self.hparams.optimizer(params=self.parameters())
32
+ if self.hparams.scheduler not in (None, {}):
33
+ scheduler_args = {}
34
+ # Manage last epoch for exponential schedulers
35
+ if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
36
+ if hasattr(self, "ckpt_loaded_epoch"):
37
+ current_epoch = self.ckpt_loaded_epoch - 1
38
+ else:
39
+ current_epoch = -1
40
+
41
+ scheduler_args.update({"optimizer": optimizer})
42
+ scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
43
+ scheduler.last_epoch = current_epoch
44
+ return {
45
+ "optimizer": optimizer,
46
+ "lr_scheduler": {
47
+ "scheduler": scheduler,
48
+ "interval": self.hparams.scheduler.lightning_args.interval,
49
+ "frequency": self.hparams.scheduler.lightning_args.frequency,
50
+ "name": "learning_rate",
51
+ },
52
+ }
53
+
54
+ return {"optimizer": optimizer}
55
+
56
+ def get_losses(self, batch):
57
+ x, x_lengths = batch["x"], batch["x_lengths"]
58
+ y, y_lengths = batch["y"], batch["y_lengths"]
59
+ spks = batch["spks"]
60
+
61
+ dur_loss, prior_loss, diff_loss = self(
62
+ x=x,
63
+ x_lengths=x_lengths,
64
+ y=y,
65
+ y_lengths=y_lengths,
66
+ spks=spks,
67
+ out_size=self.out_size,
68
+ )
69
+ return {
70
+ "dur_loss": dur_loss,
71
+ "prior_loss": prior_loss,
72
+ "diff_loss": diff_loss,
73
+ }
74
+
75
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
76
+ self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
77
+
78
+ def training_step(self, batch: Any, batch_idx: int):
79
+ loss_dict = self.get_losses(batch)
80
+ self.log(
81
+ "step",
82
+ float(self.global_step),
83
+ on_step=True,
84
+ prog_bar=True,
85
+ logger=True,
86
+ sync_dist=True,
87
+ )
88
+
89
+ self.log(
90
+ "sub_loss/train_dur_loss",
91
+ loss_dict["dur_loss"],
92
+ on_step=True,
93
+ on_epoch=True,
94
+ logger=True,
95
+ sync_dist=True,
96
+ )
97
+ self.log(
98
+ "sub_loss/train_prior_loss",
99
+ loss_dict["prior_loss"],
100
+ on_step=True,
101
+ on_epoch=True,
102
+ logger=True,
103
+ sync_dist=True,
104
+ )
105
+ self.log(
106
+ "sub_loss/train_diff_loss",
107
+ loss_dict["diff_loss"],
108
+ on_step=True,
109
+ on_epoch=True,
110
+ logger=True,
111
+ sync_dist=True,
112
+ )
113
+
114
+ total_loss = sum(loss_dict.values())
115
+ self.log(
116
+ "loss/train",
117
+ total_loss,
118
+ on_step=True,
119
+ on_epoch=True,
120
+ logger=True,
121
+ prog_bar=True,
122
+ sync_dist=True,
123
+ )
124
+
125
+ return {"loss": total_loss, "log": loss_dict}
126
+
127
+ def validation_step(self, batch: Any, batch_idx: int):
128
+ loss_dict = self.get_losses(batch)
129
+ self.log(
130
+ "sub_loss/val_dur_loss",
131
+ loss_dict["dur_loss"],
132
+ on_step=True,
133
+ on_epoch=True,
134
+ logger=True,
135
+ sync_dist=True,
136
+ )
137
+ self.log(
138
+ "sub_loss/val_prior_loss",
139
+ loss_dict["prior_loss"],
140
+ on_step=True,
141
+ on_epoch=True,
142
+ logger=True,
143
+ sync_dist=True,
144
+ )
145
+ self.log(
146
+ "sub_loss/val_diff_loss",
147
+ loss_dict["diff_loss"],
148
+ on_step=True,
149
+ on_epoch=True,
150
+ logger=True,
151
+ sync_dist=True,
152
+ )
153
+
154
+ total_loss = sum(loss_dict.values())
155
+ self.log(
156
+ "loss/val",
157
+ total_loss,
158
+ on_step=True,
159
+ on_epoch=True,
160
+ logger=True,
161
+ prog_bar=True,
162
+ sync_dist=True,
163
+ )
164
+
165
+ return total_loss
166
+
167
+ def on_validation_end(self) -> None:
168
+ if self.trainer.is_global_zero:
169
+ one_batch = next(iter(self.trainer.val_dataloaders))
170
+ if self.current_epoch == 0:
171
+ log.debug("Plotting original samples")
172
+ for i in range(2):
173
+ y = one_batch["y"][i].unsqueeze(0).to(self.device)
174
+ self.logger.experiment.add_image(
175
+ f"original/{i}",
176
+ plot_tensor(y.squeeze().cpu()),
177
+ self.current_epoch,
178
+ dataformats="HWC",
179
+ )
180
+
181
+ log.debug("Synthesising...")
182
+ for i in range(2):
183
+ x = one_batch["x"][i].unsqueeze(0).to(self.device)
184
+ x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
185
+ spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
186
+ output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
187
+ y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
188
+ attn = output["attn"]
189
+ self.logger.experiment.add_image(
190
+ f"generated_enc/{i}",
191
+ plot_tensor(y_enc.squeeze().cpu()),
192
+ self.current_epoch,
193
+ dataformats="HWC",
194
+ )
195
+ self.logger.experiment.add_image(
196
+ f"generated_dec/{i}",
197
+ plot_tensor(y_dec.squeeze().cpu()),
198
+ self.current_epoch,
199
+ dataformats="HWC",
200
+ )
201
+ self.logger.experiment.add_image(
202
+ f"alignment/{i}",
203
+ plot_tensor(attn.squeeze().cpu()),
204
+ self.current_epoch,
205
+ dataformats="HWC",
206
+ )
207
+
208
+ def on_before_optimizer_step(self, optimizer):
209
+ self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
File without changes
@@ -0,0 +1,443 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from conformer import ConformerBlock
8
+ from diffusers.models.activations import get_activation
9
+ from einops import pack, rearrange, repeat
10
+
11
+ from matcha.models.components.transformer import BasicTransformerBlock
12
+
13
+
14
+ class SinusoidalPosEmb(torch.nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
+
20
+ def forward(self, x, scale=1000):
21
+ if x.ndim < 1:
22
+ x = x.unsqueeze(0)
23
+ device = x.device
24
+ half_dim = self.dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
+ return emb
30
+
31
+
32
+ class Block1D(torch.nn.Module):
33
+ def __init__(self, dim, dim_out, groups=8):
34
+ super().__init__()
35
+ self.block = torch.nn.Sequential(
36
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
37
+ torch.nn.GroupNorm(groups, dim_out),
38
+ nn.Mish(),
39
+ )
40
+
41
+ def forward(self, x, mask):
42
+ output = self.block(x * mask)
43
+ return output * mask
44
+
45
+
46
+ class ResnetBlock1D(torch.nn.Module):
47
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
48
+ super().__init__()
49
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
50
+
51
+ self.block1 = Block1D(dim, dim_out, groups=groups)
52
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
53
+
54
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
55
+
56
+ def forward(self, x, mask, time_emb):
57
+ h = self.block1(x, mask)
58
+ h += self.mlp(time_emb).unsqueeze(-1)
59
+ h = self.block2(h, mask)
60
+ output = h + self.res_conv(x * mask)
61
+ return output
62
+
63
+
64
+ class Downsample1D(nn.Module):
65
+ def __init__(self, dim):
66
+ super().__init__()
67
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
68
+
69
+ def forward(self, x):
70
+ return self.conv(x)
71
+
72
+
73
+ class TimestepEmbedding(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ time_embed_dim: int,
78
+ act_fn: str = "silu",
79
+ out_dim: int = None,
80
+ post_act_fn: Optional[str] = None,
81
+ cond_proj_dim=None,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
86
+
87
+ if cond_proj_dim is not None:
88
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
89
+ else:
90
+ self.cond_proj = None
91
+
92
+ self.act = get_activation(act_fn)
93
+
94
+ if out_dim is not None:
95
+ time_embed_dim_out = out_dim
96
+ else:
97
+ time_embed_dim_out = time_embed_dim
98
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
99
+
100
+ if post_act_fn is None:
101
+ self.post_act = None
102
+ else:
103
+ self.post_act = get_activation(post_act_fn)
104
+
105
+ def forward(self, sample, condition=None):
106
+ if condition is not None:
107
+ sample = sample + self.cond_proj(condition)
108
+ sample = self.linear_1(sample)
109
+
110
+ if self.act is not None:
111
+ sample = self.act(sample)
112
+
113
+ sample = self.linear_2(sample)
114
+
115
+ if self.post_act is not None:
116
+ sample = self.post_act(sample)
117
+ return sample
118
+
119
+
120
+ class Upsample1D(nn.Module):
121
+ """A 1D upsampling layer with an optional convolution.
122
+
123
+ Parameters:
124
+ channels (`int`):
125
+ number of channels in the inputs and outputs.
126
+ use_conv (`bool`, default `False`):
127
+ option to use a convolution.
128
+ use_conv_transpose (`bool`, default `False`):
129
+ option to use a convolution transpose.
130
+ out_channels (`int`, optional):
131
+ number of output channels. Defaults to `channels`.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.use_conv_transpose = use_conv_transpose
140
+ self.name = name
141
+
142
+ self.conv = None
143
+ if use_conv_transpose:
144
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
145
+ elif use_conv:
146
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
147
+
148
+ def forward(self, inputs):
149
+ assert inputs.shape[1] == self.channels
150
+ if self.use_conv_transpose:
151
+ return self.conv(inputs)
152
+
153
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
154
+
155
+ if self.use_conv:
156
+ outputs = self.conv(outputs)
157
+
158
+ return outputs
159
+
160
+
161
+ class ConformerWrapper(ConformerBlock):
162
+ def __init__( # pylint: disable=useless-super-delegation
163
+ self,
164
+ *,
165
+ dim,
166
+ dim_head=64,
167
+ heads=8,
168
+ ff_mult=4,
169
+ conv_expansion_factor=2,
170
+ conv_kernel_size=31,
171
+ attn_dropout=0,
172
+ ff_dropout=0,
173
+ conv_dropout=0,
174
+ conv_causal=False,
175
+ ):
176
+ super().__init__(
177
+ dim=dim,
178
+ dim_head=dim_head,
179
+ heads=heads,
180
+ ff_mult=ff_mult,
181
+ conv_expansion_factor=conv_expansion_factor,
182
+ conv_kernel_size=conv_kernel_size,
183
+ attn_dropout=attn_dropout,
184
+ ff_dropout=ff_dropout,
185
+ conv_dropout=conv_dropout,
186
+ conv_causal=conv_causal,
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ attention_mask,
193
+ encoder_hidden_states=None,
194
+ encoder_attention_mask=None,
195
+ timestep=None,
196
+ ):
197
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
198
+
199
+
200
+ class Decoder(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels,
204
+ out_channels,
205
+ channels=(256, 256),
206
+ dropout=0.05,
207
+ attention_head_dim=64,
208
+ n_blocks=1,
209
+ num_mid_blocks=2,
210
+ num_heads=4,
211
+ act_fn="snake",
212
+ down_block_type="transformer",
213
+ mid_block_type="transformer",
214
+ up_block_type="transformer",
215
+ ):
216
+ super().__init__()
217
+ channels = tuple(channels)
218
+ self.in_channels = in_channels
219
+ self.out_channels = out_channels
220
+
221
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
222
+ time_embed_dim = channels[0] * 4
223
+ self.time_mlp = TimestepEmbedding(
224
+ in_channels=in_channels,
225
+ time_embed_dim=time_embed_dim,
226
+ act_fn="silu",
227
+ )
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.mid_blocks = nn.ModuleList([])
231
+ self.up_blocks = nn.ModuleList([])
232
+
233
+ output_channel = in_channels
234
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
235
+ input_channel = output_channel
236
+ output_channel = channels[i]
237
+ is_last = i == len(channels) - 1
238
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
239
+ transformer_blocks = nn.ModuleList(
240
+ [
241
+ self.get_block(
242
+ down_block_type,
243
+ output_channel,
244
+ attention_head_dim,
245
+ num_heads,
246
+ dropout,
247
+ act_fn,
248
+ )
249
+ for _ in range(n_blocks)
250
+ ]
251
+ )
252
+ downsample = (
253
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
254
+ )
255
+
256
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
257
+
258
+ for i in range(num_mid_blocks):
259
+ input_channel = channels[-1]
260
+ out_channels = channels[-1]
261
+
262
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
263
+
264
+ transformer_blocks = nn.ModuleList(
265
+ [
266
+ self.get_block(
267
+ mid_block_type,
268
+ output_channel,
269
+ attention_head_dim,
270
+ num_heads,
271
+ dropout,
272
+ act_fn,
273
+ )
274
+ for _ in range(n_blocks)
275
+ ]
276
+ )
277
+
278
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
279
+
280
+ channels = channels[::-1] + (channels[0],)
281
+ for i in range(len(channels) - 1):
282
+ input_channel = channels[i]
283
+ output_channel = channels[i + 1]
284
+ is_last = i == len(channels) - 2
285
+
286
+ resnet = ResnetBlock1D(
287
+ dim=2 * input_channel,
288
+ dim_out=output_channel,
289
+ time_emb_dim=time_embed_dim,
290
+ )
291
+ transformer_blocks = nn.ModuleList(
292
+ [
293
+ self.get_block(
294
+ up_block_type,
295
+ output_channel,
296
+ attention_head_dim,
297
+ num_heads,
298
+ dropout,
299
+ act_fn,
300
+ )
301
+ for _ in range(n_blocks)
302
+ ]
303
+ )
304
+ upsample = (
305
+ Upsample1D(output_channel, use_conv_transpose=True)
306
+ if not is_last
307
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
308
+ )
309
+
310
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
311
+
312
+ self.final_block = Block1D(channels[-1], channels[-1])
313
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
314
+
315
+ self.initialize_weights()
316
+ # nn.init.normal_(self.final_proj.weight)
317
+
318
+ @staticmethod
319
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
320
+ if block_type == "conformer":
321
+ block = ConformerWrapper(
322
+ dim=dim,
323
+ dim_head=attention_head_dim,
324
+ heads=num_heads,
325
+ ff_mult=1,
326
+ conv_expansion_factor=2,
327
+ ff_dropout=dropout,
328
+ attn_dropout=dropout,
329
+ conv_dropout=dropout,
330
+ conv_kernel_size=31,
331
+ )
332
+ elif block_type == "transformer":
333
+ block = BasicTransformerBlock(
334
+ dim=dim,
335
+ num_attention_heads=num_heads,
336
+ attention_head_dim=attention_head_dim,
337
+ dropout=dropout,
338
+ activation_fn=act_fn,
339
+ )
340
+ else:
341
+ raise ValueError(f"Unknown block type {block_type}")
342
+
343
+ return block
344
+
345
+ def initialize_weights(self):
346
+ for m in self.modules():
347
+ if isinstance(m, nn.Conv1d):
348
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
349
+
350
+ if m.bias is not None:
351
+ nn.init.constant_(m.bias, 0)
352
+
353
+ elif isinstance(m, nn.GroupNorm):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ elif isinstance(m, nn.Linear):
358
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
359
+
360
+ if m.bias is not None:
361
+ nn.init.constant_(m.bias, 0)
362
+
363
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
364
+ """Forward pass of the UNet1DConditional model.
365
+
366
+ Args:
367
+ x (torch.Tensor): shape (batch_size, in_channels, time)
368
+ mask (_type_): shape (batch_size, 1, time)
369
+ t (_type_): shape (batch_size)
370
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
371
+ cond (_type_, optional): placeholder for future use. Defaults to None.
372
+
373
+ Raises:
374
+ ValueError: _description_
375
+ ValueError: _description_
376
+
377
+ Returns:
378
+ _type_: _description_
379
+ """
380
+
381
+ t = self.time_embeddings(t)
382
+ t = self.time_mlp(t)
383
+
384
+ x = pack([x, mu], "b * t")[0]
385
+
386
+ if spks is not None:
387
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
388
+ x = pack([x, spks], "b * t")[0]
389
+
390
+ hiddens = []
391
+ masks = [mask]
392
+ for resnet, transformer_blocks, downsample in self.down_blocks:
393
+ mask_down = masks[-1]
394
+ x = resnet(x, mask_down, t)
395
+ x = rearrange(x, "b c t -> b t c")
396
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
397
+ for transformer_block in transformer_blocks:
398
+ x = transformer_block(
399
+ hidden_states=x,
400
+ attention_mask=mask_down,
401
+ timestep=t,
402
+ )
403
+ x = rearrange(x, "b t c -> b c t")
404
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
405
+ hiddens.append(x) # Save hidden states for skip connections
406
+ x = downsample(x * mask_down)
407
+ masks.append(mask_down[:, :, ::2])
408
+
409
+ masks = masks[:-1]
410
+ mask_mid = masks[-1]
411
+
412
+ for resnet, transformer_blocks in self.mid_blocks:
413
+ x = resnet(x, mask_mid, t)
414
+ x = rearrange(x, "b c t -> b t c")
415
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
416
+ for transformer_block in transformer_blocks:
417
+ x = transformer_block(
418
+ hidden_states=x,
419
+ attention_mask=mask_mid,
420
+ timestep=t,
421
+ )
422
+ x = rearrange(x, "b t c -> b c t")
423
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
424
+
425
+ for resnet, transformer_blocks, upsample in self.up_blocks:
426
+ mask_up = masks.pop()
427
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
428
+ x = rearrange(x, "b c t -> b t c")
429
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
430
+ for transformer_block in transformer_blocks:
431
+ x = transformer_block(
432
+ hidden_states=x,
433
+ attention_mask=mask_up,
434
+ timestep=t,
435
+ )
436
+ x = rearrange(x, "b t c -> b c t")
437
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
438
+ x = upsample(x * mask_up)
439
+
440
+ x = self.final_block(x, mask_up)
441
+ output = self.final_proj(x * mask_up)
442
+
443
+ return output * mask