minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,136 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ import regex
17
+ chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
18
+
19
+
20
+ # whether contain chinese character
21
+ def contains_chinese(text):
22
+ return bool(chinese_char_pattern.search(text))
23
+
24
+
25
+ # replace special symbol
26
+ def replace_corner_mark(text):
27
+ text = text.replace('²', '平方')
28
+ text = text.replace('³', '立方')
29
+ return text
30
+
31
+
32
+ # remove meaningless symbol
33
+ def remove_bracket(text):
34
+ text = text.replace('(', '').replace(')', '')
35
+ text = text.replace('【', '').replace('】', '')
36
+ text = text.replace('`', '').replace('`', '')
37
+ text = text.replace("——", " ")
38
+ return text
39
+
40
+
41
+ # spell Arabic numerals
42
+ def spell_out_number(text: str, inflect_parser):
43
+ new_text = []
44
+ st = None
45
+ for i, c in enumerate(text):
46
+ if not c.isdigit():
47
+ if st is not None:
48
+ num_str = inflect_parser.number_to_words(text[st: i])
49
+ new_text.append(num_str)
50
+ st = None
51
+ new_text.append(c)
52
+ else:
53
+ if st is None:
54
+ st = i
55
+ if st is not None and st < len(text):
56
+ num_str = inflect_parser.number_to_words(text[st:])
57
+ new_text.append(num_str)
58
+ return ''.join(new_text)
59
+
60
+
61
+ # split paragrah logic:
62
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
63
+ # 2. cal sentence len according to lang
64
+ # 3. split sentence according to puncatation
65
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
66
+ def calc_utt_length(_text: str):
67
+ if lang == "zh":
68
+ return len(_text)
69
+ else:
70
+ return len(tokenize(_text))
71
+
72
+ def should_merge(_text: str):
73
+ if lang == "zh":
74
+ return len(_text) < merge_len
75
+ else:
76
+ return len(tokenize(_text)) < merge_len
77
+
78
+ if lang == "zh":
79
+ pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
80
+ else:
81
+ pounc = ['.', '?', '!', ';', ':']
82
+ if comma_split:
83
+ pounc.extend([',', ','])
84
+
85
+ if text[-1] not in pounc:
86
+ if lang == "zh":
87
+ text += "。"
88
+ else:
89
+ text += "."
90
+
91
+ st = 0
92
+ utts = []
93
+ for i, c in enumerate(text):
94
+ if c in pounc:
95
+ if len(text[st: i]) > 0:
96
+ utts.append(text[st: i] + c)
97
+ if i + 1 < len(text) and text[i + 1] in ['"', '”']:
98
+ tmp = utts.pop(-1)
99
+ utts.append(tmp + text[i + 1])
100
+ st = i + 2
101
+ else:
102
+ st = i + 1
103
+
104
+ final_utts = []
105
+ cur_utt = ""
106
+ for utt in utts:
107
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
108
+ final_utts.append(cur_utt)
109
+ cur_utt = ""
110
+ cur_utt = cur_utt + utt
111
+ if len(cur_utt) > 0:
112
+ if should_merge(cur_utt) and len(final_utts) != 0:
113
+ final_utts[-1] = final_utts[-1] + cur_utt
114
+ else:
115
+ final_utts.append(cur_utt)
116
+
117
+ return final_utts
118
+
119
+
120
+ # remove blank between chinese character
121
+ def replace_blank(text: str):
122
+ out_str = []
123
+ for i, c in enumerate(text):
124
+ if c == " ":
125
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
126
+ (text[i - 1].isascii() and text[i - 1] != " ")):
127
+ out_str.append(c)
128
+ else:
129
+ out_str.append(c)
130
+ return "".join(out_str)
131
+
132
+
133
+ def is_only_punctuation(text):
134
+ # Regular expression: Match strings that consist only of punctuation marks or are empty.
135
+ punctuation_pattern = r'^[\p{P}\p{S}]*$'
136
+ return bool(regex.fullmatch(punctuation_pattern, text))
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Tuple
4
+
5
+
6
+ def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
7
+ loss = 0
8
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
9
+ m_DG = torch.median((dr - dg))
10
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
11
+ loss += tau - F.relu(tau - L_rel)
12
+ return loss
13
+
14
+
15
+ def mel_loss(real_speech, generated_speech, mel_transforms):
16
+ loss = 0
17
+ for transform in mel_transforms:
18
+ mel_r = transform(real_speech)
19
+ mel_g = transform(generated_speech)
20
+ loss += F.l1_loss(mel_g, mel_r)
21
+ return loss
22
+
23
+
24
+ class DPOLoss(torch.nn.Module):
25
+ """
26
+ DPO Loss
27
+ """
28
+
29
+ def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
30
+ super().__init__()
31
+ self.beta = beta
32
+ self.label_smoothing = label_smoothing
33
+ self.ipo = ipo
34
+
35
+ def forward(
36
+ self,
37
+ policy_chosen_logps: torch.Tensor,
38
+ policy_rejected_logps: torch.Tensor,
39
+ reference_chosen_logps: torch.Tensor,
40
+ reference_rejected_logps: torch.Tensor,
41
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
42
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
43
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
44
+ logits = pi_logratios - ref_logratios
45
+ if self.ipo:
46
+ losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
47
+ else:
48
+ # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
49
+ losses = (
50
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
51
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
52
+ )
53
+ loss = losses.mean()
54
+ chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
55
+ rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
56
+
57
+ return loss, chosen_rewards, rejected_rewards
@@ -0,0 +1,265 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ '''
19
+ def subsequent_mask(
20
+ size: int,
21
+ device: torch.device = torch.device("cpu"),
22
+ ) -> torch.Tensor:
23
+ """Create mask for subsequent steps (size, size).
24
+
25
+ This mask is used only in decoder which works in an auto-regressive mode.
26
+ This means the current step could only do attention with its left steps.
27
+
28
+ In encoder, fully attention is used when streaming is not necessary and
29
+ the sequence is not long. In this case, no attention mask is needed.
30
+
31
+ When streaming is need, chunk-based attention is used in encoder. See
32
+ subsequent_chunk_mask for the chunk-based attention mask.
33
+
34
+ Args:
35
+ size (int): size of mask
36
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
37
+ dtype (torch.device): result dtype
38
+
39
+ Returns:
40
+ torch.Tensor: mask
41
+
42
+ Examples:
43
+ >>> subsequent_mask(3)
44
+ [[1, 0, 0],
45
+ [1, 1, 0],
46
+ [1, 1, 1]]
47
+ """
48
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
49
+ return torch.tril(ret)
50
+ '''
51
+
52
+
53
+ def subsequent_mask(
54
+ size: int,
55
+ device: torch.device = torch.device("cpu"),
56
+ ) -> torch.Tensor:
57
+ """Create mask for subsequent steps (size, size).
58
+
59
+ This mask is used only in decoder which works in an auto-regressive mode.
60
+ This means the current step could only do attention with its left steps.
61
+
62
+ In encoder, fully attention is used when streaming is not necessary and
63
+ the sequence is not long. In this case, no attention mask is needed.
64
+
65
+ When streaming is need, chunk-based attention is used in encoder. See
66
+ subsequent_chunk_mask for the chunk-based attention mask.
67
+
68
+ Args:
69
+ size (int): size of mask
70
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
71
+ dtype (torch.device): result dtype
72
+
73
+ Returns:
74
+ torch.Tensor: mask
75
+
76
+ Examples:
77
+ >>> subsequent_mask(3)
78
+ [[1, 0, 0],
79
+ [1, 1, 0],
80
+ [1, 1, 1]]
81
+ """
82
+ arange = torch.arange(size, device=device)
83
+ mask = arange.expand(size, size)
84
+ arange = arange.unsqueeze(-1)
85
+ mask = mask <= arange
86
+ return mask
87
+
88
+
89
+ def subsequent_chunk_mask_deprecated(
90
+ size: int,
91
+ chunk_size: int,
92
+ num_left_chunks: int = -1,
93
+ device: torch.device = torch.device("cpu"),
94
+ ) -> torch.Tensor:
95
+ """Create mask for subsequent steps (size, size) with chunk size,
96
+ this is for streaming encoder
97
+
98
+ Args:
99
+ size (int): size of mask
100
+ chunk_size (int): size of chunk
101
+ num_left_chunks (int): number of left chunks
102
+ <0: use full chunk
103
+ >=0: use num_left_chunks
104
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
105
+
106
+ Returns:
107
+ torch.Tensor: mask
108
+
109
+ Examples:
110
+ >>> subsequent_chunk_mask(4, 2)
111
+ [[1, 1, 0, 0],
112
+ [1, 1, 0, 0],
113
+ [1, 1, 1, 1],
114
+ [1, 1, 1, 1]]
115
+ """
116
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
117
+ for i in range(size):
118
+ if num_left_chunks < 0:
119
+ start = 0
120
+ else:
121
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
122
+ ending = min((i // chunk_size + 1) * chunk_size, size)
123
+ ret[i, start:ending] = True
124
+ return ret
125
+
126
+
127
+ def subsequent_chunk_mask(
128
+ size: int,
129
+ chunk_size: int,
130
+ num_left_chunks: int = -1,
131
+ device: torch.device = torch.device("cpu"),
132
+ ) -> torch.Tensor:
133
+ """Create mask for subsequent steps (size, size) with chunk size,
134
+ this is for streaming encoder
135
+
136
+ Args:
137
+ size (int): size of mask
138
+ chunk_size (int): size of chunk
139
+ num_left_chunks (int): number of left chunks
140
+ <0: use full chunk
141
+ >=0: use num_left_chunks
142
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
143
+
144
+ Returns:
145
+ torch.Tensor: mask
146
+
147
+ Examples:
148
+ >>> subsequent_chunk_mask(4, 2)
149
+ [[1, 1, 0, 0],
150
+ [1, 1, 0, 0],
151
+ [1, 1, 1, 1],
152
+ [1, 1, 1, 1]]
153
+ """
154
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
155
+ pos_idx = torch.arange(size, device=device)
156
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
157
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
158
+ return ret
159
+
160
+
161
+ def add_optional_chunk_mask(xs: torch.Tensor,
162
+ masks: torch.Tensor,
163
+ use_dynamic_chunk: bool,
164
+ use_dynamic_left_chunk: bool,
165
+ decoding_chunk_size: int,
166
+ static_chunk_size: int,
167
+ num_decoding_left_chunks: int,
168
+ enable_full_context: bool = True):
169
+ """ Apply optional mask for encoder.
170
+
171
+ Args:
172
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
173
+ mask (torch.Tensor): mask for xs, (B, 1, L)
174
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
175
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
176
+ training.
177
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
178
+ 0: default for training, use random dynamic chunk.
179
+ <0: for decoding, use full chunk.
180
+ >0: for decoding, use fixed chunk size as set.
181
+ static_chunk_size (int): chunk size for static chunk training/decoding
182
+ if it's greater than 0, if use_dynamic_chunk is true,
183
+ this parameter will be ignored
184
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
185
+ the chunk size is decoding_chunk_size.
186
+ >=0: use num_decoding_left_chunks
187
+ <0: use all left chunks
188
+ enable_full_context (bool):
189
+ True: chunk size is either [1, 25] or full context(max_len)
190
+ False: chunk size ~ U[1, 25]
191
+
192
+ Returns:
193
+ torch.Tensor: chunk mask of the input xs.
194
+ """
195
+ # Whether to use chunk mask or not
196
+ if use_dynamic_chunk:
197
+ max_len = xs.size(1)
198
+ if decoding_chunk_size < 0:
199
+ chunk_size = max_len
200
+ num_left_chunks = -1
201
+ elif decoding_chunk_size > 0:
202
+ chunk_size = decoding_chunk_size
203
+ num_left_chunks = num_decoding_left_chunks
204
+ else:
205
+ # chunk size is either [1, 25] or full context(max_len).
206
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
207
+ # delay, the maximum frame is 100 / 4 = 25.
208
+ chunk_size = torch.randint(1, max_len, (1, )).item()
209
+ num_left_chunks = -1
210
+ if chunk_size > max_len // 2 and enable_full_context:
211
+ chunk_size = max_len
212
+ else:
213
+ chunk_size = chunk_size % 25 + 1
214
+ if use_dynamic_left_chunk:
215
+ max_left_chunks = (max_len - 1) // chunk_size
216
+ num_left_chunks = torch.randint(0, max_left_chunks,
217
+ (1, )).item()
218
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
219
+ num_left_chunks,
220
+ xs.device) # (L, L)
221
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
222
+ chunk_masks = masks & chunk_masks # (B, L, L)
223
+ elif static_chunk_size > 0:
224
+ num_left_chunks = num_decoding_left_chunks
225
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
226
+ num_left_chunks,
227
+ xs.device) # (L, L)
228
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
229
+ chunk_masks = masks & chunk_masks # (B, L, L)
230
+ else:
231
+ chunk_masks = masks
232
+ assert chunk_masks.dtype == torch.bool
233
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
234
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
235
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
236
+ return chunk_masks
237
+
238
+
239
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
240
+ """Make mask tensor containing indices of padded part.
241
+
242
+ See description of make_non_pad_mask.
243
+
244
+ Args:
245
+ lengths (torch.Tensor): Batch of lengths (B,).
246
+ Returns:
247
+ torch.Tensor: Mask tensor containing indices of padded part.
248
+
249
+ Examples:
250
+ >>> lengths = [5, 3, 2]
251
+ >>> make_pad_mask(lengths)
252
+ masks = [[0, 0, 0, 0 ,0],
253
+ [0, 0, 0, 1, 1],
254
+ [0, 0, 1, 1, 1]]
255
+ """
256
+ batch_size = lengths.size(0)
257
+ max_len = max_len if max_len > 0 else lengths.max().item()
258
+ seq_range = torch.arange(0,
259
+ max_len,
260
+ dtype=torch.int64,
261
+ device=lengths.device)
262
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
263
+ seq_length_expand = lengths.unsqueeze(-1)
264
+ mask = seq_range_expand >= seq_length_expand
265
+ return mask