xinference 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (103) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +30 -5
  4. xinference/client/restful/restful_client.py +18 -3
  5. xinference/constants.py +0 -4
  6. xinference/core/chat_interface.py +2 -2
  7. xinference/core/image_interface.py +6 -3
  8. xinference/core/model.py +9 -4
  9. xinference/core/scheduler.py +4 -4
  10. xinference/core/supervisor.py +2 -0
  11. xinference/core/worker.py +7 -0
  12. xinference/deploy/utils.py +6 -0
  13. xinference/model/audio/core.py +9 -4
  14. xinference/model/audio/cosyvoice.py +136 -0
  15. xinference/model/audio/model_spec.json +24 -0
  16. xinference/model/audio/model_spec_modelscope.json +27 -0
  17. xinference/model/core.py +25 -4
  18. xinference/model/embedding/core.py +88 -13
  19. xinference/model/embedding/model_spec.json +8 -0
  20. xinference/model/embedding/model_spec_modelscope.json +8 -0
  21. xinference/model/flexible/core.py +8 -2
  22. xinference/model/flexible/launchers/__init__.py +1 -0
  23. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  24. xinference/model/image/core.py +8 -5
  25. xinference/model/image/model_spec.json +36 -5
  26. xinference/model/image/model_spec_modelscope.json +21 -3
  27. xinference/model/image/stable_diffusion/core.py +36 -28
  28. xinference/model/llm/core.py +6 -4
  29. xinference/model/llm/ggml/llamacpp.py +7 -5
  30. xinference/model/llm/llm_family.json +802 -82
  31. xinference/model/llm/llm_family.py +6 -6
  32. xinference/model/llm/llm_family_csghub.json +39 -0
  33. xinference/model/llm/llm_family_modelscope.json +295 -47
  34. xinference/model/llm/mlx/core.py +7 -0
  35. xinference/model/llm/pytorch/chatglm.py +246 -5
  36. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  37. xinference/model/llm/pytorch/deepseek_vl.py +2 -1
  38. xinference/model/llm/pytorch/falcon.py +2 -1
  39. xinference/model/llm/pytorch/llama_2.py +4 -2
  40. xinference/model/llm/pytorch/omnilmm.py +2 -1
  41. xinference/model/llm/pytorch/qwen_vl.py +2 -1
  42. xinference/model/llm/pytorch/vicuna.py +2 -1
  43. xinference/model/llm/pytorch/yi_vl.py +2 -1
  44. xinference/model/llm/sglang/core.py +12 -6
  45. xinference/model/llm/utils.py +78 -1
  46. xinference/model/llm/vllm/core.py +9 -5
  47. xinference/model/rerank/core.py +4 -3
  48. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  50. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  51. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  52. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  53. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  54. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  55. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  56. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  57. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  58. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  59. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  61. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  63. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  64. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  65. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  66. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  67. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  68. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  69. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  70. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  71. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  72. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  73. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  74. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  77. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  78. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  79. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  80. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  81. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  82. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  83. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  84. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  85. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  86. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  87. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  88. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  89. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  90. xinference/web/ui/build/asset-manifest.json +3 -3
  91. xinference/web/ui/build/index.html +1 -1
  92. xinference/web/ui/build/static/js/{main.95c1d652.js → main.af906659.js} +3 -3
  93. xinference/web/ui/build/static/js/main.af906659.js.map +1 -0
  94. xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +1 -0
  95. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/METADATA +39 -11
  96. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/RECORD +101 -57
  97. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  99. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.af906659.js.LICENSE.txt} +0 -0
  100. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/LICENSE +0 -0
  101. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/WHEEL +0 -0
  102. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/entry_points.txt +0 -0
  103. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ shuffle=True,
130
+ partition=True,
131
+ tts_file='',
132
+ prompt_utt2data=''):
133
+ """ Construct dataset from arguments
134
+
135
+ We have two shuffle stage in the Dataset. The first is global
136
+ shuffle at shards tar/raw file level. The second is global shuffle
137
+ at training samples level.
138
+
139
+ Args:
140
+ data_type(str): raw/shard
141
+ tokenizer (BaseTokenizer): tokenizer to tokenize
142
+ partition(bool): whether to do data partition in terms of rank
143
+ """
144
+ assert mode in ['train', 'inference']
145
+ lists = read_lists(data_list_file)
146
+ if mode == 'inference':
147
+ with open(tts_file) as f:
148
+ tts_data = json.load(f)
149
+ utt2lists = read_json_lists(prompt_utt2data)
150
+ # filter unnecessary file in inference mode
151
+ lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
152
+ dataset = DataList(lists,
153
+ shuffle=shuffle,
154
+ partition=partition)
155
+ if mode == 'inference':
156
+ # map partial arg tts_data in inference mode
157
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158
+ for func in data_pipeline:
159
+ dataset = Processor(dataset, func, mode=mode)
160
+ return dataset
@@ -0,0 +1,369 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+
24
+ torchaudio.set_audio_backend('soundfile')
25
+
26
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ df = pq.read_table(url).to_pandas()
44
+ for i in range(len(df)):
45
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
46
+ continue
47
+ sample.update(dict(df.loc[i]))
48
+ if mode == 'train':
49
+ # NOTE do not return sample directly, must initialize a new dict
50
+ yield {**sample}
51
+ else:
52
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
53
+ yield {**sample, 'tts_index': index, 'tts_text': text}
54
+ except Exception as ex:
55
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
56
+
57
+ def filter(data,
58
+ max_length=10240,
59
+ min_length=10,
60
+ token_max_length=200,
61
+ token_min_length=1,
62
+ min_output_input_ratio=0.0005,
63
+ max_output_input_ratio=1,
64
+ mode='train'):
65
+ """ Filter sample according to feature and label length
66
+ Inplace operation.
67
+
68
+ Args::
69
+ data: Iterable[{key, wav, label, sample_rate}]
70
+ max_length: drop utterance which is greater than max_length(10ms)
71
+ min_length: drop utterance which is less than min_length(10ms)
72
+ token_max_length: drop utterance which is greater than
73
+ token_max_length, especially when use char unit for
74
+ english modeling
75
+ token_min_length: drop utterance which is
76
+ less than token_max_length
77
+ min_output_input_ratio: minimal ration of
78
+ token_length / feats_length(10ms)
79
+ max_output_input_ratio: maximum ration of
80
+ token_length / feats_length(10ms)
81
+
82
+ Returns:
83
+ Iterable[{key, wav, label, sample_rate}]
84
+ """
85
+ for sample in data:
86
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
87
+ del sample['audio_data']
88
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
89
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
90
+ if num_frames < min_length:
91
+ continue
92
+ if num_frames > max_length:
93
+ continue
94
+ if len(sample['text_token']) < token_min_length:
95
+ continue
96
+ if len(sample['text_token']) > token_max_length:
97
+ continue
98
+ if len(sample['speech_token']) == 0:
99
+ continue
100
+ if num_frames != 0:
101
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
102
+ continue
103
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
104
+ continue
105
+ yield sample
106
+
107
+
108
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
109
+ """ Resample data.
110
+ Inplace operation.
111
+
112
+ Args:
113
+ data: Iterable[{key, wav, label, sample_rate}]
114
+ resample_rate: target resample rate
115
+
116
+ Returns:
117
+ Iterable[{key, wav, label, sample_rate}]
118
+ """
119
+ for sample in data:
120
+ assert 'sample_rate' in sample
121
+ assert 'speech' in sample
122
+ sample_rate = sample['sample_rate']
123
+ waveform = sample['speech']
124
+ if sample_rate != resample_rate:
125
+ if sample_rate < min_sample_rate:
126
+ continue
127
+ sample['sample_rate'] = resample_rate
128
+ sample['speech'] = torchaudio.transforms.Resample(
129
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
130
+ max_val = sample['speech'].abs().max()
131
+ if max_val > 1:
132
+ sample['speech'] /= max_val
133
+ yield sample
134
+
135
+
136
+ def compute_fbank(data,
137
+ feat_extractor,
138
+ mode='train'):
139
+ """ Extract fbank
140
+
141
+ Args:
142
+ data: Iterable[{key, wav, label, sample_rate}]
143
+
144
+ Returns:
145
+ Iterable[{key, feat, label}]
146
+ """
147
+ for sample in data:
148
+ assert 'sample_rate' in sample
149
+ assert 'speech' in sample
150
+ assert 'utt' in sample
151
+ assert 'text_token' in sample
152
+ waveform = sample['speech']
153
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
154
+ sample['speech_feat'] = mat
155
+ del sample['speech']
156
+ yield sample
157
+
158
+
159
+ def parse_embedding(data, normalize, mode='train'):
160
+ """ Parse utt_embedding/spk_embedding
161
+
162
+ Args:
163
+ data: Iterable[{key, wav, label, sample_rate}]
164
+
165
+ Returns:
166
+ Iterable[{key, feat, label}]
167
+ """
168
+ for sample in data:
169
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
170
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
171
+ if normalize:
172
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
173
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
174
+ yield sample
175
+
176
+
177
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
178
+ """ Decode text to chars or BPE
179
+ Inplace operation
180
+
181
+ Args:
182
+ data: Iterable[{key, wav, txt, sample_rate}]
183
+
184
+ Returns:
185
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
186
+ """
187
+ tokenizer = get_tokenizer()
188
+ for sample in data:
189
+ assert 'text' in sample
190
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
191
+ if mode == 'inference':
192
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
193
+ yield sample
194
+
195
+
196
+ def shuffle(data, shuffle_size=10000, mode='train'):
197
+ """ Local shuffle the data
198
+
199
+ Args:
200
+ data: Iterable[{key, feat, label}]
201
+ shuffle_size: buffer size for shuffle
202
+
203
+ Returns:
204
+ Iterable[{key, feat, label}]
205
+ """
206
+ buf = []
207
+ for sample in data:
208
+ buf.append(sample)
209
+ if len(buf) >= shuffle_size:
210
+ random.shuffle(buf)
211
+ for x in buf:
212
+ yield x
213
+ buf = []
214
+ # The sample left over
215
+ random.shuffle(buf)
216
+ for x in buf:
217
+ yield x
218
+
219
+
220
+ def sort(data, sort_size=500, mode='train'):
221
+ """ Sort the data by feature length.
222
+ Sort is used after shuffle and before batch, so we can group
223
+ utts with similar lengths into a batch, and `sort_size` should
224
+ be less than `shuffle_size`
225
+
226
+ Args:
227
+ data: Iterable[{key, feat, label}]
228
+ sort_size: buffer size for sort
229
+
230
+ Returns:
231
+ Iterable[{key, feat, label}]
232
+ """
233
+
234
+ buf = []
235
+ for sample in data:
236
+ buf.append(sample)
237
+ if len(buf) >= sort_size:
238
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
239
+ for x in buf:
240
+ yield x
241
+ buf = []
242
+ # The sample left over
243
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
244
+ for x in buf:
245
+ yield x
246
+
247
+
248
+ def static_batch(data, batch_size=16):
249
+ """ Static batch the data by `batch_size`
250
+
251
+ Args:
252
+ data: Iterable[{key, feat, label}]
253
+ batch_size: batch size
254
+
255
+ Returns:
256
+ Iterable[List[{key, feat, label}]]
257
+ """
258
+ buf = []
259
+ for sample in data:
260
+ buf.append(sample)
261
+ if len(buf) >= batch_size:
262
+ yield buf
263
+ buf = []
264
+ if len(buf) > 0:
265
+ yield buf
266
+
267
+
268
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
269
+ """ Dynamic batch the data until the total frames in batch
270
+ reach `max_frames_in_batch`
271
+
272
+ Args:
273
+ data: Iterable[{key, feat, label}]
274
+ max_frames_in_batch: max_frames in one batch
275
+
276
+ Returns:
277
+ Iterable[List[{key, feat, label}]]
278
+ """
279
+ buf = []
280
+ longest_frames = 0
281
+ for sample in data:
282
+ assert 'speech_feat' in sample
283
+ assert isinstance(sample['speech_feat'], torch.Tensor)
284
+ new_sample_frames = sample['speech_feat'].size(0)
285
+ longest_frames = max(longest_frames, new_sample_frames)
286
+ frames_after_padding = longest_frames * (len(buf) + 1)
287
+ if frames_after_padding > max_frames_in_batch:
288
+ yield buf
289
+ buf = [sample]
290
+ longest_frames = new_sample_frames
291
+ else:
292
+ buf.append(sample)
293
+ if len(buf) > 0:
294
+ yield buf
295
+
296
+
297
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
298
+ """ Wrapper for static/dynamic batch
299
+ """
300
+ if mode == 'inference':
301
+ return static_batch(data, 1)
302
+ else:
303
+ if batch_type == 'static':
304
+ return static_batch(data, batch_size)
305
+ elif batch_type == 'dynamic':
306
+ return dynamic_batch(data, max_frames_in_batch)
307
+ else:
308
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
309
+
310
+
311
+ def padding(data, use_spk_embedding, mode='train'):
312
+ """ Padding the data into training data
313
+
314
+ Args:
315
+ data: Iterable[List[{key, feat, label}]]
316
+
317
+ Returns:
318
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
319
+ """
320
+ for sample in data:
321
+ assert isinstance(sample, list)
322
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
323
+ dtype=torch.int32)
324
+ order = torch.argsort(speech_feat_len, descending=True)
325
+
326
+ utts = [sample[i]['utt'] for i in order]
327
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
328
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
329
+ speech_token = pad_sequence(speech_token,
330
+ batch_first=True,
331
+ padding_value=0)
332
+ speech_feat = [sample[i]['speech_feat'] for i in order]
333
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
334
+ speech_feat = pad_sequence(speech_feat,
335
+ batch_first=True,
336
+ padding_value=0)
337
+ text = [sample[i]['text'] for i in order]
338
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
339
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
340
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
341
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
342
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
343
+ batch = {
344
+ "utts": utts,
345
+ "speech_token": speech_token,
346
+ "speech_token_len": speech_token_len,
347
+ "speech_feat": speech_feat,
348
+ "speech_feat_len": speech_feat_len,
349
+ "text": text,
350
+ "text_token": text_token,
351
+ "text_token_len": text_token_len,
352
+ "utt_embedding": utt_embedding,
353
+ "spk_embedding": spk_embedding,
354
+ }
355
+ if mode == 'inference':
356
+ tts_text = [sample[i]['tts_text'] for i in order]
357
+ tts_index = [sample[i]['tts_index'] for i in order]
358
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
359
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
360
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
361
+ batch.update({'tts_text': tts_text,
362
+ 'tts_index': tts_index,
363
+ 'tts_text_token': tts_text_token,
364
+ 'tts_text_token_len': tts_text_token_len})
365
+ if use_spk_embedding is True:
366
+ batch["embedding"] = batch["spk_embedding"]
367
+ else:
368
+ batch["embedding"] = batch["utt_embedding"]
369
+ yield batch
File without changes