minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
import regex
|
|
17
|
+
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# whether contain chinese character
|
|
21
|
+
def contains_chinese(text):
|
|
22
|
+
return bool(chinese_char_pattern.search(text))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# replace special symbol
|
|
26
|
+
def replace_corner_mark(text):
|
|
27
|
+
text = text.replace('²', '平方')
|
|
28
|
+
text = text.replace('³', '立方')
|
|
29
|
+
return text
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# remove meaningless symbol
|
|
33
|
+
def remove_bracket(text):
|
|
34
|
+
text = text.replace('(', '').replace(')', '')
|
|
35
|
+
text = text.replace('【', '').replace('】', '')
|
|
36
|
+
text = text.replace('`', '').replace('`', '')
|
|
37
|
+
text = text.replace("——", " ")
|
|
38
|
+
return text
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# spell Arabic numerals
|
|
42
|
+
def spell_out_number(text: str, inflect_parser):
|
|
43
|
+
new_text = []
|
|
44
|
+
st = None
|
|
45
|
+
for i, c in enumerate(text):
|
|
46
|
+
if not c.isdigit():
|
|
47
|
+
if st is not None:
|
|
48
|
+
num_str = inflect_parser.number_to_words(text[st: i])
|
|
49
|
+
new_text.append(num_str)
|
|
50
|
+
st = None
|
|
51
|
+
new_text.append(c)
|
|
52
|
+
else:
|
|
53
|
+
if st is None:
|
|
54
|
+
st = i
|
|
55
|
+
if st is not None and st < len(text):
|
|
56
|
+
num_str = inflect_parser.number_to_words(text[st:])
|
|
57
|
+
new_text.append(num_str)
|
|
58
|
+
return ''.join(new_text)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# split paragrah logic:
|
|
62
|
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
|
63
|
+
# 2. cal sentence len according to lang
|
|
64
|
+
# 3. split sentence according to puncatation
|
|
65
|
+
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
|
66
|
+
def calc_utt_length(_text: str):
|
|
67
|
+
if lang == "zh":
|
|
68
|
+
return len(_text)
|
|
69
|
+
else:
|
|
70
|
+
return len(tokenize(_text))
|
|
71
|
+
|
|
72
|
+
def should_merge(_text: str):
|
|
73
|
+
if lang == "zh":
|
|
74
|
+
return len(_text) < merge_len
|
|
75
|
+
else:
|
|
76
|
+
return len(tokenize(_text)) < merge_len
|
|
77
|
+
|
|
78
|
+
if lang == "zh":
|
|
79
|
+
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
|
80
|
+
else:
|
|
81
|
+
pounc = ['.', '?', '!', ';', ':']
|
|
82
|
+
if comma_split:
|
|
83
|
+
pounc.extend([',', ','])
|
|
84
|
+
|
|
85
|
+
if text[-1] not in pounc:
|
|
86
|
+
if lang == "zh":
|
|
87
|
+
text += "。"
|
|
88
|
+
else:
|
|
89
|
+
text += "."
|
|
90
|
+
|
|
91
|
+
st = 0
|
|
92
|
+
utts = []
|
|
93
|
+
for i, c in enumerate(text):
|
|
94
|
+
if c in pounc:
|
|
95
|
+
if len(text[st: i]) > 0:
|
|
96
|
+
utts.append(text[st: i] + c)
|
|
97
|
+
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
|
98
|
+
tmp = utts.pop(-1)
|
|
99
|
+
utts.append(tmp + text[i + 1])
|
|
100
|
+
st = i + 2
|
|
101
|
+
else:
|
|
102
|
+
st = i + 1
|
|
103
|
+
|
|
104
|
+
final_utts = []
|
|
105
|
+
cur_utt = ""
|
|
106
|
+
for utt in utts:
|
|
107
|
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
|
108
|
+
final_utts.append(cur_utt)
|
|
109
|
+
cur_utt = ""
|
|
110
|
+
cur_utt = cur_utt + utt
|
|
111
|
+
if len(cur_utt) > 0:
|
|
112
|
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
|
113
|
+
final_utts[-1] = final_utts[-1] + cur_utt
|
|
114
|
+
else:
|
|
115
|
+
final_utts.append(cur_utt)
|
|
116
|
+
|
|
117
|
+
return final_utts
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# remove blank between chinese character
|
|
121
|
+
def replace_blank(text: str):
|
|
122
|
+
out_str = []
|
|
123
|
+
for i, c in enumerate(text):
|
|
124
|
+
if c == " ":
|
|
125
|
+
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
|
126
|
+
(text[i - 1].isascii() and text[i - 1] != " ")):
|
|
127
|
+
out_str.append(c)
|
|
128
|
+
else:
|
|
129
|
+
out_str.append(c)
|
|
130
|
+
return "".join(out_str)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def is_only_punctuation(text):
|
|
134
|
+
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
|
135
|
+
punctuation_pattern = r'^[\p{P}\p{S}]*$'
|
|
136
|
+
return bool(regex.fullmatch(punctuation_pattern, text))
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
|
7
|
+
loss = 0
|
|
8
|
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
|
9
|
+
m_DG = torch.median((dr - dg))
|
|
10
|
+
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
|
11
|
+
loss += tau - F.relu(tau - L_rel)
|
|
12
|
+
return loss
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def mel_loss(real_speech, generated_speech, mel_transforms):
|
|
16
|
+
loss = 0
|
|
17
|
+
for transform in mel_transforms:
|
|
18
|
+
mel_r = transform(real_speech)
|
|
19
|
+
mel_g = transform(generated_speech)
|
|
20
|
+
loss += F.l1_loss(mel_g, mel_r)
|
|
21
|
+
return loss
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DPOLoss(torch.nn.Module):
|
|
25
|
+
"""
|
|
26
|
+
DPO Loss
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.beta = beta
|
|
32
|
+
self.label_smoothing = label_smoothing
|
|
33
|
+
self.ipo = ipo
|
|
34
|
+
|
|
35
|
+
def forward(
|
|
36
|
+
self,
|
|
37
|
+
policy_chosen_logps: torch.Tensor,
|
|
38
|
+
policy_rejected_logps: torch.Tensor,
|
|
39
|
+
reference_chosen_logps: torch.Tensor,
|
|
40
|
+
reference_rejected_logps: torch.Tensor,
|
|
41
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
42
|
+
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
|
43
|
+
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
|
44
|
+
logits = pi_logratios - ref_logratios
|
|
45
|
+
if self.ipo:
|
|
46
|
+
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
|
47
|
+
else:
|
|
48
|
+
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
|
49
|
+
losses = (
|
|
50
|
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
|
51
|
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
|
52
|
+
)
|
|
53
|
+
loss = losses.mean()
|
|
54
|
+
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
|
55
|
+
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
|
56
|
+
|
|
57
|
+
return loss, chosen_rewards, rejected_rewards
|
cosyvoice/utils/mask.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
'''
|
|
19
|
+
def subsequent_mask(
|
|
20
|
+
size: int,
|
|
21
|
+
device: torch.device = torch.device("cpu"),
|
|
22
|
+
) -> torch.Tensor:
|
|
23
|
+
"""Create mask for subsequent steps (size, size).
|
|
24
|
+
|
|
25
|
+
This mask is used only in decoder which works in an auto-regressive mode.
|
|
26
|
+
This means the current step could only do attention with its left steps.
|
|
27
|
+
|
|
28
|
+
In encoder, fully attention is used when streaming is not necessary and
|
|
29
|
+
the sequence is not long. In this case, no attention mask is needed.
|
|
30
|
+
|
|
31
|
+
When streaming is need, chunk-based attention is used in encoder. See
|
|
32
|
+
subsequent_chunk_mask for the chunk-based attention mask.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
size (int): size of mask
|
|
36
|
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
37
|
+
dtype (torch.device): result dtype
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
torch.Tensor: mask
|
|
41
|
+
|
|
42
|
+
Examples:
|
|
43
|
+
>>> subsequent_mask(3)
|
|
44
|
+
[[1, 0, 0],
|
|
45
|
+
[1, 1, 0],
|
|
46
|
+
[1, 1, 1]]
|
|
47
|
+
"""
|
|
48
|
+
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
|
49
|
+
return torch.tril(ret)
|
|
50
|
+
'''
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def subsequent_mask(
|
|
54
|
+
size: int,
|
|
55
|
+
device: torch.device = torch.device("cpu"),
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
"""Create mask for subsequent steps (size, size).
|
|
58
|
+
|
|
59
|
+
This mask is used only in decoder which works in an auto-regressive mode.
|
|
60
|
+
This means the current step could only do attention with its left steps.
|
|
61
|
+
|
|
62
|
+
In encoder, fully attention is used when streaming is not necessary and
|
|
63
|
+
the sequence is not long. In this case, no attention mask is needed.
|
|
64
|
+
|
|
65
|
+
When streaming is need, chunk-based attention is used in encoder. See
|
|
66
|
+
subsequent_chunk_mask for the chunk-based attention mask.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
size (int): size of mask
|
|
70
|
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
71
|
+
dtype (torch.device): result dtype
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: mask
|
|
75
|
+
|
|
76
|
+
Examples:
|
|
77
|
+
>>> subsequent_mask(3)
|
|
78
|
+
[[1, 0, 0],
|
|
79
|
+
[1, 1, 0],
|
|
80
|
+
[1, 1, 1]]
|
|
81
|
+
"""
|
|
82
|
+
arange = torch.arange(size, device=device)
|
|
83
|
+
mask = arange.expand(size, size)
|
|
84
|
+
arange = arange.unsqueeze(-1)
|
|
85
|
+
mask = mask <= arange
|
|
86
|
+
return mask
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def subsequent_chunk_mask_deprecated(
|
|
90
|
+
size: int,
|
|
91
|
+
chunk_size: int,
|
|
92
|
+
num_left_chunks: int = -1,
|
|
93
|
+
device: torch.device = torch.device("cpu"),
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
96
|
+
this is for streaming encoder
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
size (int): size of mask
|
|
100
|
+
chunk_size (int): size of chunk
|
|
101
|
+
num_left_chunks (int): number of left chunks
|
|
102
|
+
<0: use full chunk
|
|
103
|
+
>=0: use num_left_chunks
|
|
104
|
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
torch.Tensor: mask
|
|
108
|
+
|
|
109
|
+
Examples:
|
|
110
|
+
>>> subsequent_chunk_mask(4, 2)
|
|
111
|
+
[[1, 1, 0, 0],
|
|
112
|
+
[1, 1, 0, 0],
|
|
113
|
+
[1, 1, 1, 1],
|
|
114
|
+
[1, 1, 1, 1]]
|
|
115
|
+
"""
|
|
116
|
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
|
117
|
+
for i in range(size):
|
|
118
|
+
if num_left_chunks < 0:
|
|
119
|
+
start = 0
|
|
120
|
+
else:
|
|
121
|
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
|
122
|
+
ending = min((i // chunk_size + 1) * chunk_size, size)
|
|
123
|
+
ret[i, start:ending] = True
|
|
124
|
+
return ret
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def subsequent_chunk_mask(
|
|
128
|
+
size: int,
|
|
129
|
+
chunk_size: int,
|
|
130
|
+
num_left_chunks: int = -1,
|
|
131
|
+
device: torch.device = torch.device("cpu"),
|
|
132
|
+
) -> torch.Tensor:
|
|
133
|
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
134
|
+
this is for streaming encoder
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
size (int): size of mask
|
|
138
|
+
chunk_size (int): size of chunk
|
|
139
|
+
num_left_chunks (int): number of left chunks
|
|
140
|
+
<0: use full chunk
|
|
141
|
+
>=0: use num_left_chunks
|
|
142
|
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
torch.Tensor: mask
|
|
146
|
+
|
|
147
|
+
Examples:
|
|
148
|
+
>>> subsequent_chunk_mask(4, 2)
|
|
149
|
+
[[1, 1, 0, 0],
|
|
150
|
+
[1, 1, 0, 0],
|
|
151
|
+
[1, 1, 1, 1],
|
|
152
|
+
[1, 1, 1, 1]]
|
|
153
|
+
"""
|
|
154
|
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
|
155
|
+
pos_idx = torch.arange(size, device=device)
|
|
156
|
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
|
157
|
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
|
158
|
+
return ret
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
|
162
|
+
masks: torch.Tensor,
|
|
163
|
+
use_dynamic_chunk: bool,
|
|
164
|
+
use_dynamic_left_chunk: bool,
|
|
165
|
+
decoding_chunk_size: int,
|
|
166
|
+
static_chunk_size: int,
|
|
167
|
+
num_decoding_left_chunks: int,
|
|
168
|
+
enable_full_context: bool = True):
|
|
169
|
+
""" Apply optional mask for encoder.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
|
173
|
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
|
174
|
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
175
|
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
176
|
+
training.
|
|
177
|
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
178
|
+
0: default for training, use random dynamic chunk.
|
|
179
|
+
<0: for decoding, use full chunk.
|
|
180
|
+
>0: for decoding, use fixed chunk size as set.
|
|
181
|
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
182
|
+
if it's greater than 0, if use_dynamic_chunk is true,
|
|
183
|
+
this parameter will be ignored
|
|
184
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
185
|
+
the chunk size is decoding_chunk_size.
|
|
186
|
+
>=0: use num_decoding_left_chunks
|
|
187
|
+
<0: use all left chunks
|
|
188
|
+
enable_full_context (bool):
|
|
189
|
+
True: chunk size is either [1, 25] or full context(max_len)
|
|
190
|
+
False: chunk size ~ U[1, 25]
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
torch.Tensor: chunk mask of the input xs.
|
|
194
|
+
"""
|
|
195
|
+
# Whether to use chunk mask or not
|
|
196
|
+
if use_dynamic_chunk:
|
|
197
|
+
max_len = xs.size(1)
|
|
198
|
+
if decoding_chunk_size < 0:
|
|
199
|
+
chunk_size = max_len
|
|
200
|
+
num_left_chunks = -1
|
|
201
|
+
elif decoding_chunk_size > 0:
|
|
202
|
+
chunk_size = decoding_chunk_size
|
|
203
|
+
num_left_chunks = num_decoding_left_chunks
|
|
204
|
+
else:
|
|
205
|
+
# chunk size is either [1, 25] or full context(max_len).
|
|
206
|
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
207
|
+
# delay, the maximum frame is 100 / 4 = 25.
|
|
208
|
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
|
209
|
+
num_left_chunks = -1
|
|
210
|
+
if chunk_size > max_len // 2 and enable_full_context:
|
|
211
|
+
chunk_size = max_len
|
|
212
|
+
else:
|
|
213
|
+
chunk_size = chunk_size % 25 + 1
|
|
214
|
+
if use_dynamic_left_chunk:
|
|
215
|
+
max_left_chunks = (max_len - 1) // chunk_size
|
|
216
|
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
|
217
|
+
(1, )).item()
|
|
218
|
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
|
219
|
+
num_left_chunks,
|
|
220
|
+
xs.device) # (L, L)
|
|
221
|
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
222
|
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
223
|
+
elif static_chunk_size > 0:
|
|
224
|
+
num_left_chunks = num_decoding_left_chunks
|
|
225
|
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
|
226
|
+
num_left_chunks,
|
|
227
|
+
xs.device) # (L, L)
|
|
228
|
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
229
|
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
230
|
+
else:
|
|
231
|
+
chunk_masks = masks
|
|
232
|
+
assert chunk_masks.dtype == torch.bool
|
|
233
|
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
|
234
|
+
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
|
235
|
+
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
|
236
|
+
return chunk_masks
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
240
|
+
"""Make mask tensor containing indices of padded part.
|
|
241
|
+
|
|
242
|
+
See description of make_non_pad_mask.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
lengths (torch.Tensor): Batch of lengths (B,).
|
|
246
|
+
Returns:
|
|
247
|
+
torch.Tensor: Mask tensor containing indices of padded part.
|
|
248
|
+
|
|
249
|
+
Examples:
|
|
250
|
+
>>> lengths = [5, 3, 2]
|
|
251
|
+
>>> make_pad_mask(lengths)
|
|
252
|
+
masks = [[0, 0, 0, 0 ,0],
|
|
253
|
+
[0, 0, 0, 1, 1],
|
|
254
|
+
[0, 0, 1, 1, 1]]
|
|
255
|
+
"""
|
|
256
|
+
batch_size = lengths.size(0)
|
|
257
|
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
258
|
+
seq_range = torch.arange(0,
|
|
259
|
+
max_len,
|
|
260
|
+
dtype=torch.int64,
|
|
261
|
+
device=lengths.device)
|
|
262
|
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
263
|
+
seq_length_expand = lengths.unsqueeze(-1)
|
|
264
|
+
mask = seq_range_expand >= seq_length_expand
|
|
265
|
+
return mask
|