minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,328 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache is not None and cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
297
+ k = torch.cat([key_cache, k], dim=2)
298
+ v = torch.cat([value_cache, v], dim=2)
299
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
300
+ # non-trivial to calculate `next_cache_start` here.
301
+ new_cache = torch.cat((k, v), dim=-1)
302
+
303
+ n_batch_pos = pos_emb.size(0)
304
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
305
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
306
+
307
+ # (batch, head, time1, d_k)
308
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
311
+
312
+ # compute attention score
313
+ # first compute matrix a and matrix c
314
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
315
+ # (batch, head, time1, time2)
316
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
317
+
318
+ # compute matrix b and matrix d
319
+ # (batch, head, time1, time2)
320
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
321
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
322
+ if matrix_ac.shape != matrix_bd.shape:
323
+ matrix_bd = self.rel_shift(matrix_bd)
324
+
325
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
326
+ self.d_k) # (batch, head, time1, time2)
327
+
328
+ return self.forward_attention(v, scores, mask), new_cache
@@ -0,0 +1,119 @@
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class EspnetRelPositionalEncoding(torch.nn.Module):
27
+ """Relative positional encoding module (new implementation).
28
+
29
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
30
+
31
+ See : Appendix B in https://arxiv.org/abs/1901.02860
32
+
33
+ Args:
34
+ d_model (int): Embedding dimension.
35
+ dropout_rate (float): Dropout rate.
36
+ max_len (int): Maximum input length.
37
+
38
+ """
39
+
40
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
41
+ """Construct an PositionalEncoding object."""
42
+ super(EspnetRelPositionalEncoding, self).__init__()
43
+ self.d_model = d_model
44
+ self.xscale = math.sqrt(self.d_model)
45
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
46
+ self.pe = None
47
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
48
+
49
+ def extend_pe(self, x: torch.Tensor):
50
+ """Reset the positional encodings."""
51
+ if self.pe is not None:
52
+ # self.pe contains both positive and negative parts
53
+ # the length of self.pe is 2 * input_len - 1
54
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
55
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
56
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
57
+ return
58
+ # Suppose `i` means to the position of query vecotr and `j` means the
59
+ # position of key vector. We use position relative positions when keys
60
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
61
+ pe_positive = torch.zeros(x.size(1), self.d_model)
62
+ pe_negative = torch.zeros(x.size(1), self.d_model)
63
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
64
+ div_term = torch.exp(
65
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
66
+ * -(math.log(10000.0) / self.d_model)
67
+ )
68
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
69
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
70
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
71
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
72
+
73
+ # Reserve the order of positive indices and concat both positive and
74
+ # negative indices. This is used to support the shifting trick
75
+ # as in https://arxiv.org/abs/1901.02860
76
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
77
+ pe_negative = pe_negative[1:].unsqueeze(0)
78
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
79
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
80
+
81
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
82
+ -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """Add positional encoding.
84
+
85
+ Args:
86
+ x (torch.Tensor): Input tensor (batch, time, `*`).
87
+
88
+ Returns:
89
+ torch.Tensor: Encoded tensor (batch, time, `*`).
90
+
91
+ """
92
+ self.extend_pe(x)
93
+ x = x * self.xscale
94
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
95
+ return self.dropout(x), self.dropout(pos_emb)
96
+
97
+ def position_encoding(self,
98
+ offset: Union[int, torch.Tensor],
99
+ size: int) -> torch.Tensor:
100
+ """ For getting encoding in a streaming fashion
101
+
102
+ Attention!!!!!
103
+ we apply dropout only once at the whole utterance level in a none
104
+ streaming way, but will call this function several times with
105
+ increasing input size in a streaming scenario, so the dropout will
106
+ be applied several times.
107
+
108
+ Args:
109
+ offset (int or torch.tensor): start offset
110
+ size (int): required size of position encoding
111
+
112
+ Returns:
113
+ torch.Tensor: Corresponding encoding
114
+ """
115
+ pos_emb = self.pe[
116
+ :,
117
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
118
+ ]
119
+ return pos_emb
@@ -0,0 +1,163 @@
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+ Args:
27
+ size (int): Input dimension.
28
+ self_attn (torch.nn.Module): Self-attention module instance.
29
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
30
+ instance can be used as the argument.
31
+ feed_forward (torch.nn.Module): Feed-forward module instance.
32
+ `PositionwiseFeedForward` instance can be used as the argument.
33
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
34
+ instance.
35
+ `PositionwiseFeedForward` instance can be used as the argument.
36
+ conv_module (torch.nn.Module): Convolution module instance.
37
+ `ConvlutionModule` instance can be used as the argument.
38
+ dropout_rate (float): Dropout rate.
39
+ normalize_before (bool):
40
+ True: use layer_norm before each sub-block.
41
+ False: use layer_norm after each sub-block.
42
+ enable_cuda_graph (bool): Control whether to enable CUDA Graph.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ size: int,
48
+ self_attn: torch.nn.Module,
49
+ feed_forward: Optional[nn.Module] = None,
50
+ feed_forward_macaron: Optional[nn.Module] = None,
51
+ conv_module: Optional[nn.Module] = None,
52
+ dropout_rate: float = 0.1,
53
+ normalize_before: bool = True,
54
+ ):
55
+ """Construct an EncoderLayer object."""
56
+ super().__init__()
57
+ self.self_attn = self_attn
58
+ self.feed_forward = feed_forward
59
+ self.feed_forward_macaron = feed_forward_macaron
60
+ self.conv_module = conv_module
61
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
62
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
63
+ if feed_forward_macaron is not None:
64
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
65
+ self.ff_scale = 0.5
66
+ else:
67
+ self.ff_scale = 1.0
68
+ if self.conv_module is not None:
69
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
70
+ self.norm_final = nn.LayerNorm(
71
+ size, eps=1e-12) # for the final output of the block
72
+ self.dropout = nn.Dropout(dropout_rate)
73
+ self.size = size
74
+ self.normalize_before = normalize_before
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ mask: torch.Tensor,
80
+ pos_emb: torch.Tensor,
81
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
82
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
83
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
84
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ """Compute encoded features.
86
+
87
+ Args:
88
+ x (torch.Tensor): (#batch, time, size)
89
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
90
+ (0, 0, 0) means fake mask.
91
+ pos_emb (torch.Tensor): positional encoding, must not be None
92
+ for ConformerEncoderLayer.
93
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
94
+ (#batch, 1,time), (0, 0, 0) means fake mask.
95
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
96
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
97
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
98
+ (#batch=1, size, cache_t2)
99
+ Returns:
100
+ torch.Tensor: Output tensor (#batch, time, size).
101
+ torch.Tensor: Mask tensor (#batch, time, time).
102
+ torch.Tensor: att_cache tensor,
103
+ (#batch=1, head, cache_t1 + time, d_k * 2).
104
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
105
+ """
106
+ return self._forward_impl(x, mask, pos_emb, mask_pad, att_cache, cnn_cache)
107
+
108
+ def _forward_impl(
109
+ self,
110
+ x: torch.Tensor,
111
+ mask: torch.Tensor,
112
+ pos_emb: torch.Tensor,
113
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
114
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
115
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
116
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
117
+ """原始的前向传播实现"""
118
+ # whether to use macaron style
119
+ if self.feed_forward_macaron is not None:
120
+ residual = x
121
+ if self.normalize_before:
122
+ x = self.norm_ff_macaron(x)
123
+ x = residual + self.ff_scale * self.dropout(
124
+ self.feed_forward_macaron(x))
125
+ if not self.normalize_before:
126
+ x = self.norm_ff_macaron(x)
127
+
128
+ # multi-headed self-attention module
129
+ residual = x
130
+ if self.normalize_before:
131
+ x = self.norm_mha(x)
132
+ # att_cache: (b, head, cache_t, d_k*2)
133
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
134
+ att_cache)
135
+ x = residual + self.dropout(x_att)
136
+ if not self.normalize_before:
137
+ x = self.norm_mha(x)
138
+
139
+ # convolution module
140
+ # Fake new cnn cache here, and then change it in conv_module
141
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
142
+ if self.conv_module is not None:
143
+ residual = x
144
+ if self.normalize_before:
145
+ x = self.norm_conv(x)
146
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
147
+ x = residual + self.dropout(x)
148
+
149
+ if not self.normalize_before:
150
+ x = self.norm_conv(x)
151
+
152
+ # feed forward module
153
+ residual = x
154
+ if self.normalize_before:
155
+ x = self.norm_ff(x)
156
+
157
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
158
+ if not self.normalize_before:
159
+ x = self.norm_ff(x)
160
+
161
+ if self.conv_module is not None:
162
+ x = self.norm_final(x)
163
+ return x, mask, new_att_cache, new_cnn_cache
@@ -0,0 +1,56 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+