xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +26 -4
- xinference/client/restful/restful_client.py +16 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +8 -3
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/model_spec.json +7 -0
- xinference/model/image/stable_diffusion/core.py +6 -1
- xinference/model/llm/llm_family.json +802 -82
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +295 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +8 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
|
2
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import torchaudio
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def read_lists(list_file):
|
|
21
|
+
lists = []
|
|
22
|
+
with open(list_file, 'r', encoding='utf8') as fin:
|
|
23
|
+
for line in fin:
|
|
24
|
+
lists.append(line.strip())
|
|
25
|
+
return lists
|
|
26
|
+
|
|
27
|
+
def read_json_lists(list_file):
|
|
28
|
+
lists = read_lists(list_file)
|
|
29
|
+
results = {}
|
|
30
|
+
for fn in lists:
|
|
31
|
+
with open(fn, 'r', encoding='utf8') as fin:
|
|
32
|
+
results.update(json.load(fin))
|
|
33
|
+
return results
|
|
34
|
+
|
|
35
|
+
def load_wav(wav, target_sr):
|
|
36
|
+
speech, sample_rate = torchaudio.load(wav)
|
|
37
|
+
speech = speech.mean(dim=0, keepdim=True)
|
|
38
|
+
if sample_rate != target_sr:
|
|
39
|
+
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
|
40
|
+
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
|
41
|
+
return speech
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
|
17
|
+
|
|
18
|
+
# whether contain chinese character
|
|
19
|
+
def contains_chinese(text):
|
|
20
|
+
return bool(chinese_char_pattern.search(text))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# replace special symbol
|
|
24
|
+
def replace_corner_mark(text):
|
|
25
|
+
text = text.replace('²', '平方')
|
|
26
|
+
text = text.replace('³', '立方')
|
|
27
|
+
return text
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# remove meaningless symbol
|
|
31
|
+
def remove_bracket(text):
|
|
32
|
+
text = text.replace('(', '').replace(')', '')
|
|
33
|
+
text = text.replace('【', '').replace('】', '')
|
|
34
|
+
text = text.replace('`', '').replace('`', '')
|
|
35
|
+
text = text.replace("——", " ")
|
|
36
|
+
return text
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# spell Arabic numerals
|
|
40
|
+
def spell_out_number(text: str, inflect_parser):
|
|
41
|
+
new_text = []
|
|
42
|
+
st = None
|
|
43
|
+
for i, c in enumerate(text):
|
|
44
|
+
if not c.isdigit():
|
|
45
|
+
if st is not None:
|
|
46
|
+
num_str = inflect_parser.number_to_words(text[st: i])
|
|
47
|
+
new_text.append(num_str)
|
|
48
|
+
st = None
|
|
49
|
+
new_text.append(c)
|
|
50
|
+
else:
|
|
51
|
+
if st is None:
|
|
52
|
+
st = i
|
|
53
|
+
if st is not None and st < len(text):
|
|
54
|
+
num_str = inflect_parser.number_to_words(text[st:])
|
|
55
|
+
new_text.append(num_str)
|
|
56
|
+
return ''.join(new_text)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# split paragrah logic:
|
|
60
|
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
|
61
|
+
# 2. cal sentence len according to lang
|
|
62
|
+
# 3. split sentence according to puncatation
|
|
63
|
+
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
|
64
|
+
def calc_utt_length(_text: str):
|
|
65
|
+
if lang == "zh":
|
|
66
|
+
return len(_text)
|
|
67
|
+
else:
|
|
68
|
+
return len(tokenize(_text))
|
|
69
|
+
|
|
70
|
+
def should_merge(_text: str):
|
|
71
|
+
if lang == "zh":
|
|
72
|
+
return len(_text) < merge_len
|
|
73
|
+
else:
|
|
74
|
+
return len(tokenize(_text)) < merge_len
|
|
75
|
+
|
|
76
|
+
if lang == "zh":
|
|
77
|
+
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
|
78
|
+
else:
|
|
79
|
+
pounc = ['.', '?', '!', ';', ':']
|
|
80
|
+
if comma_split:
|
|
81
|
+
pounc.extend([',', ','])
|
|
82
|
+
st = 0
|
|
83
|
+
utts = []
|
|
84
|
+
for i, c in enumerate(text):
|
|
85
|
+
if c in pounc:
|
|
86
|
+
if len(text[st: i]) > 0:
|
|
87
|
+
utts.append(text[st: i] + c)
|
|
88
|
+
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
|
89
|
+
tmp = utts.pop(-1)
|
|
90
|
+
utts.append(tmp + text[i + 1])
|
|
91
|
+
st = i + 2
|
|
92
|
+
else:
|
|
93
|
+
st = i + 1
|
|
94
|
+
if len(utts) == 0:
|
|
95
|
+
if lang == "zh":
|
|
96
|
+
utts.append(text + '。')
|
|
97
|
+
else:
|
|
98
|
+
utts.append(text + '.')
|
|
99
|
+
final_utts = []
|
|
100
|
+
cur_utt = ""
|
|
101
|
+
for utt in utts:
|
|
102
|
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
|
103
|
+
final_utts.append(cur_utt)
|
|
104
|
+
cur_utt = ""
|
|
105
|
+
cur_utt = cur_utt + utt
|
|
106
|
+
if len(cur_utt) > 0:
|
|
107
|
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
|
108
|
+
final_utts[-1] = final_utts[-1] + cur_utt
|
|
109
|
+
else:
|
|
110
|
+
final_utts.append(cur_utt)
|
|
111
|
+
|
|
112
|
+
return final_utts
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# remove blank between chinese character
|
|
116
|
+
def replace_blank(text: str):
|
|
117
|
+
out_str = []
|
|
118
|
+
for i, c in enumerate(text):
|
|
119
|
+
if c == " ":
|
|
120
|
+
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
|
121
|
+
(text[i - 1].isascii() and text[i - 1] != " ")):
|
|
122
|
+
out_str.append(c)
|
|
123
|
+
else:
|
|
124
|
+
out_str.append(c)
|
|
125
|
+
return "".join(out_str)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
'''
|
|
19
|
+
def subsequent_mask(
|
|
20
|
+
size: int,
|
|
21
|
+
device: torch.device = torch.device("cpu"),
|
|
22
|
+
) -> torch.Tensor:
|
|
23
|
+
"""Create mask for subsequent steps (size, size).
|
|
24
|
+
|
|
25
|
+
This mask is used only in decoder which works in an auto-regressive mode.
|
|
26
|
+
This means the current step could only do attention with its left steps.
|
|
27
|
+
|
|
28
|
+
In encoder, fully attention is used when streaming is not necessary and
|
|
29
|
+
the sequence is not long. In this case, no attention mask is needed.
|
|
30
|
+
|
|
31
|
+
When streaming is need, chunk-based attention is used in encoder. See
|
|
32
|
+
subsequent_chunk_mask for the chunk-based attention mask.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
size (int): size of mask
|
|
36
|
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
37
|
+
dtype (torch.device): result dtype
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
torch.Tensor: mask
|
|
41
|
+
|
|
42
|
+
Examples:
|
|
43
|
+
>>> subsequent_mask(3)
|
|
44
|
+
[[1, 0, 0],
|
|
45
|
+
[1, 1, 0],
|
|
46
|
+
[1, 1, 1]]
|
|
47
|
+
"""
|
|
48
|
+
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
|
49
|
+
return torch.tril(ret)
|
|
50
|
+
'''
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def subsequent_mask(
|
|
54
|
+
size: int,
|
|
55
|
+
device: torch.device = torch.device("cpu"),
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
"""Create mask for subsequent steps (size, size).
|
|
58
|
+
|
|
59
|
+
This mask is used only in decoder which works in an auto-regressive mode.
|
|
60
|
+
This means the current step could only do attention with its left steps.
|
|
61
|
+
|
|
62
|
+
In encoder, fully attention is used when streaming is not necessary and
|
|
63
|
+
the sequence is not long. In this case, no attention mask is needed.
|
|
64
|
+
|
|
65
|
+
When streaming is need, chunk-based attention is used in encoder. See
|
|
66
|
+
subsequent_chunk_mask for the chunk-based attention mask.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
size (int): size of mask
|
|
70
|
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
71
|
+
dtype (torch.device): result dtype
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: mask
|
|
75
|
+
|
|
76
|
+
Examples:
|
|
77
|
+
>>> subsequent_mask(3)
|
|
78
|
+
[[1, 0, 0],
|
|
79
|
+
[1, 1, 0],
|
|
80
|
+
[1, 1, 1]]
|
|
81
|
+
"""
|
|
82
|
+
arange = torch.arange(size, device=device)
|
|
83
|
+
mask = arange.expand(size, size)
|
|
84
|
+
arange = arange.unsqueeze(-1)
|
|
85
|
+
mask = mask <= arange
|
|
86
|
+
return mask
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def subsequent_chunk_mask(
|
|
90
|
+
size: int,
|
|
91
|
+
chunk_size: int,
|
|
92
|
+
num_left_chunks: int = -1,
|
|
93
|
+
device: torch.device = torch.device("cpu"),
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
96
|
+
this is for streaming encoder
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
size (int): size of mask
|
|
100
|
+
chunk_size (int): size of chunk
|
|
101
|
+
num_left_chunks (int): number of left chunks
|
|
102
|
+
<0: use full chunk
|
|
103
|
+
>=0: use num_left_chunks
|
|
104
|
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
torch.Tensor: mask
|
|
108
|
+
|
|
109
|
+
Examples:
|
|
110
|
+
>>> subsequent_chunk_mask(4, 2)
|
|
111
|
+
[[1, 1, 0, 0],
|
|
112
|
+
[1, 1, 0, 0],
|
|
113
|
+
[1, 1, 1, 1],
|
|
114
|
+
[1, 1, 1, 1]]
|
|
115
|
+
"""
|
|
116
|
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
|
117
|
+
for i in range(size):
|
|
118
|
+
if num_left_chunks < 0:
|
|
119
|
+
start = 0
|
|
120
|
+
else:
|
|
121
|
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
|
122
|
+
ending = min((i // chunk_size + 1) * chunk_size, size)
|
|
123
|
+
ret[i, start:ending] = True
|
|
124
|
+
return ret
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
|
128
|
+
masks: torch.Tensor,
|
|
129
|
+
use_dynamic_chunk: bool,
|
|
130
|
+
use_dynamic_left_chunk: bool,
|
|
131
|
+
decoding_chunk_size: int,
|
|
132
|
+
static_chunk_size: int,
|
|
133
|
+
num_decoding_left_chunks: int,
|
|
134
|
+
enable_full_context: bool = True):
|
|
135
|
+
""" Apply optional mask for encoder.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
|
139
|
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
|
140
|
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
141
|
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
142
|
+
training.
|
|
143
|
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
144
|
+
0: default for training, use random dynamic chunk.
|
|
145
|
+
<0: for decoding, use full chunk.
|
|
146
|
+
>0: for decoding, use fixed chunk size as set.
|
|
147
|
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
148
|
+
if it's greater than 0, if use_dynamic_chunk is true,
|
|
149
|
+
this parameter will be ignored
|
|
150
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
151
|
+
the chunk size is decoding_chunk_size.
|
|
152
|
+
>=0: use num_decoding_left_chunks
|
|
153
|
+
<0: use all left chunks
|
|
154
|
+
enable_full_context (bool):
|
|
155
|
+
True: chunk size is either [1, 25] or full context(max_len)
|
|
156
|
+
False: chunk size ~ U[1, 25]
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
torch.Tensor: chunk mask of the input xs.
|
|
160
|
+
"""
|
|
161
|
+
# Whether to use chunk mask or not
|
|
162
|
+
if use_dynamic_chunk:
|
|
163
|
+
max_len = xs.size(1)
|
|
164
|
+
if decoding_chunk_size < 0:
|
|
165
|
+
chunk_size = max_len
|
|
166
|
+
num_left_chunks = -1
|
|
167
|
+
elif decoding_chunk_size > 0:
|
|
168
|
+
chunk_size = decoding_chunk_size
|
|
169
|
+
num_left_chunks = num_decoding_left_chunks
|
|
170
|
+
else:
|
|
171
|
+
# chunk size is either [1, 25] or full context(max_len).
|
|
172
|
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
173
|
+
# delay, the maximum frame is 100 / 4 = 25.
|
|
174
|
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
|
175
|
+
num_left_chunks = -1
|
|
176
|
+
if chunk_size > max_len // 2 and enable_full_context:
|
|
177
|
+
chunk_size = max_len
|
|
178
|
+
else:
|
|
179
|
+
chunk_size = chunk_size % 25 + 1
|
|
180
|
+
if use_dynamic_left_chunk:
|
|
181
|
+
max_left_chunks = (max_len - 1) // chunk_size
|
|
182
|
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
|
183
|
+
(1, )).item()
|
|
184
|
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
|
185
|
+
num_left_chunks,
|
|
186
|
+
xs.device) # (L, L)
|
|
187
|
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
188
|
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
189
|
+
elif static_chunk_size > 0:
|
|
190
|
+
num_left_chunks = num_decoding_left_chunks
|
|
191
|
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
|
192
|
+
num_left_chunks,
|
|
193
|
+
xs.device) # (L, L)
|
|
194
|
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
195
|
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
196
|
+
else:
|
|
197
|
+
chunk_masks = masks
|
|
198
|
+
return chunk_masks
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
202
|
+
"""Make mask tensor containing indices of padded part.
|
|
203
|
+
|
|
204
|
+
See description of make_non_pad_mask.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
lengths (torch.Tensor): Batch of lengths (B,).
|
|
208
|
+
Returns:
|
|
209
|
+
torch.Tensor: Mask tensor containing indices of padded part.
|
|
210
|
+
|
|
211
|
+
Examples:
|
|
212
|
+
>>> lengths = [5, 3, 2]
|
|
213
|
+
>>> make_pad_mask(lengths)
|
|
214
|
+
masks = [[0, 0, 0, 0 ,0],
|
|
215
|
+
[0, 0, 0, 1, 1],
|
|
216
|
+
[0, 0, 1, 1, 1]]
|
|
217
|
+
"""
|
|
218
|
+
batch_size = lengths.size(0)
|
|
219
|
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
220
|
+
seq_range = torch.arange(0,
|
|
221
|
+
max_len,
|
|
222
|
+
dtype=torch.int64,
|
|
223
|
+
device=lengths.device)
|
|
224
|
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
225
|
+
seq_length_expand = lengths.unsqueeze(-1)
|
|
226
|
+
mask = seq_range_expand >= seq_length_expand
|
|
227
|
+
return mask
|