minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,230 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ from stepaudio2.cosyvoice2.utils.mask import make_pad_mask
19
+ from stepaudio2.cosyvoice2.flow.flow_matching import CausalConditionalCFM
20
+ from stepaudio2.cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
21
+
22
+
23
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
24
+ def __init__(self,
25
+ input_size: int = 512,
26
+ output_size: int = 80,
27
+ spk_embed_dim: int = 192,
28
+ output_type: str = "mel",
29
+ vocab_size: int = 5121,
30
+ encoder: UpsampleConformerEncoderV2 = None,
31
+ decoder: CausalConditionalCFM = None,
32
+ input_embedding: torch.nn.Module = None,
33
+ ):
34
+ super().__init__()
35
+ self.input_size = input_size
36
+ self.output_size = output_size
37
+ self.vocab_size = vocab_size
38
+ self.output_type = output_type
39
+ self.pre_lookahead_len = int(encoder.pre_lookahead_layer.pre_lookahead_len)
40
+ self.up_rate = int(encoder.up_layer.stride)
41
+ if input_embedding is None:
42
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
43
+ else:
44
+ self.input_embedding = input_embedding
45
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
46
+ self.encoder = encoder
47
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
48
+ self.decoder = decoder
49
+
50
+ # xvec projection with CUDA Graph optimization
51
+ # 初始化 CUDA Graph 相关变量
52
+ self.enable_cuda_graph = False
53
+ self.static_embedding = None
54
+ self.static_output = None
55
+ self.graph = None
56
+ self.embedding_shape = None
57
+
58
+ def scatter_cuda_graph(self, enable_cuda_graph: bool):
59
+ self.enable_cuda_graph = enable_cuda_graph
60
+ if self.enable_cuda_graph:
61
+ # self.encoder.scatter_cuda_graph(enable_cuda_graph)
62
+ self.decoder.scatter_cuda_graph(enable_cuda_graph)
63
+
64
+ @torch.inference_mode()
65
+ def inference(self,
66
+ token,
67
+ token_len,
68
+ prompt_token,
69
+ prompt_token_len,
70
+ prompt_feat,
71
+ prompt_feat_len,
72
+ embedding,
73
+ n_timesteps: int = 10,
74
+ ):
75
+ assert token.shape[0] == 1
76
+
77
+ # xvec projection
78
+ embedding = F.normalize(embedding, dim=1)
79
+ embedding = self.spk_embed_affine_layer(embedding)
80
+
81
+ # concat text and prompt_text
82
+ token_len = prompt_token_len + token_len
83
+ token = torch.concat([prompt_token, token], dim=1)
84
+
85
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
86
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
87
+
88
+ # token encode
89
+ h, _ = self.encoder.forward(token, token_len)
90
+ h = self.encoder_proj(h)
91
+
92
+ # condition
93
+ mel_len1 = prompt_feat.shape[1]
94
+ mel_len2 = h.shape[1] - prompt_feat.shape[1]
95
+
96
+ conds = torch.zeros_like(h)
97
+ conds[:, :mel_len1] = prompt_feat
98
+ conds = conds.transpose(1, 2).contiguous()
99
+
100
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
101
+
102
+ feat = self.decoder.forward(
103
+ mu=h.transpose(1, 2).contiguous(),
104
+ mask=mask.unsqueeze(1),
105
+ spks=embedding,
106
+ cond=conds,
107
+ n_timesteps=n_timesteps,
108
+ )
109
+
110
+ feat = feat[:, :, mel_len1:]
111
+ assert feat.shape[2] == mel_len2
112
+ return feat
113
+
114
+ @torch.inference_mode()
115
+ def setup_cache(self,
116
+ token: torch.Tensor,
117
+ mel: torch.Tensor,
118
+ spk: torch.Tensor,
119
+ n_timesteps: int = 10,
120
+ ):
121
+ """
122
+ Args:
123
+ token: shape (b, t), with look ahead tokens
124
+ mel: shape (b, t, c), groundtruth mel
125
+ spk: shape (b, 192), speaker embedding
126
+ Returns:
127
+ cache: dict {
128
+ 'conformer': {'cnn_cache': xxx, 'att_cache': xxx},
129
+ 'estimator': {'cnn_cache': xxx, 'att_cache': xxx}
130
+ }
131
+ """
132
+ # check if look ahead token included
133
+ assert (token.shape[1] - self.pre_lookahead_len) * self.up_rate == mel.shape[1], (token.shape, mel.shape)
134
+
135
+ # xvec projection
136
+ spk = F.normalize(spk, dim=1)
137
+ spk = self.spk_embed_affine_layer(spk)
138
+
139
+ token = self.input_embedding(token)
140
+ # NOTE encoder.forward_chunk will strip the look ahead part
141
+ h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
142
+ xs = token,
143
+ last_chunk = False,
144
+ cnn_cache = None,
145
+ att_cache = None,
146
+ )
147
+ h = self.encoder_proj(h)
148
+
149
+ feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
150
+ mu = h.transpose(1, 2).contiguous(),
151
+ spks = spk,
152
+ cond = mel.transpose(1, 2).contiguous(),
153
+ n_timesteps = n_timesteps,
154
+ temperature = 1.0,
155
+ cnn_cache = None,
156
+ att_cache = None,
157
+ )
158
+
159
+ cache = {
160
+ 'conformer_cnn_cache': conformer_cnn_cache,
161
+ 'conformer_att_cache': conformer_att_cache,
162
+ 'estimator_cnn_cache': estimator_cnn_cache,
163
+ 'estimator_att_cache': estimator_att_cache,
164
+ }
165
+
166
+ # print("examining flow cache")
167
+ # from IPython import embed; embed()
168
+
169
+
170
+ return cache
171
+
172
+ @torch.inference_mode()
173
+ def inference_chunk(self,
174
+ token: torch.Tensor,
175
+ spk: torch.Tensor,
176
+ cache: dict,
177
+ last_chunk: bool = False,
178
+ n_timesteps: int = 10,
179
+ ):
180
+ """
181
+ Args:
182
+ token: shape (b, t), with look ahead tokens
183
+ spk: shape (b, 192), speaker embedding
184
+ cache: dict {
185
+ 'conformer_cnn_cache': xxx,
186
+ ...
187
+ }
188
+ """
189
+ # unpack cache
190
+ conformer_cnn_cache = cache['conformer_cnn_cache']
191
+ conformer_att_cache = cache['conformer_att_cache']
192
+ estimator_cnn_cache = cache['estimator_cnn_cache']
193
+ estimator_att_cache = cache['estimator_att_cache']
194
+
195
+ # xvec projection
196
+ spk = F.normalize(spk, dim=1)
197
+ spk = self.spk_embed_affine_layer(spk)
198
+
199
+ token = self.input_embedding(token)
200
+ # if not the last chunk, h is shorter than xs for a length of lookahead_length * stride (6)
201
+ h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
202
+ xs = token,
203
+ last_chunk = last_chunk,
204
+ cnn_cache = conformer_cnn_cache,
205
+ att_cache = conformer_att_cache,
206
+ )
207
+ h = self.encoder_proj(h)
208
+
209
+ cond = torch.zeros_like(h)
210
+ # forward estimator
211
+ feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
212
+ mu = h.transpose(1, 2).contiguous(),
213
+ spks = spk,
214
+ cond = cond.transpose(1, 2).contiguous(),
215
+ n_timesteps = n_timesteps,
216
+ temperature = 1.0,
217
+ cnn_cache = estimator_cnn_cache,
218
+ att_cache = estimator_att_cache,
219
+ )
220
+
221
+
222
+ new_cache = {
223
+ 'conformer_cnn_cache': conformer_cnn_cache,
224
+ 'conformer_att_cache': conformer_att_cache,
225
+ 'estimator_cnn_cache': estimator_cnn_cache,
226
+ 'estimator_att_cache': estimator_att_cache,
227
+ }
228
+
229
+ return feat, new_cache
230
+
@@ -0,0 +1,205 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ import onnxruntime
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ from stepaudio2.cosyvoice2.flow.decoder_dit import DiT
20
+ from stepaudio2.cosyvoice2.utils.mask import make_pad_mask
21
+
22
+
23
+ """
24
+ Inference wrapper
25
+ """
26
+ class CausalConditionalCFM(torch.nn.Module):
27
+ def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7):
28
+ super().__init__()
29
+ self.estimator = estimator
30
+ self.inference_cfg_rate = inference_cfg_rate
31
+ self.out_channels = estimator.out_channels
32
+ # a maximum of 600s
33
+ self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False)
34
+
35
+ self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False)
36
+ self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False)
37
+
38
+ def scatter_cuda_graph(self, enable_cuda_graph: bool):
39
+ if enable_cuda_graph:
40
+ self.estimator._init_cuda_graph_all()
41
+
42
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
43
+ """
44
+ Fixed euler solver for ODEs.
45
+ Args:
46
+ x (torch.Tensor): random noise
47
+ t_span (torch.Tensor): n_timesteps interpolated
48
+ shape: (n_timesteps + 1,)
49
+ mu (torch.Tensor): output of encoder
50
+ shape: (batch_size, n_feats, mel_timesteps)
51
+ mask (torch.Tensor): output_mask
52
+ shape: (batch_size, 1, mel_timesteps)
53
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
54
+ shape: (batch_size, spk_emb_dim)
55
+ cond: Not used but kept for future purposes
56
+ """
57
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
58
+ t = t.unsqueeze(dim=0)
59
+ assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0'
60
+
61
+ # constant during denoising
62
+ mask_in = torch.cat([mask, mask], dim=0)
63
+ mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
64
+ spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
65
+ cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
66
+
67
+ for step in range(1, len(t_span)):
68
+
69
+ x_in = torch.cat([x, x], dim=0)
70
+ t_in = torch.cat([t, t], dim=0)
71
+
72
+ dphi_dt = self.estimator.forward(
73
+ x_in,
74
+ mask_in,
75
+ mu_in,
76
+ t_in,
77
+ spks_in,
78
+ cond_in,
79
+ )
80
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
81
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
82
+ x = x + dt * dphi_dt
83
+ t = t + dt
84
+ if step < len(t_span) - 1:
85
+ dt = t_span[step + 1] - t
86
+
87
+ return x
88
+
89
+ @torch.inference_mode()
90
+ def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0):
91
+ z = self.rand_noise[:, :, :mu.size(2)] * temperature
92
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
93
+ # cosine scheduling
94
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
95
+ return self.solve_euler(z, t_span, mu, mask, spks, cond)
96
+
97
+ def solve_euler_chunk(self,
98
+ x:torch.Tensor,
99
+ t_span:torch.Tensor,
100
+ mu:torch.Tensor,
101
+ spks:torch.Tensor,
102
+ cond:torch.Tensor,
103
+ cnn_cache:torch.Tensor=None,
104
+ att_cache:torch.Tensor=None,
105
+ ):
106
+ """
107
+ Fixed euler solver for ODEs.
108
+ Args:
109
+ x (torch.Tensor): random noise
110
+ t_span (torch.Tensor): n_timesteps interpolated
111
+ shape: (n_timesteps + 1,)
112
+ mu (torch.Tensor): output of encoder
113
+ shape: (batch_size, n_feats, mel_timesteps)
114
+ mask (torch.Tensor): output_mask
115
+ shape: (batch_size, 1, mel_timesteps)
116
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
117
+ shape: (batch_size, spk_emb_dim)
118
+ cond: Not used but kept for future purposes
119
+ cnn_cache: shape (n_time, depth, b, c1+c2, 2)
120
+ att_cache: shape (n_time, depth, b, nh, t, c * 2)
121
+ """
122
+ assert self.inference_cfg_rate > 0, 'cfg rate should be > 0'
123
+
124
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
125
+ t = t.unsqueeze(dim=0) # (b,)
126
+
127
+ # setup initial cache
128
+ if cnn_cache is None:
129
+ cnn_cache = [None for _ in range(len(t_span)-1)]
130
+ if att_cache is None:
131
+ att_cache = [None for _ in range(len(t_span)-1)]
132
+ # next chunk's cache at each timestep
133
+
134
+ if att_cache[0] is not None:
135
+ last_att_len = att_cache.shape[4]
136
+ else:
137
+ last_att_len = 0
138
+
139
+ # constant during denoising
140
+ mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
141
+ spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
142
+ cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
143
+ for step in range(1, len(t_span)):
144
+ # torch.cuda.memory._record_memory_history(max_entries=100000)
145
+ # torch.cuda.memory._record_memory_history(max_entries=100000)
146
+ this_att_cache = att_cache[step-1]
147
+ this_cnn_cache = cnn_cache[step-1]
148
+
149
+ dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk(
150
+ x = x.repeat(2, 1, 1),
151
+ mu = mu_in,
152
+ t = t.repeat(2),
153
+ spks = spks_in,
154
+ cond = cond_in,
155
+ cnn_cache = this_cnn_cache,
156
+ att_cache = this_att_cache,
157
+ )
158
+ dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0)
159
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
160
+ x = x + dt * dphi_dt
161
+ t = t + dt
162
+ if step < len(t_span) - 1:
163
+ dt = t_span[step + 1] - t
164
+
165
+ self.cnn_cache_buffer[step-1] = this_new_cnn_cache
166
+ self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache
167
+
168
+ cnn_cache = self.cnn_cache_buffer
169
+ att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :]
170
+ return x, cnn_cache, att_cache
171
+
172
+ @torch.inference_mode()
173
+ def forward_chunk(self,
174
+ mu:torch.Tensor,
175
+ spks:torch.Tensor,
176
+ cond:torch.Tensor,
177
+ n_timesteps:int=10,
178
+ temperature:float=1.0,
179
+ cnn_cache:torch.Tensor=None,
180
+ att_cache:torch.Tensor=None,
181
+ ):
182
+ """
183
+ Args:
184
+ mu(torch.Tensor): shape (b, c, t)
185
+ spks(torch.Tensor): shape (b, 192)
186
+ cond(torch.Tensor): shape (b, c, t)
187
+ cnn_cache: shape (n_time, depth, b, c1+c2, 2)
188
+ att_cache: shape (n_time, depth, b, nh, t, c * 2)
189
+ """
190
+ # get offset from att_cache
191
+ offset = att_cache.shape[4] if att_cache is not None else 0
192
+ z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature
193
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
194
+ # cosine scheduling
195
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
196
+ x, new_cnn_cache, new_att_cache = self.solve_euler_chunk(
197
+ x=z,
198
+ t_span=t_span,
199
+ mu=mu,
200
+ spks=spks,
201
+ cond=cond,
202
+ att_cache=att_cache,
203
+ cnn_cache=cnn_cache,
204
+ )
205
+ return x, new_cnn_cache, new_att_cache
File without changes