minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,281 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ # NOTE this is unnecessary, feat/h already same shape
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ assert token.shape[0] == 1
115
+ # xvec projection
116
+ embedding = F.normalize(embedding, dim=1)
117
+ embedding = self.spk_embed_affine_layer(embedding)
118
+
119
+ # concat speech token and prompt speech token
120
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
121
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
122
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
123
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
124
+
125
+ # text encode
126
+ h, h_lengths = self.encoder(token, token_len)
127
+ h = self.encoder_proj(h)
128
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
129
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
130
+
131
+ # get conditions
132
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
133
+ conds[:, :mel_len1] = prompt_feat
134
+ conds = conds.transpose(1, 2)
135
+
136
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
137
+ feat, flow_cache = self.decoder(
138
+ mu=h.transpose(1, 2).contiguous(),
139
+ mask=mask.unsqueeze(1),
140
+ spks=embedding,
141
+ cond=conds,
142
+ n_timesteps=10,
143
+ prompt_len=mel_len1,
144
+ cache=flow_cache
145
+ )
146
+ feat = feat[:, :, mel_len1:]
147
+ assert feat.shape[2] == mel_len2
148
+ return feat.float(), flow_cache
149
+
150
+
151
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
152
+ def __init__(self,
153
+ input_size: int = 512,
154
+ output_size: int = 80,
155
+ spk_embed_dim: int = 192,
156
+ output_type: str = "mel",
157
+ vocab_size: int = 4096,
158
+ input_frame_rate: int = 50,
159
+ only_mask_loss: bool = True,
160
+ token_mel_ratio: int = 2,
161
+ pre_lookahead_len: int = 3,
162
+ encoder: torch.nn.Module = None,
163
+ decoder: torch.nn.Module = None,
164
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
165
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
166
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
167
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
168
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
169
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
170
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
171
+ super().__init__()
172
+ self.input_size = input_size
173
+ self.output_size = output_size
174
+ self.decoder_conf = decoder_conf
175
+ self.mel_feat_conf = mel_feat_conf
176
+ self.vocab_size = vocab_size
177
+ self.output_type = output_type
178
+ self.input_frame_rate = input_frame_rate
179
+ logging.info(f"input frame rate={self.input_frame_rate}")
180
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
181
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
182
+ self.encoder = encoder
183
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
184
+ self.decoder = decoder
185
+ self.only_mask_loss = only_mask_loss
186
+ self.token_mel_ratio = token_mel_ratio
187
+ self.pre_lookahead_len = pre_lookahead_len
188
+
189
+ def forward(
190
+ self,
191
+ batch: dict,
192
+ device: torch.device,
193
+ ) -> Dict[str, Optional[torch.Tensor]]:
194
+ token = batch['speech_token'].to(device)
195
+ token_len = batch['speech_token_len'].to(device)
196
+ feat = batch['speech_feat'].to(device)
197
+ feat_len = batch['speech_feat_len'].to(device)
198
+ embedding = batch['embedding'].to(device)
199
+
200
+ # NOTE unified training, static_chunk_size > 0 or = 0
201
+ streaming = True if random.random() < 0.5 else False
202
+
203
+ # xvec projection
204
+ embedding = F.normalize(embedding, dim=1)
205
+ embedding = self.spk_embed_affine_layer(embedding)
206
+
207
+ # concat text and prompt_text
208
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
209
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
210
+
211
+ # text encode
212
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
213
+ h = self.encoder_proj(h)
214
+
215
+ # get conditions
216
+ conds = torch.zeros(feat.shape, device=token.device)
217
+ for i, j in enumerate(feat_len):
218
+ if random.random() < 0.5:
219
+ continue
220
+ index = random.randint(0, int(0.3 * j))
221
+ conds[i, :index] = feat[i, :index]
222
+ conds = conds.transpose(1, 2)
223
+
224
+ mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
225
+ loss, _ = self.decoder.compute_loss(
226
+ feat.transpose(1, 2).contiguous(),
227
+ mask.unsqueeze(1),
228
+ h.transpose(1, 2).contiguous(),
229
+ embedding,
230
+ cond=conds,
231
+ streaming=streaming,
232
+ )
233
+ return {'loss': loss}
234
+
235
+ @torch.inference_mode()
236
+ def inference(self,
237
+ token,
238
+ token_len,
239
+ prompt_token,
240
+ prompt_token_len,
241
+ prompt_feat,
242
+ prompt_feat_len,
243
+ embedding,
244
+ streaming,
245
+ finalize):
246
+ assert token.shape[0] == 1
247
+ # xvec projection
248
+ embedding = F.normalize(embedding, dim=1)
249
+ embedding = self.spk_embed_affine_layer(embedding)
250
+
251
+ # concat text and prompt_text
252
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
253
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
254
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
255
+
256
+ # text encode
257
+ if finalize is True:
258
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
259
+ else:
260
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
261
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
262
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
263
+ h = self.encoder_proj(h)
264
+
265
+ # get conditions
266
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
267
+ conds[:, :mel_len1] = prompt_feat
268
+ conds = conds.transpose(1, 2)
269
+
270
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
271
+ feat, _ = self.decoder(
272
+ mu=h.transpose(1, 2).contiguous(),
273
+ mask=mask.unsqueeze(1),
274
+ spks=embedding,
275
+ cond=conds,
276
+ n_timesteps=10,
277
+ streaming=streaming
278
+ )
279
+ feat = feat[:, :, mel_len1:]
280
+ assert feat.shape[2] == mel_len2
281
+ return feat.float(), None
@@ -0,0 +1,227 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+ from cosyvoice.utils.common import set_all_random_seed
19
+
20
+
21
+ class ConditionalCFM(BASECFM):
22
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
23
+ super().__init__(
24
+ n_feats=in_channels,
25
+ cfm_params=cfm_params,
26
+ n_spks=n_spks,
27
+ spk_emb_dim=spk_emb_dim,
28
+ )
29
+ self.t_scheduler = cfm_params.t_scheduler
30
+ self.training_cfg_rate = cfm_params.training_cfg_rate
31
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
32
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
33
+ # Just change the architecture of the estimator here
34
+ self.estimator = estimator
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ for step in range(1, len(t_span)):
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
108
+ dphi_dt = self.forward_estimator(
109
+ x_in, mask_in,
110
+ mu_in, t_in,
111
+ spks_in,
112
+ cond_in,
113
+ streaming
114
+ )
115
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
116
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
117
+ x = x + dt * dphi_dt
118
+ t = t + dt
119
+ sol.append(x)
120
+ if step < len(t_span) - 1:
121
+ dt = t_span[step + 1] - t
122
+
123
+ return sol[-1].float()
124
+
125
+ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
126
+ if isinstance(self.estimator, torch.nn.Module):
127
+ return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
128
+ else:
129
+ [estimator, stream], trt_engine = self.estimator.acquire_estimator()
130
+ # NOTE need to synchronize when switching stream
131
+ torch.cuda.current_stream().synchronize()
132
+ with stream:
133
+ estimator.set_input_shape('x', (2, 80, x.size(2)))
134
+ estimator.set_input_shape('mask', (2, 1, x.size(2)))
135
+ estimator.set_input_shape('mu', (2, 80, x.size(2)))
136
+ estimator.set_input_shape('t', (2,))
137
+ estimator.set_input_shape('spks', (2, 80))
138
+ estimator.set_input_shape('cond', (2, 80, x.size(2)))
139
+ data_ptrs = [x.contiguous().data_ptr(),
140
+ mask.contiguous().data_ptr(),
141
+ mu.contiguous().data_ptr(),
142
+ t.contiguous().data_ptr(),
143
+ spks.contiguous().data_ptr(),
144
+ cond.contiguous().data_ptr(),
145
+ x.data_ptr()]
146
+ for i, j in enumerate(data_ptrs):
147
+ estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
148
+ # run trt engine
149
+ assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
150
+ torch.cuda.current_stream().synchronize()
151
+ self.estimator.release_estimator(estimator, stream)
152
+ return x
153
+
154
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
155
+ """Computes diffusion loss
156
+
157
+ Args:
158
+ x1 (torch.Tensor): Target
159
+ shape: (batch_size, n_feats, mel_timesteps)
160
+ mask (torch.Tensor): target mask
161
+ shape: (batch_size, 1, mel_timesteps)
162
+ mu (torch.Tensor): output of encoder
163
+ shape: (batch_size, n_feats, mel_timesteps)
164
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
165
+ shape: (batch_size, spk_emb_dim)
166
+
167
+ Returns:
168
+ loss: conditional flow matching loss
169
+ y: conditional flow
170
+ shape: (batch_size, n_feats, mel_timesteps)
171
+ """
172
+ b, _, t = mu.shape
173
+
174
+ # random timestep
175
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
176
+ if self.t_scheduler == 'cosine':
177
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
178
+ # sample noise p(x_0)
179
+ z = torch.randn_like(x1)
180
+
181
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
182
+ u = x1 - (1 - self.sigma_min) * z
183
+
184
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
185
+ if self.training_cfg_rate > 0:
186
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
187
+ mu = mu * cfg_mask.view(-1, 1, 1)
188
+ spks = spks * cfg_mask.view(-1, 1)
189
+ cond = cond * cfg_mask.view(-1, 1, 1)
190
+
191
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
192
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
193
+ return loss, y
194
+
195
+
196
+ class CausalConditionalCFM(ConditionalCFM):
197
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
198
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
199
+ set_all_random_seed(0)
200
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
201
+
202
+ @torch.inference_mode()
203
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
204
+ """Forward diffusion
205
+
206
+ Args:
207
+ mu (torch.Tensor): output of encoder
208
+ shape: (batch_size, n_feats, mel_timesteps)
209
+ mask (torch.Tensor): output_mask
210
+ shape: (batch_size, 1, mel_timesteps)
211
+ n_timesteps (int): number of diffusion steps
212
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
213
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
214
+ shape: (batch_size, spk_emb_dim)
215
+ cond: Not used but kept for future purposes
216
+
217
+ Returns:
218
+ sample: generated mel-spectrogram
219
+ shape: (batch_size, n_feats, mel_timesteps)
220
+ """
221
+
222
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
223
+ # fix prompt and overlap part mu and z
224
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
225
+ if self.t_scheduler == 'cosine':
226
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
227
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
55
+ # x in (B, T, D)
56
+ if x2.shape[1] > 40:
57
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
58
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
59
+ mode='linear')
60
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
61
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
62
+ else:
63
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
64
+ if x1.shape[1] != 0:
65
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
66
+ x = torch.concat([x1, x2], dim=2)
67
+ else:
68
+ x = x2
69
+ out = self.model(x).transpose(1, 2).contiguous()
70
+ return out, mel_len1 + mel_len2