minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,434 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+ import pyworld as pw
24
+
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ sample.update(dict(df.loc[i]))
47
+ if mode == 'train':
48
+ # NOTE do not return sample directly, must initialize a new dict
49
+ yield {**sample}
50
+ else:
51
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
52
+ yield {**sample, 'tts_index': index, 'tts_text': text}
53
+ except Exception as ex:
54
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
55
+
56
+
57
+ def filter(data,
58
+ max_length=10240,
59
+ min_length=10,
60
+ token_max_length=200,
61
+ token_min_length=1,
62
+ min_output_input_ratio=0.0005,
63
+ max_output_input_ratio=1,
64
+ mode='train'):
65
+ """ Filter sample according to feature and label length
66
+ Inplace operation.
67
+
68
+ Args::
69
+ data: Iterable[{key, wav, label, sample_rate}]
70
+ max_length: drop utterance which is greater than max_length(10ms)
71
+ min_length: drop utterance which is less than min_length(10ms)
72
+ token_max_length: drop utterance which is greater than
73
+ token_max_length, especially when use char unit for
74
+ english modeling
75
+ token_min_length: drop utterance which is
76
+ less than token_max_length
77
+ min_output_input_ratio: minimal ration of
78
+ token_length / feats_length(10ms)
79
+ max_output_input_ratio: maximum ration of
80
+ token_length / feats_length(10ms)
81
+
82
+ Returns:
83
+ Iterable[{key, wav, label, sample_rate}]
84
+ """
85
+ for sample in data:
86
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
87
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
88
+ del sample['audio_data']
89
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
90
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
91
+ if num_frames < min_length:
92
+ continue
93
+ if num_frames > max_length:
94
+ continue
95
+ if len(sample['text_token']) < token_min_length:
96
+ continue
97
+ if len(sample['text_token']) > token_max_length:
98
+ continue
99
+ if len(sample['speech_token']) == 0:
100
+ continue
101
+ if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ token_mel_ratio=0,
163
+ mode='train'):
164
+ """ Extract fbank
165
+
166
+ Args:
167
+ data: Iterable[{key, wav, label, sample_rate}]
168
+
169
+ Returns:
170
+ Iterable[{key, feat, label}]
171
+ """
172
+ for sample in data:
173
+ assert 'sample_rate' in sample
174
+ assert 'speech' in sample
175
+ assert 'utt' in sample
176
+ assert 'text_token' in sample
177
+ waveform = sample['speech']
178
+ feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
179
+ if token_mel_ratio != 0:
180
+ # trim to align speech_token and speech_feat
181
+ token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
182
+ feat = feat[:token_mel_ratio * token_len]
183
+ sample["speech_token"] = sample["speech_token"][:token_len]
184
+ sample['speech_feat'] = feat
185
+ yield sample
186
+
187
+
188
+ def compute_f0(data, sample_rate, hop_size, mode='train'):
189
+ """ Extract f0
190
+
191
+ Args:
192
+ data: Iterable[{key, wav, label, sample_rate}]
193
+
194
+ Returns:
195
+ Iterable[{key, feat, label}]
196
+ """
197
+ frame_period = hop_size * 1000 / sample_rate
198
+ for sample in data:
199
+ assert 'sample_rate' in sample
200
+ assert 'speech' in sample
201
+ assert 'utt' in sample
202
+ assert 'text_token' in sample
203
+ waveform = sample['speech']
204
+ _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
205
+ if sum(_f0 != 0) < 5: # this happens when the algorithm fails
206
+ _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
207
+ f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
208
+ f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
209
+ sample['pitch_feat'] = f0
210
+ yield sample
211
+
212
+
213
+ def parse_embedding(data, normalize, mode='train'):
214
+ """ Parse utt_embedding/spk_embedding
215
+
216
+ Args:
217
+ data: Iterable[{key, wav, label, sample_rate}]
218
+
219
+ Returns:
220
+ Iterable[{key, feat, label}]
221
+ """
222
+ for sample in data:
223
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
224
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
225
+ if normalize:
226
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
227
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
228
+ yield sample
229
+
230
+
231
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
232
+ """ Decode text to chars or BPE
233
+ Inplace operation
234
+
235
+ Args:
236
+ data: Iterable[{key, wav, txt, sample_rate}]
237
+
238
+ Returns:
239
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
240
+ """
241
+ tokenizer = get_tokenizer()
242
+ for sample in data:
243
+ assert 'text' in sample
244
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
245
+ yield sample
246
+
247
+
248
+ def shuffle(data, shuffle_size=10000, mode='train'):
249
+ """ Local shuffle the data
250
+
251
+ Args:
252
+ data: Iterable[{key, feat, label}]
253
+ shuffle_size: buffer size for shuffle
254
+
255
+ Returns:
256
+ Iterable[{key, feat, label}]
257
+ """
258
+ buf = []
259
+ for sample in data:
260
+ buf.append(sample)
261
+ if len(buf) >= shuffle_size:
262
+ random.shuffle(buf)
263
+ for x in buf:
264
+ yield x
265
+ buf = []
266
+ # The sample left over
267
+ random.shuffle(buf)
268
+ for x in buf:
269
+ yield x
270
+
271
+
272
+ def sort(data, sort_size=500, mode='train'):
273
+ """ Sort the data by feature length.
274
+ Sort is used after shuffle and before batch, so we can group
275
+ utts with similar lengths into a batch, and `sort_size` should
276
+ be less than `shuffle_size`
277
+
278
+ Args:
279
+ data: Iterable[{key, feat, label}]
280
+ sort_size: buffer size for sort
281
+
282
+ Returns:
283
+ Iterable[{key, feat, label}]
284
+ """
285
+
286
+ buf = []
287
+ for sample in data:
288
+ buf.append(sample)
289
+ if len(buf) >= sort_size:
290
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
291
+ for x in buf:
292
+ yield x
293
+ buf = []
294
+ # The sample left over
295
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
296
+ for x in buf:
297
+ yield x
298
+
299
+
300
+ def static_batch(data, batch_size=16):
301
+ """ Static batch the data by `batch_size`
302
+
303
+ Args:
304
+ data: Iterable[{key, feat, label}]
305
+ batch_size: batch size
306
+
307
+ Returns:
308
+ Iterable[List[{key, feat, label}]]
309
+ """
310
+ buf = []
311
+ for sample in data:
312
+ buf.append(sample)
313
+ if len(buf) >= batch_size:
314
+ yield buf
315
+ buf = []
316
+ if len(buf) > 0:
317
+ yield buf
318
+
319
+
320
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
321
+ """ Dynamic batch the data until the total frames in batch
322
+ reach `max_frames_in_batch`
323
+
324
+ Args:
325
+ data: Iterable[{key, feat, label}]
326
+ max_frames_in_batch: max_frames in one batch
327
+
328
+ Returns:
329
+ Iterable[List[{key, feat, label}]]
330
+ """
331
+ buf = []
332
+ longest_frames = 0
333
+ for sample in data:
334
+ assert 'speech_feat' in sample
335
+ assert isinstance(sample['speech_feat'], torch.Tensor)
336
+ new_sample_frames = sample['speech_feat'].size(0)
337
+ longest_frames = max(longest_frames, new_sample_frames)
338
+ frames_after_padding = longest_frames * (len(buf) + 1)
339
+ if frames_after_padding > max_frames_in_batch:
340
+ yield buf
341
+ buf = [sample]
342
+ longest_frames = new_sample_frames
343
+ else:
344
+ buf.append(sample)
345
+ if len(buf) > 0:
346
+ yield buf
347
+
348
+
349
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
350
+ """ Wrapper for static/dynamic batch
351
+ """
352
+ if batch_type == 'static':
353
+ return static_batch(data, batch_size)
354
+ elif batch_type == 'dynamic':
355
+ return dynamic_batch(data, max_frames_in_batch)
356
+ else:
357
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
358
+
359
+
360
+ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
361
+ """ Padding the data into training data
362
+
363
+ Args:
364
+ data: Iterable[List[{key, feat, label}]]
365
+
366
+ Returns:
367
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
368
+ """
369
+ for sample in data:
370
+ assert isinstance(sample, list)
371
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
372
+ dtype=torch.int32)
373
+ order = torch.argsort(speech_feat_len, descending=True)
374
+
375
+ utts = [sample[i]['utt'] for i in order]
376
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
377
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
378
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
379
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
380
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
381
+ speech_token = pad_sequence(speech_token,
382
+ batch_first=True,
383
+ padding_value=0)
384
+ speech_feat = [sample[i]['speech_feat'] for i in order]
385
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
386
+ speech_feat = pad_sequence(speech_feat,
387
+ batch_first=True,
388
+ padding_value=0)
389
+ text = [sample[i]['text'] for i in order]
390
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
391
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
392
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
393
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
394
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
395
+ batch = {
396
+ "utts": utts,
397
+ "speech": speech,
398
+ "speech_len": speech_len,
399
+ "speech_token": speech_token,
400
+ "speech_token_len": speech_token_len,
401
+ "speech_feat": speech_feat,
402
+ "speech_feat_len": speech_feat_len,
403
+ "text": text,
404
+ "text_token": text_token,
405
+ "text_token_len": text_token_len,
406
+ "utt_embedding": utt_embedding,
407
+ "spk_embedding": spk_embedding,
408
+ }
409
+ if gan is True:
410
+ # in gan train, we need pitch_feat
411
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
412
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
413
+ pitch_feat = pad_sequence(pitch_feat,
414
+ batch_first=True,
415
+ padding_value=0)
416
+ batch["pitch_feat"] = pitch_feat
417
+ batch["pitch_feat_len"] = pitch_feat_len
418
+ else:
419
+ # only gan train needs speech, delete it to save memory
420
+ del batch["speech"]
421
+ del batch["speech_len"]
422
+ if dpo is True:
423
+ reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
424
+ reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
425
+ reject_speech_token = pad_sequence(reject_speech_token,
426
+ batch_first=True,
427
+ padding_value=0)
428
+ batch['reject_speech_token'] = reject_speech_token
429
+ batch['reject_speech_token_len'] = reject_speech_token_len
430
+ if use_spk_embedding is True:
431
+ batch["embedding"] = batch["spk_embedding"]
432
+ else:
433
+ batch["embedding"] = batch["utt_embedding"]
434
+ yield batch