minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,998 @@
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def subsequent_chunk_mask(
10
+ size: int,
11
+ chunk_size: int,
12
+ num_left_chunks: int = -1,
13
+ device: torch.device = torch.device("cpu"),
14
+ ) -> torch.Tensor:
15
+ """Create mask for subsequent steps (size, size) with chunk size,
16
+ this is for streaming encoder
17
+
18
+ Args:
19
+ size (int): size of mask
20
+ chunk_size (int): size of chunk
21
+ num_left_chunks (int): number of left chunks
22
+ <0: use full chunk
23
+ >=0: use num_left_chunks
24
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
25
+
26
+ Returns:
27
+ torch.Tensor: mask
28
+
29
+ Examples:
30
+ >>> subsequent_chunk_mask(4, 2)
31
+ [[1, 1, 0, 0],
32
+ [1, 1, 0, 0],
33
+ [1, 1, 1, 1],
34
+ [1, 1, 1, 1]]
35
+ """
36
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
37
+ pos_idx = torch.arange(size, device=device)
38
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
39
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
40
+ return ret
41
+
42
+
43
+ def add_optional_chunk_mask(xs: torch.Tensor,
44
+ masks: torch.Tensor,
45
+ use_dynamic_chunk: bool,
46
+ use_dynamic_left_chunk: bool,
47
+ decoding_chunk_size: int,
48
+ static_chunk_size: int,
49
+ num_decoding_left_chunks: int,
50
+ enable_full_context: bool = True):
51
+ """ Apply optional mask for encoder.
52
+
53
+ Args:
54
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
55
+ mask (torch.Tensor): mask for xs, (B, 1, L)
56
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
57
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
58
+ training.
59
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
60
+ 0: default for training, use random dynamic chunk.
61
+ <0: for decoding, use full chunk.
62
+ >0: for decoding, use fixed chunk size as set.
63
+ static_chunk_size (int): chunk size for static chunk training/decoding
64
+ if it's greater than 0, if use_dynamic_chunk is true,
65
+ this parameter will be ignored
66
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
67
+ the chunk size is decoding_chunk_size.
68
+ >=0: use num_decoding_left_chunks
69
+ <0: use all left chunks
70
+ enable_full_context (bool):
71
+ True: chunk size is either [1, 25] or full context(max_len)
72
+ False: chunk size ~ U[1, 25]
73
+
74
+ Returns:
75
+ torch.Tensor: chunk mask of the input xs.
76
+ """
77
+ # Whether to use chunk mask or not
78
+ if use_dynamic_chunk:
79
+ max_len = xs.size(1)
80
+ if decoding_chunk_size < 0:
81
+ chunk_size = max_len
82
+ num_left_chunks = -1
83
+ elif decoding_chunk_size > 0:
84
+ chunk_size = decoding_chunk_size
85
+ num_left_chunks = num_decoding_left_chunks
86
+ else:
87
+ # chunk size is either [1, 25] or full context(max_len).
88
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
89
+ # delay, the maximum frame is 100 / 4 = 25.
90
+ chunk_size = torch.randint(1, max_len, (1, )).item()
91
+ num_left_chunks = -1
92
+ if chunk_size > max_len // 2 and enable_full_context:
93
+ chunk_size = max_len
94
+ else:
95
+ chunk_size = chunk_size % 25 + 1
96
+ if use_dynamic_left_chunk:
97
+ max_left_chunks = (max_len - 1) // chunk_size
98
+ num_left_chunks = torch.randint(0, max_left_chunks,
99
+ (1, )).item()
100
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
101
+ num_left_chunks,
102
+ xs.device) # (L, L)
103
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
104
+ chunk_masks = masks & chunk_masks # (B, L, L)
105
+ elif static_chunk_size > 0:
106
+ num_left_chunks = num_decoding_left_chunks
107
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
108
+ num_left_chunks,
109
+ xs.device) # (L, L)
110
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
111
+ chunk_masks = masks & chunk_masks # (B, L, L)
112
+ else:
113
+ chunk_masks = masks
114
+ assert chunk_masks.dtype == torch.bool
115
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
116
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
117
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
118
+ return chunk_masks
119
+
120
+
121
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
122
+ """Make mask tensor containing indices of padded part.
123
+
124
+ See description of make_non_pad_mask.
125
+
126
+ Args:
127
+ lengths (torch.Tensor): Batch of lengths (B,).
128
+ Returns:
129
+ torch.Tensor: Mask tensor containing indices of padded part.
130
+
131
+ Examples:
132
+ >>> lengths = [5, 3, 2]
133
+ >>> make_pad_mask(lengths)
134
+ masks = [[0, 0, 0, 0 ,0],
135
+ [0, 0, 0, 1, 1],
136
+ [0, 0, 1, 1, 1]]
137
+ """
138
+ batch_size = lengths.size(0)
139
+ max_len = max_len if max_len > 0 else lengths.max().item()
140
+ seq_range = torch.arange(0,
141
+ max_len,
142
+ dtype=torch.int64,
143
+ device=lengths.device)
144
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
145
+ seq_length_expand = lengths.unsqueeze(-1)
146
+ mask = seq_range_expand >= seq_length_expand
147
+ return mask
148
+
149
+
150
+ class EspnetRelPositionalEncoding(torch.nn.Module):
151
+ """Relative positional encoding module (new implementation).
152
+
153
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
154
+
155
+ See : Appendix B in https://arxiv.org/abs/1901.02860
156
+
157
+ Args:
158
+ d_model (int): Embedding dimension.
159
+ max_len (int): Maximum input length.
160
+
161
+ """
162
+
163
+ def __init__(self, d_model: int, max_len: int = 5000):
164
+ super(EspnetRelPositionalEncoding, self).__init__()
165
+ self.d_model = d_model
166
+ self.xscale = math.sqrt(self.d_model)
167
+ self.pe = None
168
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
169
+
170
+ def extend_pe(self, x: torch.Tensor):
171
+ """Reset the positional encodings."""
172
+ if self.pe is not None:
173
+ # self.pe contains both positive and negative parts
174
+ # the length of self.pe is 2 * input_len - 1
175
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
176
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
177
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
178
+ return
179
+ # Suppose `i` means to the position of query vecotr and `j` means the
180
+ # position of key vector. We use position relative positions when keys
181
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
182
+ pe_positive = torch.zeros(x.size(1), self.d_model)
183
+ pe_negative = torch.zeros(x.size(1), self.d_model)
184
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
185
+ div_term = torch.exp(
186
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
187
+ * -(math.log(10000.0) / self.d_model)
188
+ )
189
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
190
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
191
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
192
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
193
+
194
+ # Reserve the order of positive indices and concat both positive and
195
+ # negative indices. This is used to support the shifting trick
196
+ # as in https://arxiv.org/abs/1901.02860
197
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
198
+ pe_negative = pe_negative[1:].unsqueeze(0)
199
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
200
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
201
+
202
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
203
+ -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """Add positional encoding.
205
+
206
+ Args:
207
+ x (torch.Tensor): Input tensor (batch, time, `*`).
208
+
209
+ Returns:
210
+ torch.Tensor: Encoded tensor (batch, time, `*`).
211
+
212
+ """
213
+ self.extend_pe(x)
214
+ x = x * self.xscale
215
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
216
+ return x, pos_emb
217
+
218
+ def position_encoding(self,
219
+ offset: Union[int, torch.Tensor],
220
+ size: int) -> torch.Tensor:
221
+ """ For getting encoding in a streaming fashion
222
+
223
+ Attention!!!!!
224
+ we apply dropout only once at the whole utterance level in a none
225
+ streaming way, but will call this function several times with
226
+ increasing input size in a streaming scenario, so the dropout will
227
+ be applied several times.
228
+
229
+ Args:
230
+ offset (int or torch.tensor): start offset
231
+ size (int): required size of position encoding
232
+
233
+ Returns:
234
+ torch.Tensor: Corresponding encoding
235
+ """
236
+ # How to subscript a Union type:
237
+ # https://github.com/pytorch/pytorch/issues/69434
238
+ if isinstance(offset, int):
239
+ pos_emb = self.pe[
240
+ :,
241
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
242
+ ]
243
+ elif isinstance(offset, torch.Tensor):
244
+ pos_emb = self.pe[
245
+ :,
246
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
247
+ ]
248
+ return pos_emb
249
+
250
+
251
+ class LinearNoSubsampling(torch.nn.Module):
252
+ """Linear transform the input without subsampling
253
+
254
+ Args:
255
+ idim (int): Input dimension.
256
+ odim (int): Output dimension.
257
+ pos_enc_class (torch.nn.Module): Positional encoding class.
258
+
259
+ """
260
+
261
+ def __init__(self, idim: int, odim: int,
262
+ pos_enc_class: torch.nn.Module):
263
+ super().__init__()
264
+ self.out = torch.nn.Sequential(
265
+ torch.nn.Linear(idim, odim),
266
+ torch.nn.LayerNorm(odim, eps=1e-5),
267
+ )
268
+ self.pos_enc = pos_enc_class
269
+ self.right_context = 0
270
+ self.subsampling_rate = 1
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ x_mask: torch.Tensor,
276
+ offset: Union[int, torch.Tensor] = 0
277
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
+ """Input x.
279
+
280
+ Args:
281
+ x (torch.Tensor): Input tensor (#batch, time, idim).
282
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
283
+
284
+ Returns:
285
+ torch.Tensor: linear input tensor (#batch, time', odim),
286
+ where time' = time .
287
+ torch.Tensor: linear input mask (#batch, 1, time'),
288
+ where time' = time .
289
+
290
+ """
291
+ x = self.out(x)
292
+ x, pos_emb = self.pos_enc(x, offset)
293
+ return x, pos_emb, x_mask
294
+
295
+ def position_encoding(self, offset: Union[int, torch.Tensor],
296
+ size: int) -> torch.Tensor:
297
+ return self.pos_enc.position_encoding(offset, size)
298
+
299
+
300
+ class Upsample1D(nn.Module):
301
+ """A 1D upsampling layer with an optional convolution.
302
+
303
+ Parameters:
304
+ channels (`int`):
305
+ number of channels in the inputs and outputs.
306
+ use_conv (`bool`, default `False`):
307
+ option to use a convolution.
308
+ use_conv_transpose (`bool`, default `False`):
309
+ option to use a convolution transpose.
310
+ out_channels (`int`, optional):
311
+ number of output channels. Defaults to `channels`.
312
+ """
313
+
314
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
315
+ super().__init__()
316
+ self.channels = channels
317
+ self.out_channels = out_channels
318
+ self.stride = stride
319
+ # In this mode, first repeat interpolate, than conv with stride=1
320
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
321
+
322
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
323
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
324
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
325
+ outputs = self.conv(outputs)
326
+ return outputs, input_lengths * self.stride
327
+
328
+
329
+ class PreLookaheadLayer(nn.Module):
330
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
331
+ super().__init__()
332
+ self.channels = channels
333
+ self.pre_lookahead_len = pre_lookahead_len
334
+ self.conv1 = nn.Conv1d(
335
+ channels, channels,
336
+ kernel_size=pre_lookahead_len + 1,
337
+ stride=1, padding=0,
338
+ )
339
+ self.conv2 = nn.Conv1d(
340
+ channels, channels,
341
+ kernel_size=3, stride=1, padding=0,
342
+ )
343
+
344
+ def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
345
+ """
346
+ inputs: (batch_size, seq_len, channels)
347
+ """
348
+ outputs = inputs.transpose(1, 2).contiguous()
349
+ context = context.transpose(1, 2).contiguous()
350
+ # look ahead
351
+ if context.size(2) == 0:
352
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
353
+ else:
354
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
355
+ assert context.size(2) == self.pre_lookahead_len
356
+ outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
357
+ outputs = F.leaky_relu(self.conv1(outputs))
358
+ # outputs
359
+ outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
360
+ outputs = self.conv2(outputs)
361
+ outputs = outputs.transpose(1, 2).contiguous()
362
+
363
+ # residual connection
364
+ outputs = outputs + inputs
365
+ return outputs
366
+
367
+
368
+ class MultiHeadedAttention(nn.Module):
369
+ """Multi-Head Attention layer.
370
+
371
+ Args:
372
+ n_head (int): The number of heads.
373
+ n_feat (int): The number of features.
374
+ dropout_rate (float): Dropout rate.
375
+ key_bias (bool): Whether to use bias in key linear layer.
376
+
377
+ """
378
+
379
+ def __init__(self,
380
+ n_head: int,
381
+ n_feat: int,
382
+ dropout_rate: float,
383
+ key_bias: bool = True):
384
+ super().__init__()
385
+ assert n_feat % n_head == 0
386
+ # We assume d_v always equals d_k
387
+ self.d_k = n_feat // n_head
388
+ self.h = n_head
389
+ self.linear_q = nn.Linear(n_feat, n_feat)
390
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
391
+ self.linear_v = nn.Linear(n_feat, n_feat)
392
+ self.linear_out = nn.Linear(n_feat, n_feat)
393
+ self.dropout = nn.Dropout(p=dropout_rate)
394
+
395
+ def forward_qkv(
396
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
397
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
398
+ """Transform query, key and value.
399
+
400
+ Args:
401
+ query (torch.Tensor): Query tensor (#batch, time1, size).
402
+ key (torch.Tensor): Key tensor (#batch, time2, size).
403
+ value (torch.Tensor): Value tensor (#batch, time2, size).
404
+
405
+ Returns:
406
+ torch.Tensor: Transformed query tensor, size
407
+ (#batch, n_head, time1, d_k).
408
+ torch.Tensor: Transformed key tensor, size
409
+ (#batch, n_head, time2, d_k).
410
+ torch.Tensor: Transformed value tensor, size
411
+ (#batch, n_head, time2, d_k).
412
+
413
+ """
414
+ n_batch = query.size(0)
415
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
416
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
417
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
418
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
419
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
420
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
421
+
422
+ return q, k, v
423
+
424
+ def forward_attention(
425
+ self,
426
+ value: torch.Tensor,
427
+ scores: torch.Tensor,
428
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
429
+ ) -> torch.Tensor:
430
+ """Compute attention context vector.
431
+
432
+ Args:
433
+ value (torch.Tensor): Transformed value, size
434
+ (#batch, n_head, time2, d_k).
435
+ scores (torch.Tensor): Attention score, size
436
+ (#batch, n_head, time1, time2).
437
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
438
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
439
+
440
+ Returns:
441
+ torch.Tensor: Transformed value (#batch, time1, d_model)
442
+ weighted by the attention score (#batch, time1, time2).
443
+
444
+ """
445
+ n_batch = value.size(0)
446
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
447
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
448
+ # 1st chunk to ease the onnx export.]
449
+ # 2. pytorch training
450
+ if mask.size(2) > 0: # time2 > 0
451
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
452
+ # For last chunk, time2 might be larger than scores.size(-1)
453
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
454
+ scores = scores.masked_fill(mask, -float('inf'))
455
+ attn = torch.softmax(scores, dim=-1).masked_fill(
456
+ mask, 0.0) # (batch, head, time1, time2)
457
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
458
+ # 1. onnx(16/-1, -1/-1, 16/0)
459
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
460
+ else:
461
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
462
+
463
+ p_attn = self.dropout(attn)
464
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
465
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
466
+ self.h * self.d_k)
467
+ ) # (batch, time1, d_model)
468
+
469
+ return self.linear_out(x) # (batch, time1, d_model)
470
+
471
+ def forward(
472
+ self,
473
+ query: torch.Tensor,
474
+ key: torch.Tensor,
475
+ value: torch.Tensor,
476
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
477
+ pos_emb: torch.Tensor = torch.empty(0),
478
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
479
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ """Compute scaled dot product attention.
481
+
482
+ Args:
483
+ query (torch.Tensor): Query tensor (#batch, time1, size).
484
+ key (torch.Tensor): Key tensor (#batch, time2, size).
485
+ value (torch.Tensor): Value tensor (#batch, time2, size).
486
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
487
+ (#batch, time1, time2).
488
+ 1.When applying cross attention between decoder and encoder,
489
+ the batch padding mask for input is in (#batch, 1, T) shape.
490
+ 2.When applying self attention of encoder,
491
+ the mask is in (#batch, T, T) shape.
492
+ 3.When applying self attention of decoder,
493
+ the mask is in (#batch, L, L) shape.
494
+ 4.If the different position in decoder see different block
495
+ of the encoder, such as Mocha, the passed in mask could be
496
+ in (#batch, L, T) shape. But there is no such case in current
497
+ CosyVoice.
498
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
499
+ where `cache_t == chunk_size * num_decoding_left_chunks`
500
+ and `head * d_k == size`
501
+
502
+
503
+ Returns:
504
+ torch.Tensor: Output tensor (#batch, time1, d_model).
505
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
506
+ where `cache_t == chunk_size * num_decoding_left_chunks`
507
+ and `head * d_k == size`
508
+
509
+ """
510
+ q, k, v = self.forward_qkv(query, key, value)
511
+
512
+ # NOTE(xcsong):
513
+ # when export onnx model, for 1st chunk, we feed
514
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
515
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
516
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
517
+ # and we will always do splitting and
518
+ # concatnation(this will simplify onnx export). Note that
519
+ # it's OK to concat & split zero-shaped tensors(see code below).
520
+ # when export jit model, for 1st chunk, we always feed
521
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
522
+ # >>> a = torch.ones((1, 2, 0, 4))
523
+ # >>> b = torch.ones((1, 2, 3, 4))
524
+ # >>> c = torch.cat((a, b), dim=2)
525
+ # >>> torch.equal(b, c) # True
526
+ # >>> d = torch.split(a, 2, dim=-1)
527
+ # >>> torch.equal(d[0], d[1]) # True
528
+ if cache.size(0) > 0:
529
+ key_cache, value_cache = torch.split(cache,
530
+ cache.size(-1) // 2,
531
+ dim=-1)
532
+ k = torch.cat([key_cache, k], dim=2)
533
+ v = torch.cat([value_cache, v], dim=2)
534
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
535
+ # non-trivial to calculate `next_cache_start` here.
536
+ new_cache = torch.cat((k, v), dim=-1)
537
+
538
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
539
+ return self.forward_attention(v, scores, mask), new_cache
540
+
541
+
542
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
543
+ """Multi-Head Attention layer with relative position encoding.
544
+ Paper: https://arxiv.org/abs/1901.02860
545
+ Args:
546
+ n_head (int): The number of heads.
547
+ n_feat (int): The number of features.
548
+ dropout_rate (float): Dropout rate.
549
+ key_bias (bool): Whether to use bias in key linear layer.
550
+ """
551
+
552
+ def __init__(self,
553
+ n_head: int,
554
+ n_feat: int,
555
+ dropout_rate: float,
556
+ key_bias: bool = True):
557
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
558
+ # linear transformation for positional encoding
559
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
560
+ # these two learnable bias are used in matrix c and matrix d
561
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
562
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
563
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
564
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
565
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
566
+
567
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
568
+ """Compute relative positional encoding.
569
+
570
+ Args:
571
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
572
+ time1 means the length of query vector.
573
+
574
+ Returns:
575
+ torch.Tensor: Output tensor.
576
+
577
+ """
578
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
579
+ device=x.device,
580
+ dtype=x.dtype)
581
+ x_padded = torch.cat([zero_pad, x], dim=-1)
582
+
583
+ x_padded = x_padded.view(x.size()[0],
584
+ x.size()[1],
585
+ x.size(3) + 1, x.size(2))
586
+ x = x_padded[:, :, 1:].view_as(x)[
587
+ :, :, :, : x.size(-1) // 2 + 1
588
+ ] # only keep the positions from 0 to time2
589
+ return x
590
+
591
+ def forward(
592
+ self,
593
+ query: torch.Tensor,
594
+ key: torch.Tensor,
595
+ value: torch.Tensor,
596
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
597
+ pos_emb: torch.Tensor = torch.empty(0),
598
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
599
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
600
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
601
+ Args:
602
+ query (torch.Tensor): Query tensor (#batch, time1, size).
603
+ key (torch.Tensor): Key tensor (#batch, time2, size).
604
+ value (torch.Tensor): Value tensor (#batch, time2, size).
605
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
606
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
607
+ pos_emb (torch.Tensor): Positional embedding tensor
608
+ (#batch, time2, size).
609
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
610
+ where `cache_t == chunk_size * num_decoding_left_chunks`
611
+ and `head * d_k == size`
612
+ Returns:
613
+ torch.Tensor: Output tensor (#batch, time1, d_model).
614
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
615
+ where `cache_t == chunk_size * num_decoding_left_chunks`
616
+ and `head * d_k == size`
617
+ """
618
+ q, k, v = self.forward_qkv(query, key, value)
619
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
620
+
621
+ # NOTE(xcsong):
622
+ # when export onnx model, for 1st chunk, we feed
623
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
624
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
625
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
626
+ # and we will always do splitting and
627
+ # concatnation(this will simplify onnx export). Note that
628
+ # it's OK to concat & split zero-shaped tensors(see code below).
629
+ # when export jit model, for 1st chunk, we always feed
630
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
631
+ # >>> a = torch.ones((1, 2, 0, 4))
632
+ # >>> b = torch.ones((1, 2, 3, 4))
633
+ # >>> c = torch.cat((a, b), dim=2)
634
+ # >>> torch.equal(b, c) # True
635
+ # >>> d = torch.split(a, 2, dim=-1)
636
+ # >>> torch.equal(d[0], d[1]) # True
637
+ if cache.size(0) > 0:
638
+ key_cache, value_cache = torch.split(cache,
639
+ cache.size(-1) // 2,
640
+ dim=-1)
641
+ k = torch.cat([key_cache, k], dim=2)
642
+ v = torch.cat([value_cache, v], dim=2)
643
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
644
+ # non-trivial to calculate `next_cache_start` here.
645
+ new_cache = torch.cat((k, v), dim=-1)
646
+
647
+ n_batch_pos = pos_emb.size(0)
648
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
649
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
650
+
651
+ # (batch, head, time1, d_k)
652
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
653
+ # (batch, head, time1, d_k)
654
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
655
+
656
+ # compute attention score
657
+ # first compute matrix a and matrix c
658
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
659
+ # (batch, head, time1, time2)
660
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
661
+
662
+ # compute matrix b and matrix d
663
+ # (batch, head, time1, time2)
664
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
665
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
666
+ if matrix_ac.shape != matrix_bd.shape:
667
+ matrix_bd = self.rel_shift(matrix_bd)
668
+
669
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
670
+ self.d_k) # (batch, head, time1, time2)
671
+
672
+ return self.forward_attention(v, scores, mask), new_cache
673
+
674
+
675
+ class PositionwiseFeedForward(torch.nn.Module):
676
+ """Positionwise feed forward layer.
677
+
678
+ FeedForward are appied on each position of the sequence.
679
+ The output dim is same with the input dim.
680
+
681
+ Args:
682
+ idim (int): Input dimenstion.
683
+ hidden_units (int): The number of hidden units.
684
+ dropout_rate (float): Dropout rate.
685
+ activation (torch.nn.Module): Activation function
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ idim: int,
691
+ hidden_units: int,
692
+ dropout_rate: float,
693
+ activation: torch.nn.Module = torch.nn.ReLU(),
694
+ ):
695
+ super(PositionwiseFeedForward, self).__init__()
696
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
697
+ self.activation = activation
698
+ self.dropout = torch.nn.Dropout(dropout_rate)
699
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
700
+
701
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
702
+ """Forward function.
703
+
704
+ Args:
705
+ xs: input tensor (B, L, D)
706
+ Returns:
707
+ output tensor, (B, L, D)
708
+ """
709
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
710
+
711
+
712
+ class ConformerEncoderLayer(nn.Module):
713
+ """Encoder layer module.
714
+ Args:
715
+ size (int): Input dimension.
716
+ self_attn (torch.nn.Module): Self-attention module instance.
717
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
718
+ instance can be used as the argument.
719
+ feed_forward (torch.nn.Module): Feed-forward module instance.
720
+ `PositionwiseFeedForward` instance can be used as the argument.
721
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
722
+ instance.
723
+ `PositionwiseFeedForward` instance can be used as the argument.
724
+ conv_module (torch.nn.Module): Convolution module instance.
725
+ `ConvlutionModule` instance can be used as the argument.
726
+ dropout_rate (float): Dropout rate.
727
+ normalize_before (bool):
728
+ True: use layer_norm before each sub-block.
729
+ False: use layer_norm after each sub-block.
730
+ """
731
+
732
+ def __init__(
733
+ self,
734
+ size: int,
735
+ self_attn: torch.nn.Module,
736
+ feed_forward: Optional[nn.Module] = None,
737
+ feed_forward_macaron: Optional[nn.Module] = None,
738
+ conv_module: Optional[nn.Module] = None,
739
+ dropout_rate: float = 0.0,
740
+ normalize_before: bool = True,
741
+ ):
742
+ super().__init__()
743
+ self.self_attn = self_attn
744
+ self.feed_forward = feed_forward
745
+ self.feed_forward_macaron = feed_forward_macaron
746
+ self.conv_module = conv_module
747
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
748
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
749
+ if feed_forward_macaron is not None:
750
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
751
+ self.ff_scale = 0.5
752
+ else:
753
+ self.ff_scale = 1.0
754
+ if self.conv_module is not None:
755
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
756
+ self.norm_final = nn.LayerNorm(
757
+ size, eps=1e-12) # for the final output of the block
758
+ self.dropout = nn.Dropout(dropout_rate)
759
+ self.size = size
760
+ self.normalize_before = normalize_before
761
+
762
+ def forward(
763
+ self,
764
+ x: torch.Tensor,
765
+ mask: torch.Tensor,
766
+ pos_emb: torch.Tensor,
767
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
768
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
769
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
770
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
771
+ """Compute encoded features.
772
+
773
+ Args:
774
+ x (torch.Tensor): (#batch, time, size)
775
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
776
+ (0, 0, 0) means fake mask.
777
+ pos_emb (torch.Tensor): positional encoding, must not be None
778
+ for ConformerEncoderLayer.
779
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
780
+ (#batch, 1,time), (0, 0, 0) means fake mask.
781
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
782
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
783
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
784
+ (#batch=1, size, cache_t2)
785
+ Returns:
786
+ torch.Tensor: Output tensor (#batch, time, size).
787
+ torch.Tensor: Mask tensor (#batch, time, time).
788
+ torch.Tensor: att_cache tensor,
789
+ (#batch=1, head, cache_t1 + time, d_k * 2).
790
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
791
+ """
792
+
793
+ # whether to use macaron style
794
+ if self.feed_forward_macaron is not None:
795
+ residual = x
796
+ if self.normalize_before:
797
+ x = self.norm_ff_macaron(x)
798
+ x = residual + self.ff_scale * self.dropout(
799
+ self.feed_forward_macaron(x))
800
+ if not self.normalize_before:
801
+ x = self.norm_ff_macaron(x)
802
+
803
+ # multi-headed self-attention module
804
+ residual = x
805
+ if self.normalize_before:
806
+ x = self.norm_mha(x)
807
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
808
+ att_cache)
809
+ x = residual + self.dropout(x_att)
810
+ if not self.normalize_before:
811
+ x = self.norm_mha(x)
812
+
813
+ # convolution module
814
+ # Fake new cnn cache here, and then change it in conv_module
815
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
816
+ if self.conv_module is not None:
817
+ residual = x
818
+ if self.normalize_before:
819
+ x = self.norm_conv(x)
820
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
821
+ x = residual + self.dropout(x)
822
+
823
+ if not self.normalize_before:
824
+ x = self.norm_conv(x)
825
+
826
+ # feed forward module
827
+ residual = x
828
+ if self.normalize_before:
829
+ x = self.norm_ff(x)
830
+
831
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
832
+ if not self.normalize_before:
833
+ x = self.norm_ff(x)
834
+
835
+ if self.conv_module is not None:
836
+ x = self.norm_final(x)
837
+
838
+ return x, mask, new_att_cache, new_cnn_cache
839
+
840
+
841
+ class UpsampleConformerEncoder(torch.nn.Module):
842
+ """
843
+ Args:
844
+ input_size (int): input dim
845
+ output_size (int): dimension of attention
846
+ attention_heads (int): the number of heads of multi head attention
847
+ linear_units (int): the hidden units number of position-wise feed
848
+ forward
849
+ num_blocks (int): the number of decoder blocks
850
+ static_chunk_size (int): chunk size for static chunk training and
851
+ decoding
852
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
853
+ training or not, You can only use fixed chunk(chunk_size > 0)
854
+ or dyanmic chunk size(use_dynamic_chunk = True)
855
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
856
+ dynamic chunk training
857
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
858
+ """
859
+
860
+ def __init__(
861
+ self,
862
+ input_size: int = 512,
863
+ output_size: int = 512,
864
+ attention_heads: int = 8,
865
+ linear_units: int = 2048,
866
+ num_blocks: int = 6,
867
+ static_chunk_size: int = 25,
868
+ use_dynamic_chunk: bool = False,
869
+ use_dynamic_left_chunk: bool = False,
870
+ key_bias: bool = True,
871
+ ):
872
+ super().__init__()
873
+ self._output_size = output_size
874
+
875
+ self.embed = LinearNoSubsampling(
876
+ input_size, output_size,
877
+ EspnetRelPositionalEncoding(output_size),
878
+ )
879
+
880
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
881
+ self.static_chunk_size = static_chunk_size
882
+ self.use_dynamic_chunk = use_dynamic_chunk
883
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
884
+ activation = torch.nn.SiLU()
885
+ # self-attention module definition
886
+ encoder_selfattn_layer_args = (
887
+ attention_heads,
888
+ output_size,
889
+ 0.0,
890
+ key_bias,
891
+ )
892
+ # feed-forward module definition
893
+ positionwise_layer_args = (
894
+ output_size,
895
+ linear_units,
896
+ 0.0,
897
+ activation,
898
+ )
899
+ # convolution module definition
900
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
901
+ self.encoders = torch.nn.ModuleList([
902
+ ConformerEncoderLayer(
903
+ output_size,
904
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
905
+ PositionwiseFeedForward(*positionwise_layer_args),
906
+ ) for _ in range(num_blocks)
907
+ ])
908
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
909
+ self.up_embed = LinearNoSubsampling(
910
+ input_size, output_size,
911
+ EspnetRelPositionalEncoding(output_size),
912
+ )
913
+ self.up_encoders = torch.nn.ModuleList([
914
+ ConformerEncoderLayer(
915
+ output_size,
916
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
917
+ PositionwiseFeedForward(*positionwise_layer_args),
918
+ ) for _ in range(4)
919
+ ])
920
+
921
+ def output_size(self) -> int:
922
+ return self._output_size
923
+
924
+ def forward(
925
+ self,
926
+ xs: torch.Tensor,
927
+ xs_lens: torch.Tensor,
928
+ context: torch.Tensor = torch.zeros(0, 0, 0),
929
+ decoding_chunk_size: int = 0,
930
+ num_decoding_left_chunks: int = -1,
931
+ streaming: bool = False,
932
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
933
+ """Embed positions in tensor.
934
+
935
+ Args:
936
+ xs: padded input tensor (B, T, D)
937
+ xs_lens: input length (B)
938
+ decoding_chunk_size: decoding chunk size for dynamic chunk
939
+ 0: default for training, use random dynamic chunk.
940
+ <0: for decoding, use full chunk.
941
+ >0: for decoding, use fixed chunk size as set.
942
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
943
+ the chunk size is decoding_chunk_size.
944
+ >=0: use num_decoding_left_chunks
945
+ <0: use all left chunks
946
+ Returns:
947
+ encoder output tensor xs, and subsampled masks
948
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
949
+ masks: torch.Tensor batch padding mask after subsample
950
+ (B, 1, T' ~= T/subsample_rate)
951
+ NOTE(xcsong):
952
+ We pass the `__call__` method of the modules instead of `forward` to the
953
+ checkpointing API because `__call__` attaches all the hooks of the module.
954
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
955
+ """
956
+ T = xs.size(1)
957
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
958
+ xs, pos_emb, masks = self.embed(xs, masks)
959
+ if context.size(1) != 0:
960
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
961
+ context_masks = torch.ones(1, 1, context.size(1)).to(masks)
962
+ context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
963
+ mask_pad = masks # (B, 1, T/subsample_rate)
964
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
965
+ # lookahead + conformer encoder
966
+ xs = self.pre_lookahead_layer(xs, context=context)
967
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
968
+
969
+ # upsample + conformer encoder
970
+ xs = xs.transpose(1, 2).contiguous()
971
+ xs, xs_lens = self.up_layer(xs, xs_lens)
972
+ xs = xs.transpose(1, 2).contiguous()
973
+ T = xs.size(1)
974
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
975
+ xs, pos_emb, masks = self.up_embed(xs, masks)
976
+ mask_pad = masks # (B, 1, T/subsample_rate)
977
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
978
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
979
+
980
+ xs = self.after_norm(xs)
981
+ # Here we assume the mask is not changed in encoder layers, so just
982
+ # return the masks before encoder layers, and the masks will be used
983
+ # for cross attention with decoder later
984
+ return xs, masks
985
+
986
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
987
+ pos_emb: torch.Tensor,
988
+ mask_pad: torch.Tensor) -> torch.Tensor:
989
+ for layer in self.encoders:
990
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
991
+ return xs
992
+
993
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
994
+ pos_emb: torch.Tensor,
995
+ mask_pad: torch.Tensor) -> torch.Tensor:
996
+ for layer in self.up_encoders:
997
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
998
+ return xs