minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,585 @@
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ from typing import Optional
5
+ from einops import pack, rearrange, repeat
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+
11
+ """
12
+ DiT-v5
13
+ - Add convolution in DiTBlock to increase high-freq component
14
+ """
15
+
16
+
17
+ class MLP(torch.nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features:int,
21
+ hidden_features:Optional[int]=None,
22
+ out_features:Optional[int]=None,
23
+ act_layer=nn.GELU,
24
+ norm_layer=None,
25
+ bias=True,
26
+ drop=0.,
27
+ ):
28
+ super().__init__()
29
+ hidden_features = hidden_features or in_features
30
+ out_features = out_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
32
+ self.act = act_layer()
33
+ self.drop1 = nn.Dropout(drop)
34
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
35
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
36
+ self.drop2 = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop1(x)
42
+ x = self.norm(x)
43
+ x = self.fc2(x)
44
+ x = self.drop2(x)
45
+ return x
46
+
47
+
48
+ class Attention(torch.nn.Module):
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_heads: int = 8,
53
+ head_dim: int = 64,
54
+ qkv_bias: bool = False,
55
+ qk_norm: bool = False,
56
+ attn_drop: float = 0.,
57
+ proj_drop: float = 0.,
58
+ norm_layer: nn.Module = nn.LayerNorm,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ self.head_dim = head_dim
63
+ self.inner_dim = num_heads * head_dim
64
+ self.scale = head_dim ** -0.5
65
+
66
+ self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
67
+ self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
68
+ self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
69
+
70
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
71
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
72
+
73
+ self.attn_drop = nn.Dropout(attn_drop)
74
+ self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+ self.proj = nn.Linear(self.inner_dim, dim)
77
+
78
+ def to_heads(self, ts:torch.Tensor):
79
+ b, t, c = ts.shape
80
+ # (b, t, nh, c)
81
+ ts = ts.reshape(b, t, self.num_heads, c // self.num_heads)
82
+ ts = ts.transpose(1, 2)
83
+ return ts
84
+
85
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
86
+ """Args:
87
+ x(torch.Tensor): shape (b, t, c)
88
+ attn_mask(torch.Tensor): shape (b, t, t)
89
+ """
90
+ b, t, c = x.shape
91
+
92
+ q = self.to_q(x)
93
+ k = self.to_k(x)
94
+ v = self.to_v(x)
95
+
96
+ q = self.to_heads(q) # (b, nh, t, c)
97
+ k = self.to_heads(k)
98
+ v = self.to_heads(v)
99
+
100
+ q = self.q_norm(q)
101
+ k = self.k_norm(k)
102
+
103
+ attn_mask = attn_mask.unsqueeze(1)
104
+ x = F.scaled_dot_product_attention(
105
+ q, k, v,
106
+ attn_mask=attn_mask,
107
+ dropout_p=self.attn_drop.p if self.training else 0.,
108
+ ) # (b, nh, t, c)
109
+ x = x.transpose(1, 2).reshape(b, t, -1)
110
+ x = self.proj(x)
111
+ x = self.proj_drop(x)
112
+ return x
113
+
114
+ def forward_chunk(self, x: torch.Tensor, att_cache: torch.Tensor=None, attn_mask: torch.Tensor=None):
115
+ """
116
+ Args:
117
+ x: shape (b, dt, c)
118
+ att_cache: shape (b, nh, t, c*2)
119
+ """
120
+ b, t, c = x.shape
121
+
122
+ q = self.to_q(x)
123
+ k = self.to_k(x)
124
+ v = self.to_v(x)
125
+
126
+ q = self.to_heads(q) # (b, nh, t, c)
127
+ k = self.to_heads(k)
128
+ v = self.to_heads(v)
129
+
130
+ q = self.q_norm(q)
131
+ k = self.k_norm(k)
132
+
133
+ # unpack {k,v}_cache
134
+ if att_cache is not None:
135
+ if attn_mask is not None:
136
+ k_cache, v_cache = att_cache.chunk(2, dim=3)
137
+ k = torch.cat([k, k_cache], dim=2)
138
+ v = torch.cat([v, v_cache], dim=2)
139
+
140
+ else:
141
+ k_cache, v_cache = att_cache.chunk(2, dim=3)
142
+ k = torch.cat([k, k_cache], dim=2)
143
+ v = torch.cat([v, v_cache], dim=2)
144
+
145
+ # new {k,v}_cache
146
+ new_att_cache = torch.cat([k, v], dim=3)
147
+ # attn_mask = torch.ones((b, 1, t, t1), dtype=torch.bool, device=x.device)
148
+ if attn_mask is not None:
149
+ attn_mask = attn_mask.unsqueeze(1)
150
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # (b, nh, t, c)
151
+ x = x.transpose(1, 2).reshape(b, t, -1)
152
+ x = self.proj(x)
153
+ x = self.proj_drop(x)
154
+ return x, new_att_cache
155
+
156
+
157
+ def modulate(x, shift, scale):
158
+ return x * (1 + scale) + shift
159
+
160
+
161
+ class TimestepEmbedder(nn.Module):
162
+ """
163
+ Embeds scalar timesteps into vector representations.
164
+ """
165
+ def __init__(self, hidden_size, frequency_embedding_size=256):
166
+ super().__init__()
167
+ self.mlp = nn.Sequential(
168
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
169
+ nn.SiLU(),
170
+ nn.Linear(hidden_size, hidden_size, bias=True),
171
+ )
172
+ self.frequency_embedding_size = frequency_embedding_size
173
+ # from SinusoidalPosEmb
174
+ self.scale = 1000
175
+
176
+ @staticmethod
177
+ def timestep_embedding(t, dim, max_period=10000):
178
+ """
179
+ Create sinusoidal timestep embeddings.
180
+ :param t: a 1-D Tensor of N indices, one per batch element.
181
+ These may be fractional.
182
+ :param dim: the dimension of the output.
183
+ :param max_period: controls the minimum frequency of the embeddings.
184
+ :return: an (N, D) Tensor of positional embeddings.
185
+ """
186
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
187
+ half = dim // 2
188
+ freqs = torch.exp(
189
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
190
+ ).to(t)
191
+ args = t[:, None] * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t):
198
+ t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size)
199
+ t_emb = self.mlp(t_freq)
200
+ return t_emb
201
+
202
+
203
+ # Convolution related
204
+ class Transpose(torch.nn.Module):
205
+ def __init__(self, dim0: int, dim1: int):
206
+ super().__init__()
207
+ self.dim0 = dim0
208
+ self.dim1 = dim1
209
+
210
+ def forward(self, x: torch.Tensor):
211
+ x = torch.transpose(x, self.dim0, self.dim1)
212
+ return x
213
+
214
+
215
+ class CausalConv1d(torch.nn.Conv1d):
216
+ def __init__(
217
+ self,
218
+ in_channels: int,
219
+ out_channels: int,
220
+ kernel_size: int,
221
+ ) -> None:
222
+ super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
223
+ self.causal_padding = (kernel_size - 1, 0)
224
+
225
+ def forward(self, x: torch.Tensor):
226
+ x = F.pad(x, self.causal_padding)
227
+ x = super(CausalConv1d, self).forward(x)
228
+ return x
229
+
230
+ def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None):
231
+ if cnn_cache is None:
232
+ cnn_cache = x.new_zeros((x.shape[0], self.in_channels, self.causal_padding[0]))
233
+ x = torch.cat([cnn_cache, x], dim=2)
234
+ new_cnn_cache = x[..., -self.causal_padding[0]:]
235
+ x = super(CausalConv1d, self).forward(x)
236
+ return x, new_cnn_cache
237
+
238
+
239
+ class CausalConvBlock(nn.Module):
240
+ def __init__(self,
241
+ in_channels: int,
242
+ out_channels: int,
243
+ kernel_size: int = 3,
244
+ ):
245
+ super().__init__()
246
+ self.in_channels = in_channels
247
+ self.out_channels = out_channels
248
+ self.kernel_size = kernel_size
249
+
250
+ self.block = torch.nn.Sequential(
251
+ # norm
252
+ # conv1
253
+ Transpose(1, 2),
254
+ CausalConv1d(in_channels, out_channels, kernel_size),
255
+ Transpose(1, 2),
256
+ # norm & act
257
+ nn.LayerNorm(out_channels),
258
+ nn.Mish(),
259
+ # conv2
260
+ Transpose(1, 2),
261
+ CausalConv1d(out_channels, out_channels, kernel_size),
262
+ Transpose(1, 2),
263
+ )
264
+
265
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
266
+ """
267
+ Args:
268
+ x: shape (b, t, c)
269
+ mask: shape (b, t, 1)
270
+ """
271
+ if mask is not None: x = x * mask
272
+ x = self.block(x)
273
+ if mask is not None: x = x * mask
274
+ return x
275
+
276
+ def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None):
277
+ """
278
+ Args:
279
+ x: shape (b, dt, c)
280
+ cnn_cache: shape (b, c1+c2, 2)
281
+ """
282
+ if cnn_cache is not None:
283
+ cnn_cache1, cnn_cache2 = cnn_cache.split((self.in_channels, self.out_channels), dim=1)
284
+ else:
285
+ cnn_cache1, cnn_cache2 = None, None
286
+ x = self.block[0](x)
287
+ x, new_cnn_cache1 = self.block[1].forward_chunk(x, cnn_cache1)
288
+ x = self.block[2:6](x)
289
+ x, new_cnn_cache2 = self.block[6].forward_chunk(x, cnn_cache2)
290
+ x = self.block[7](x)
291
+ new_cnn_cache = torch.cat((new_cnn_cache1, new_cnn_cache2), dim=1)
292
+ return x, new_cnn_cache
293
+
294
+
295
+ class DiTBlock(nn.Module):
296
+ """
297
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
298
+ """
299
+ def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs):
300
+ super().__init__()
301
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
302
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs)
303
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
304
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
305
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
306
+ self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
307
+ self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
308
+ self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3)
309
+ self.adaLN_modulation = nn.Sequential(
310
+ nn.SiLU(),
311
+ nn.Linear(hidden_size, 9 * hidden_size, bias=True)
312
+ )
313
+
314
+ def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor):
315
+ """Args
316
+ x: shape (b, t, c)
317
+ c: shape (b, 1, c)
318
+ attn_mask: shape (b, t, t), bool type attention mask
319
+ """
320
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
321
+ = self.adaLN_modulation(c).chunk(9, dim=-1)
322
+ # attention
323
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask)
324
+ # conv
325
+ x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv))
326
+ # mlp
327
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
328
+ return x
329
+
330
+ def forward_chunk(self, x: torch.Tensor, c: torch.Tensor, cnn_cache: torch.Tensor=None, att_cache: torch.Tensor=None, mask: torch.Tensor=None):
331
+ """
332
+ Args:
333
+ x: shape (b, dt, c)
334
+ c: shape (b, 1, c)
335
+ cnn_cache: shape (b, c1+c2, 2)
336
+ att_cache: shape (b, nh, t, c * 2)
337
+ """
338
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
339
+ = self.adaLN_modulation(c).chunk(9, dim=-1)
340
+ # attention
341
+ x_att, new_att_cache = self.attn.forward_chunk(modulate(self.norm1(x), shift_msa, scale_msa), att_cache, mask)
342
+ x = x + gate_msa * x_att
343
+ # conv
344
+ x_conv, new_cnn_cache = self.conv.forward_chunk(modulate(self.norm3(x), shift_conv, scale_conv), cnn_cache)
345
+ x = x + gate_conv * x_conv
346
+ # mlp
347
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
348
+ return x, new_cnn_cache, new_att_cache
349
+
350
+
351
+ class FinalLayer(nn.Module):
352
+ """
353
+ The final layer of DiT.
354
+ """
355
+ def __init__(self, hidden_size, out_channels):
356
+ super().__init__()
357
+ self.adaLN_modulation = nn.Sequential(
358
+ nn.SiLU(),
359
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
360
+ )
361
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
362
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
363
+
364
+ def forward(self, x, c):
365
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
366
+ x = modulate(self.norm_final(x), shift, scale)
367
+ x = self.linear(x)
368
+ return x
369
+
370
+
371
+ class DiT(nn.Module):
372
+ """
373
+ Diffusion model with a Transformer backbone.
374
+ """
375
+ def __init__(
376
+ self,
377
+ in_channels: int,
378
+ out_channels: int,
379
+ mlp_ratio: float = 4.0,
380
+ depth: int = 28,
381
+ num_heads: int = 8,
382
+ head_dim: int = 64,
383
+ hidden_size: int = 256,
384
+ ):
385
+ super().__init__()
386
+ self.in_channels = in_channels
387
+ self.out_channels = out_channels
388
+ self.t_embedder = TimestepEmbedder(hidden_size)
389
+
390
+ self.in_proj = nn.Linear(in_channels, hidden_size)
391
+
392
+ self.blocks = nn.ModuleList([
393
+ DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth)
394
+ ])
395
+ self.final_layer = FinalLayer(hidden_size, self.out_channels)
396
+
397
+ self.initialize_weights()
398
+
399
+ self.enable_cuda_graph = False
400
+ self.use_cuda_graph = False
401
+
402
+ self.graph_chunk = {}
403
+ self.inference_buffers_chunk = {}
404
+ self.max_size_chunk = {}
405
+
406
+ self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False)
407
+ self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False)
408
+
409
+ def initialize_weights(self):
410
+ # Initialize transformer layers:
411
+ def _basic_init(module):
412
+ if isinstance(module, nn.Linear):
413
+ torch.nn.init.xavier_uniform_(module.weight)
414
+ if module.bias is not None:
415
+ nn.init.constant_(module.bias, 0)
416
+ self.apply(_basic_init)
417
+
418
+ # Initialize timestep embedding MLP:
419
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
420
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
421
+
422
+ # Zero-out adaLN modulation layers in DiT blocks:
423
+ for block in self.blocks:
424
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
425
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
426
+
427
+ # Zero-out output layers:
428
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
429
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
430
+ nn.init.constant_(self.final_layer.linear.weight, 0)
431
+ nn.init.constant_(self.final_layer.linear.bias, 0)
432
+
433
+ def _init_cuda_graph_chunk(self):
434
+ # get dtype, device from registered buffer
435
+ dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device
436
+ # init cuda graph for streaming forward
437
+ with torch.no_grad():
438
+ for chunk_size in [30, 48, 96]:
439
+ if chunk_size == 30 or chunk_size == 48:
440
+ max_size = 500
441
+ self.max_size_chunk[chunk_size] = max_size
442
+ else:
443
+ max_size = 1000
444
+ self.max_size_chunk[chunk_size] = max_size
445
+ static_x1 = torch.zeros((2, 320, chunk_size), dtype=dtype, device=device)
446
+ static_t1 = torch.zeros((2, 1, 512), dtype=dtype, device=device)
447
+ static_mask1 = torch.ones((2, chunk_size, max_size+chunk_size), dtype=torch.bool, device=device)
448
+ static_att_cache = torch.zeros((16, 2, 8, max_size, 128), dtype=dtype, device=device)
449
+ static_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device)
450
+ static_inputs1 = [
451
+ static_x1,
452
+ static_t1,
453
+ static_mask1,
454
+ static_cnn_cache,
455
+ static_att_cache,
456
+ ]
457
+ static_new_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device)
458
+ static_new_att_cache = torch.zeros((16, 2, 8, max_size+chunk_size, 128), dtype=dtype, device=device)
459
+ self.blocks_forward_chunk(
460
+ static_inputs1[0],
461
+ static_inputs1[1],
462
+ static_inputs1[2],
463
+ static_inputs1[3],
464
+ static_inputs1[4],
465
+ static_new_cnn_cache,
466
+ static_new_att_cache)
467
+ graph_chunk = torch.cuda.CUDAGraph()
468
+ with torch.cuda.graph(graph_chunk):
469
+ static_out1 = self.blocks_forward_chunk(static_x1, static_t1, static_mask1, static_cnn_cache, static_att_cache, static_new_cnn_cache, static_new_att_cache)
470
+ static_outputs1 = [static_out1, static_new_cnn_cache, static_new_att_cache]
471
+ self.inference_buffers_chunk[chunk_size] = {
472
+ 'static_inputs': static_inputs1,
473
+ 'static_outputs': static_outputs1
474
+ }
475
+ self.graph_chunk[chunk_size] = graph_chunk
476
+
477
+ def _init_cuda_graph_all(self):
478
+ self._init_cuda_graph_chunk()
479
+ self.use_cuda_graph = True
480
+ print(f"CUDA Graph initialized successfully for chunk decoder")
481
+
482
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
483
+ """Args:
484
+ x: shape (b, c, t)
485
+ mask: shape (b, 1, t)
486
+ t: shape (b,)
487
+ spks: shape (b, c)
488
+ cond: shape (b, c, t)
489
+ """
490
+ # (sfy) chunk training strategy should not be open-sourced
491
+
492
+ # time
493
+ t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)
494
+ x = pack([x, mu], "b * t")[0]
495
+ if spks is not None:
496
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
497
+ x = pack([x, spks], "b * t")[0]
498
+ if cond is not None:
499
+ x = pack([x, cond], "b * t")[0]
500
+
501
+ return self.blocks_forward(x, t, mask)
502
+
503
+ def blocks_forward(self, x, t, mask):
504
+ x = x.transpose(1, 2)
505
+ attn_mask = mask.bool()
506
+ x = self.in_proj(x)
507
+ for block in self.blocks:
508
+ x = block(x, t, attn_mask)
509
+ x = self.final_layer(x, t)
510
+ x = x.transpose(1, 2)
511
+ return x
512
+
513
+ def forward_chunk(self,
514
+ x: torch.Tensor,
515
+ mu: torch.Tensor,
516
+ t: torch.Tensor,
517
+ spks: torch.Tensor,
518
+ cond: torch.Tensor,
519
+ cnn_cache: torch.Tensor = None,
520
+ att_cache: torch.Tensor = None,
521
+ ):
522
+ """
523
+ Args:
524
+ x: shape (b, dt, c)
525
+ mu: shape (b, dt, c)
526
+ t: shape (b,)
527
+ spks: shape (b, c)
528
+ cond: shape (b, dt, c)
529
+ cnn_cache: shape (depth, b, c1+c2, 2)
530
+ att_cache: shape (depth, b, nh, t, c * 2)
531
+ """
532
+
533
+ # time
534
+ t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)
535
+ x = pack([x, mu], "b * t")[0]
536
+ if spks is not None:
537
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
538
+ x = pack([x, spks], "b * t")[0]
539
+ if cond is not None:
540
+ x = pack([x, cond], "b * t")[0]
541
+
542
+ # create fake cache
543
+ if cnn_cache is None:
544
+ cnn_cache = [None] * len(self.blocks)
545
+ if att_cache is None:
546
+ att_cache = [None] * len(self.blocks)
547
+ if att_cache[0] is not None:
548
+ last_att_len = att_cache.shape[3]
549
+ else:
550
+ last_att_len = 0
551
+ chunk_size = x.shape[2]
552
+ mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device)
553
+ if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]:
554
+ padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device)
555
+ padded_mask[:, :, :mask.shape[-1]] = mask
556
+ padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device)
557
+ padded_att_cache[:, :, :, :last_att_len, :] = att_cache
558
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][0].copy_(x)
559
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][1].copy_(t)
560
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][2].copy_(padded_mask)
561
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][3].copy_(cnn_cache)
562
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][4].copy_(padded_att_cache)
563
+ self.graph_chunk[chunk_size].replay()
564
+ x = self.inference_buffers_chunk[chunk_size]['static_outputs'][0][:, :, :chunk_size]
565
+ new_cnn_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][1]
566
+ new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :]
567
+ else:
568
+ mask = None
569
+ x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer)
570
+ new_cnn_cache = self.cnn_cache_buffer
571
+ new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :]
572
+
573
+ return x, new_cnn_cache, new_att_cache
574
+
575
+ def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None):
576
+ x = x.transpose(1, 2)
577
+ x = self.in_proj(x)
578
+ for b_idx, block in enumerate(self.blocks):
579
+ x, this_new_cnn_cache, this_new_att_cache \
580
+ = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask)
581
+ cnn_cache_buffer[b_idx] = this_new_cnn_cache
582
+ att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache
583
+ x = self.final_layer(x, t)
584
+ x = x.transpose(1, 2)
585
+ return x