sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1917 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ #!/usr/bin/env python3
15
+ import math
16
+ from typing import Optional, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor, nn
21
+
22
+
23
+ class BlockBase(nn.Module):
24
+ """Block abstract module"""
25
+
26
+ def __init__(self, input_size, output_size):
27
+ super().__init__()
28
+ self.input_size = input_size
29
+ self.output_size = output_size
30
+
31
+
32
+ def get_activation(name="relu"):
33
+ """Select an activation function by name
34
+
35
+ Args:
36
+ name: str
37
+ activation function name,
38
+ one of ["relu", "gelu", "swish", "sigmoid"],
39
+ default "relu".
40
+ """
41
+ name = name.lower()
42
+ if name == "relu":
43
+ return nn.ReLU(inplace=True)
44
+ if name == "gelu":
45
+ return nn.GELU()
46
+ if name == "swish":
47
+ return Swish()
48
+ if name == "sigmoid":
49
+ return torch.nn.Sigmoid()
50
+ return nn.Identity()
51
+
52
+
53
+ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
54
+ """
55
+ The function is very important for Transformer Transducer Streaming mode
56
+ Args:
57
+ xs_len (int): sequence length
58
+ chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
59
+ It also supports adaptive chunk size [0,10,15,45]
60
+ left_window (int): how many left chunks can be seen
61
+ right_window (int): how many right chunks can be seen. It is used for
62
+ chunk overlap model.
63
+ Returns:
64
+ mask (torch.Tensor): a mask tensor for streaming model
65
+ Torch 1.0.1
66
+ tensor([[1., 1., 0., 0.],
67
+ [0., 1., 1., 0.],
68
+ [0., 0., 1., 1.]])
69
+ Torch 1.4.1
70
+ tensor([[True., True., False., False.],
71
+ [False., True., True., False.],
72
+ [False., False., True., True.]])
73
+ """
74
+ chunk_start_idx = torch.Tensor(
75
+ chunk_start_idx
76
+ ).long() # first idx of each chunk, such as [0,18,36,48].
77
+ start_pad = torch.nn.functional.pad(
78
+ chunk_start_idx, (1, 0)
79
+ ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
80
+ end_pad = torch.nn.functional.pad(
81
+ chunk_start_idx, (0, 1), value=x_len
82
+ ) # append x_len to the end, so it becomes [0,18,36,48, x_len]
83
+ seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
84
+ idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[
85
+ :, 1
86
+ ] # idx size: [x_len]
87
+ # boundary = end_pad[idx] # boundary size: [x_len]
88
+ seq_range_expand = (
89
+ torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
90
+ ) # seq_range_expand size [x_len, x_len]
91
+ idx_left = idx - left_window
92
+ idx_left[idx_left < 0] = 0
93
+ boundary_left = start_pad[idx_left]
94
+ mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
95
+ idx_right = idx + right_window
96
+ idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
97
+ boundary_right = end_pad[idx_right]
98
+ mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
99
+ return mask_left & mask_right
100
+
101
+
102
+ class Swish(nn.Module):
103
+ """Implement Swish activation module.
104
+ From https://arxiv.org/pdf/2005.03191.pdf
105
+
106
+ """
107
+
108
+ def __init__(self) -> None:
109
+ super().__init__()
110
+ self.act_fn = nn.Sigmoid()
111
+
112
+ def forward(self, x: Tensor) -> Tensor:
113
+ """Apply Swish function
114
+
115
+ Args:
116
+ x: torch.Tensor
117
+ Input.
118
+ """
119
+ return x * self.act_fn(x)
120
+
121
+
122
+ class GLU(nn.Module):
123
+ """Implement Gated Linear Unit (GLU) module"""
124
+
125
+ def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
126
+ super().__init__()
127
+ self.dim = dim
128
+ self.act_name = act_name.lower()
129
+
130
+ if self.act_name == "relu":
131
+ self.act_fn = nn.ReLU(inplace=True)
132
+ elif self.act_name == "gelu":
133
+ self.act_fn = nn.GELU()
134
+ elif self.act_name == "swish":
135
+ self.act_fn = Swish()
136
+ elif self.act_name == "sigmoid":
137
+ self.act_fn = nn.Sigmoid()
138
+ else:
139
+ self.act_fn = nn.Identity()
140
+
141
+ def forward(self, x: Tensor) -> Tensor:
142
+ """GLU forward
143
+ Apply Swish function on the first half of input matrices
144
+ with sigmoid of the second half.
145
+
146
+ Args:
147
+ x: torch.Tensor
148
+ Input.
149
+
150
+ """
151
+ half_x, gate = x.chunk(2, dim=self.dim)
152
+ return half_x * self.act_fn(gate)
153
+
154
+
155
+ # TODO: Abdel, this can be improved using GLU module
156
+ class GLUPointWiseConv(nn.Module):
157
+ """GLUPointWiseConv module
158
+ used for conformer architecture,
159
+ for more details see:
160
+ https://arxiv.org/pdf/2005.08100v1.pdf
161
+
162
+ Args:
163
+ input_dim: int
164
+ input channel size.
165
+ output_dim: int
166
+ output channel size.
167
+ kernel_size: int
168
+ kernel size
169
+ glu_type: str, optional
170
+ activation function one of
171
+ ["sigmoid", "relu", "gelu"]
172
+ default "sigmoid".
173
+ bias_in_glu: bool, optional
174
+ use addtive bias in glu
175
+ causal: bool, optional
176
+ if set to True, padding is set to the half of
177
+ kernel size, ie, convolution can't see future frames.
178
+ default False.
179
+
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ input_dim,
185
+ output_dim,
186
+ kernel_size,
187
+ glu_type="sigmoid",
188
+ bias_in_glu=True,
189
+ causal=False,
190
+ ):
191
+ super().__init__()
192
+
193
+ self.glu_type = glu_type
194
+ self.output_dim = output_dim
195
+ self.bias_in_glu = bias_in_glu
196
+ if causal:
197
+ self.ext_pw_conv_1d = nn.Conv1d(
198
+ input_dim,
199
+ output_dim * 2,
200
+ kernel_size,
201
+ 1,
202
+ padding=(kernel_size - 1),
203
+ )
204
+ else:
205
+ self.ext_pw_conv_1d = nn.Conv1d(
206
+ input_dim,
207
+ output_dim * 2,
208
+ kernel_size,
209
+ 1,
210
+ padding=(kernel_size - 1) // 2,
211
+ )
212
+
213
+ if glu_type == "sigmoid":
214
+ self.glu_act = nn.Sigmoid()
215
+ elif glu_type == "relu":
216
+ self.glu_act = nn.ReLU()
217
+ elif glu_type == "gelu":
218
+ self.glu_act = nn.GELU()
219
+ elif glu_type == "swish":
220
+ self.glu_act = Swish()
221
+ else:
222
+ raise ValueError(f"Unsupported activation type {self.glu_act}")
223
+
224
+ if bias_in_glu:
225
+ self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
226
+ self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
227
+
228
+ def forward(self, x):
229
+ """
230
+ Args:
231
+ x: torch.Tensor
232
+ input tensor
233
+ """
234
+ # to be consistent with GLULinear, we assume the input always has the
235
+ # #channel (#dim) in the last dimension of the tensor, so need to
236
+ # switch the dimension first for 1D-Conv case
237
+ x = x.permute([0, 2, 1])
238
+ x = self.ext_pw_conv_1d(x)
239
+ if self.glu_type == "bilinear":
240
+ if self.bias_in_glu:
241
+ x = (x[:, 0 : self.output_dim, :] + self.b1) * (
242
+ x[:, self.output_dim : self.output_dim * 2, :] + self.b2
243
+ )
244
+ else:
245
+ x = (x[:, 0 : self.output_dim, :]) * (
246
+ x[:, self.output_dim : self.output_dim * 2, :]
247
+ )
248
+ else:
249
+ if self.bias_in_glu:
250
+ x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act(
251
+ x[:, self.output_dim : self.output_dim * 2, :] + self.b2
252
+ )
253
+ else:
254
+ x = (x[:, 0 : self.output_dim, :]) * self.glu_act(
255
+ x[:, self.output_dim : self.output_dim * 2, :]
256
+ )
257
+
258
+ x = x.permute([0, 2, 1])
259
+ return x
260
+
261
+
262
+ class DepthWiseSeperableConv1d(nn.Module):
263
+ """DepthWiseSeperableConv1d module used in Convnet module
264
+ for the conformer, for more details see:
265
+ https://arxiv.org/pdf/2005.08100v1.pdf
266
+
267
+ Args:
268
+ input_dim: int
269
+ input channel size.
270
+ depthwise_seperable_out_channel: int
271
+ if set different to 0, the number of
272
+ depthwise_seperable_out_channel will be used as a channel_out
273
+ of the second conv1d layer.
274
+ otherwise, it equal to 0, the second conv1d layer is skipped.
275
+ kernel_size: int
276
+ kernel_size
277
+ depthwise_multiplier: int
278
+ number of input_dim channels duplication. this value
279
+ will be used to compute the hidden channels of the Conv1D.
280
+ padding: int, optional
281
+ padding for the conv1d,
282
+ default: 0.
283
+
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ input_dim,
289
+ depthwise_seperable_out_channel,
290
+ kernel_size,
291
+ depthwise_multiplier,
292
+ padding=0,
293
+ ):
294
+ super().__init__()
295
+
296
+ self.dw_conv = nn.Conv1d(
297
+ input_dim,
298
+ input_dim * depthwise_multiplier,
299
+ kernel_size,
300
+ 1,
301
+ padding=padding,
302
+ groups=input_dim,
303
+ )
304
+
305
+ if depthwise_seperable_out_channel != 0:
306
+ self.pw_conv = nn.Conv1d(
307
+ input_dim * depthwise_multiplier,
308
+ depthwise_seperable_out_channel,
309
+ 1,
310
+ 1,
311
+ 0,
312
+ )
313
+ else:
314
+ self.pw_conv = nn.Identity()
315
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
316
+
317
+ def forward(self, x):
318
+ """
319
+
320
+ Args:
321
+ x: torch.Tensor
322
+ input tensor
323
+ """
324
+ x = self.dw_conv(x)
325
+ if self.depthwise_seperable_out_channel != 0:
326
+ x = self.pw_conv(x)
327
+ return x
328
+
329
+
330
+ class ConvModule(nn.Module):
331
+ """ConvModule Module for the conformer block.
332
+ for more details see:
333
+ https://arxiv.org/pdf/2005.08100v1.pdf
334
+
335
+ Args:
336
+ input_dim: int
337
+ input channel size.
338
+ ext_pw_out_channel: int
339
+ if > 0, ext_pw_out_channel is a dim channel size
340
+ for the last pointwise conv after swish activation.
341
+ depthwise_seperable_out_channel: int
342
+ if set different to 0, the number of
343
+ depthwise_seperable_out_channel
344
+ will be used as a channel_out of the second conv1d layer.
345
+ otherwise, it equal to 0, the second conv1d layer is skipped.
346
+ ext_pw_kernel_size: int
347
+ kernel size of the conv pointwise of the conformer.
348
+ kernel_size: int
349
+ kernel size.
350
+ depthwise_multiplier: int
351
+ number of input_dim channels duplication. this value
352
+ will be used to compute the hidden channels of the Conv1D.
353
+ dropout_rate: float
354
+ dropout rate.
355
+ causal: bool, optional
356
+ if set to True, convolution have no access
357
+ to future frames. default False.
358
+ batch_norm: bool, optional
359
+ if set to True, apply batchnorm before activation.
360
+ default False
361
+ chunk_se: int, optional
362
+ 0 for offline SE.
363
+ 1 for streaming SE, where mean is computed
364
+ by accumulated history until current chunk_se.
365
+ 2 for streaming SE, where mean is computed
366
+ by only the current chunk.
367
+ chunk_size: int, optional
368
+ chunk size for cnn. default 18
369
+ activation: str, optional
370
+ activation function used in ConvModule,
371
+ default: "relu".
372
+ glu_type: str, optional
373
+ activation function used for the glu,
374
+ default: "sigmoid".
375
+ bias_in_glu: bool, optional
376
+ if set to True, use additive bias in the weight module
377
+ before GLU.
378
+ linear_glu_in_convm: bool, optional
379
+ if set to True, use GLULinear module,
380
+ otherwise, used GLUPointWiseConv module.
381
+ default to False.
382
+ export: bool, optional,
383
+ if set to True, padding is equal to 0. This is for inference,
384
+ or onnx export. Typically this is set by the export program or
385
+ the decoder program, and it isn't present in your config file.
386
+ default False
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ input_dim,
392
+ ext_pw_out_channel,
393
+ depthwise_seperable_out_channel,
394
+ ext_pw_kernel_size,
395
+ kernel_size,
396
+ depthwise_multiplier,
397
+ dropout_rate,
398
+ causal=False,
399
+ batch_norm=False,
400
+ chunk_se=0,
401
+ chunk_size=18,
402
+ activation="relu",
403
+ glu_type="sigmoid",
404
+ bias_in_glu=True,
405
+ linear_glu_in_convm=False,
406
+ export=False,
407
+ ):
408
+ super().__init__()
409
+ self.layer_norm = nn.LayerNorm(input_dim)
410
+ self.input_dim = input_dim
411
+ self.ext_pw_out_channel = ext_pw_out_channel
412
+ self.ext_pw_kernel_size = ext_pw_kernel_size
413
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
414
+ self.glu_type = glu_type
415
+ self.bias_in_glu = bias_in_glu
416
+ self.linear_glu_in_convm = linear_glu_in_convm
417
+ self.causal = causal
418
+
419
+ self._add_ext_pw_layer()
420
+
421
+ self.batch_norm = batch_norm
422
+ self.kernel_size = kernel_size
423
+
424
+ if batch_norm:
425
+ self.bn_layer = nn.BatchNorm1d(input_dim)
426
+
427
+ self.act = get_activation(activation)
428
+ self.dropout = nn.Dropout(dropout_rate)
429
+ self.export = export
430
+
431
+ if causal:
432
+ padding = 0 if export else kernel_size - 1
433
+ else:
434
+ padding = (kernel_size - 1) // 2
435
+
436
+ self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
437
+ input_dim,
438
+ depthwise_seperable_out_channel,
439
+ kernel_size,
440
+ depthwise_multiplier,
441
+ padding=padding,
442
+ )
443
+
444
+ if depthwise_seperable_out_channel != 0:
445
+ if input_dim != depthwise_seperable_out_channel:
446
+ self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim)
447
+ else:
448
+ if depthwise_multiplier != 1:
449
+ self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim)
450
+
451
+ def _add_ext_pw_layer(self):
452
+ """
453
+ This function is an extension of __init__ function
454
+ and dedicated to the convolution module creation
455
+ of the conformer.
456
+ """
457
+ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
458
+ nn.Identity()
459
+ ) # jit hacks.
460
+ self.squeeze_excitation = nn.Identity() # jit.
461
+ self.apply_ln1 = self.fix_len1 = False # jit.
462
+
463
+ if self.ext_pw_out_channel != 0:
464
+ if self.causal:
465
+ self.ext_pw_conv_1d = nn.Conv1d(
466
+ self.input_dim,
467
+ self.ext_pw_out_channel,
468
+ self.ext_pw_kernel_size,
469
+ 1,
470
+ padding=(self.ext_pw_kernel_size - 1),
471
+ )
472
+ if self.ext_pw_kernel_size > 1:
473
+ self.fix_len1 = True
474
+ else:
475
+ self.fix_len1 = False
476
+ else:
477
+ self.ext_pw_conv_1d = nn.Conv1d(
478
+ self.input_dim,
479
+ self.ext_pw_out_channel,
480
+ self.ext_pw_kernel_size,
481
+ 1,
482
+ padding=(self.ext_pw_kernel_size - 1) // 2,
483
+ )
484
+ self.fix_len1 = False
485
+
486
+ if self.linear_glu_in_convm:
487
+ self.glu = GLULinear(
488
+ self.input_dim,
489
+ self.ext_pw_out_channel,
490
+ self.glu_type,
491
+ self.bias_in_glu,
492
+ )
493
+ else:
494
+ self.glu = GLUPointWiseConv(
495
+ self.input_dim,
496
+ self.ext_pw_out_channel,
497
+ self.ext_pw_kernel_size,
498
+ self.glu_type,
499
+ self.bias_in_glu,
500
+ self.causal,
501
+ )
502
+
503
+ if self.input_dim != self.ext_pw_out_channel:
504
+ self.apply_ln1 = True
505
+ self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
506
+ else:
507
+ self.apply_ln1 = False
508
+ else:
509
+ self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
510
+ self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
511
+
512
+ def forward(self, x):
513
+ """ConvModule Forward.
514
+
515
+ Args:
516
+ x: torch.Tensor
517
+ input tensor.
518
+ """
519
+ x = self.layer_norm(x)
520
+
521
+ if self.ext_pw_out_channel != 0:
522
+ x = self.glu(x)
523
+ if self.causal and self.ext_pw_kernel_size > 1:
524
+ x = x[:, : -(self.ext_pw_kernel_size - 1), :]
525
+ if self.apply_ln1:
526
+ x = self.ln1(x)
527
+ else:
528
+ x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
529
+ x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
530
+ x = x_0 + x_1
531
+
532
+ x = x.permute([0, 2, 1])
533
+
534
+ x = self.dw_sep_conv_1d(x)
535
+ if self.causal and self.kernel_size > 1:
536
+ x = x[:, :, : -(self.kernel_size - 1)]
537
+ if hasattr(self, "ln2"):
538
+ x = x.permute([0, 2, 1])
539
+ x = self.ln2(x)
540
+ x = x.permute([0, 2, 1])
541
+ if self.batch_norm:
542
+ x = self.bn_layer(x)
543
+ x = self.act(x)
544
+
545
+ if self.ext_pw_out_channel != 0:
546
+ x = self.ext_pw_conv_1d(x)
547
+ if self.fix_len1:
548
+ x = x[:, :, : -(self.ext_pw_kernel_size - 1)]
549
+
550
+ if self.apply_ln1:
551
+ x = x.permute([0, 2, 1])
552
+ x = self.ln1(x)
553
+ x = x.permute([0, 2, 1])
554
+
555
+ x = x.permute([0, 2, 1])
556
+ else:
557
+ x = x.unsqueeze(1).permute([0, 1, 3, 2])
558
+ x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
559
+ x = x.squeeze(1)
560
+
561
+ x = self.dropout(x)
562
+ return x
563
+
564
+
565
+ class GLULinear(nn.Module):
566
+ """Linear + GLU module
567
+
568
+ Args:
569
+ input_dim: int
570
+ input size
571
+ output_dim: int
572
+ output size.
573
+ glu_type:
574
+ activation function name used in glu module.
575
+ default "sigmoid" (swish function).
576
+ bias_in_glu: bool, optional
577
+ If True, the addtive bias is added. Default False.
578
+ """
579
+
580
+ def __init__(
581
+ self,
582
+ input_dim,
583
+ output_dim,
584
+ glu_type="sigmoid",
585
+ bias_in_glu=True,
586
+ ):
587
+ super().__init__()
588
+ self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
589
+ self.glu_act = GLU(-1, glu_type)
590
+
591
+ def forward(self, x):
592
+ """GLULinear forward
593
+
594
+ Args:
595
+ x: torch.Tensor
596
+ inpute tensor.
597
+ """
598
+ x = self.linear(x)
599
+ return self.glu_act(x)
600
+
601
+
602
+ class FeedForward(nn.Module):
603
+ """FeedForward Module.
604
+ For more details see Conformer paper:
605
+ https://arxiv.org/pdf/2005.08100.pdf
606
+
607
+ Args:
608
+ d_model: int
609
+ input size.
610
+ d_inner: int
611
+ output size.
612
+ dropout_rate: float,
613
+ dropout rate.
614
+ activation: str,
615
+ activation function name,
616
+ one of ["relu", "swish", "sigmoid"],
617
+ sigmoid activation is only used with "glu_in_fnn=True",
618
+ default "sigmoid".
619
+ bias_in_glu: bool, optional
620
+ """
621
+
622
+ def __init__(
623
+ self,
624
+ d_model,
625
+ d_inner,
626
+ dropout_rate,
627
+ activation="sigmoid",
628
+ bias_in_glu=True,
629
+ ):
630
+ super().__init__()
631
+ self.d_model = d_model
632
+ self.d_inner = d_inner
633
+
634
+ self.layer_norm = nn.LayerNorm(d_model)
635
+ module = GLULinear(d_model, d_inner, activation, bias_in_glu)
636
+ self.net = nn.Sequential(
637
+ module,
638
+ nn.Dropout(dropout_rate),
639
+ nn.Linear(d_inner, d_model),
640
+ nn.Dropout(dropout_rate),
641
+ )
642
+
643
+ def forward(self, x):
644
+ """FeedForward forward function.
645
+
646
+ Args:
647
+ x: torch.Tensor
648
+ input tensor.
649
+ """
650
+ out = self.net(self.layer_norm(x))
651
+
652
+ return out
653
+
654
+
655
+ #### positional encoding starts here
656
+ def _pre_hook(
657
+ state_dict,
658
+ prefix,
659
+ local_metadata,
660
+ strict,
661
+ missing_keys,
662
+ unexpected_keys,
663
+ error_msgs,
664
+ ):
665
+ """Perform pre-hook in load_state_dict for backward compatibility.
666
+
667
+ Note:
668
+ We saved self.pe until v.0.5.2 but we have omitted it later.
669
+ Therefore, we remove the item "pe" from `state_dict` for backward
670
+ compatibility.
671
+
672
+ """
673
+ k = prefix + "pe"
674
+ if k in state_dict:
675
+ state_dict.pop(k)
676
+
677
+
678
+ class T5RelativeAttentionLogitBias(nn.Module):
679
+ """
680
+ This module implements the relative position bias described in Section
681
+ 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
682
+
683
+ The Huggingface implementation is used as a reference
684
+ https://github.com/huggingface/transformers/blob/v4.30.0/src/
685
+ transformers/models/t5/modeling_t5.py#L435
686
+
687
+ Modifies attention as Q*K^T + B, where B is a learned scalar bias based
688
+ on relative position of the query and key. It is HxNxN, where H is the
689
+ number of heads, N is the sequence length.
690
+
691
+ I've made these modifications to the original T5 bias:
692
+ - Skipping of the bucketing step. Original T5 bias converted rel
693
+ position distances into logarithmically increasing buckets. This is
694
+ supposed to help with length generalization.
695
+ - I just directly use rel position index as bias values, as we don't
696
+ need length generalization (40s max is good enough for ASR encoder),
697
+ and it keeps ONNX export simple.
698
+ - I've also extended it so that biases can be asymmetric, the default
699
+ implementation treats L->R and R->L the same. Asymmetric was found to
700
+ yield better results in my experiments.
701
+
702
+ Args:
703
+ num_heads: int
704
+ Number of attention heads
705
+ num_buckets: int
706
+ Number of buckets to use for relative attention bias. This is the
707
+ size of the learnable bias parameter. Bucketing is not yet
708
+ supported, so this defaults to -1 which means no bucketing is
709
+ used (max_distance determines size of bias param).
710
+ max_distance: int
711
+ Maximum distance to use for relative attention bias. With
712
+ num_buckets=-1, this directly controls the max size of the bias
713
+ parameter. When num_buckets > 0 is supported, this will control
714
+ the maximum distance for logarithmic bucketing after which all
715
+ positions are in the same bucket.
716
+ symmetric: bool
717
+ Whether to use symmetric or asymmetric biases. symmetric=False uses
718
+ 2x number of bias params to distinguish L->R from R->L. This was
719
+ found to be better for the encoder.
720
+ """
721
+
722
+ def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False):
723
+ super().__init__()
724
+ self.num_heads = num_heads
725
+ self.num_buckets = num_buckets
726
+ self.max_distance = max_distance
727
+ self.symmetric = symmetric
728
+ self._skip_bucketing = self.num_buckets < 0
729
+ if self._skip_bucketing:
730
+ self.num_buckets = max_distance
731
+ else:
732
+ raise NotImplementedError(
733
+ "T5 attention bias with bucketed positions is not yet tested"
734
+ )
735
+ if not self.symmetric:
736
+ self.num_buckets *= 2
737
+ self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
738
+
739
+ def forward(self, x):
740
+ # instantiate bias compatible with shape of x
741
+ maxpos = x.size(1)
742
+ context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
743
+ :, None
744
+ ]
745
+ memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
746
+ None, :
747
+ ]
748
+ relative_position = memory_position - context_position
749
+ # clipping to a maximum distance using ops that play well with ONNX
750
+ # export
751
+ relative_position = relative_position.masked_fill(
752
+ relative_position < -self.max_distance, -self.max_distance
753
+ )
754
+ relative_position = relative_position.masked_fill(
755
+ relative_position > self.max_distance - 1, self.max_distance - 1
756
+ )
757
+
758
+ # mapping from relative position to index in the bias parameter
759
+ if self._skip_bucketing:
760
+ bias_idx = relative_position
761
+ else:
762
+ bias_idx = self._bucket_relative_position(relative_position)
763
+ if self.symmetric:
764
+ bias_idx = bias_idx.abs()
765
+ else:
766
+ bias_idx += self.num_buckets // 2
767
+
768
+ t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
769
+ t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L]
770
+
771
+ return t5_rel_att_bias
772
+
773
+ def _bucket_relative_position(self, relative_position):
774
+ # this is a placeholder (isn't tested, likely buggy) using HuggingFace
775
+ # implem as a reference this also needs to be extended to support
776
+ # asymmetric +/- ve positions
777
+ relative_buckets = 0
778
+ if not self.causal:
779
+ self.num_buckets //= 2
780
+ relative_buckets += (relative_position > 0).to(
781
+ torch.long
782
+ ) * self.num_buckets
783
+ relative_position = torch.abs(relative_position)
784
+ else:
785
+ relative_position = -torch.min(
786
+ relative_position, torch.zeros_like(relative_position)
787
+ )
788
+ # now relative_position is in the range [0, inf)
789
+
790
+ # half of the buckets are for exact increments in positions
791
+ max_exact = self.num_buckets // 2
792
+ is_small = relative_position < max_exact
793
+
794
+ # The other half of the buckets are for logarithmically bigger bins in
795
+ # positions up to max_distance
796
+ relative_position_if_large = max_exact + (
797
+ torch.log(relative_position.float() / max_exact)
798
+ / math.log(self.max_distance / max_exact)
799
+ * (self.num_buckets - max_exact)
800
+ ).to(torch.long)
801
+ relative_position_if_large = torch.min(
802
+ relative_position_if_large,
803
+ torch.full_like(relative_position_if_large, self.num_buckets - 1),
804
+ )
805
+
806
+ relative_buckets += torch.where(
807
+ is_small, relative_position, relative_position_if_large
808
+ )
809
+ return relative_buckets
810
+
811
+
812
+ class AbsolutePositionalEncoding(nn.Module):
813
+ """Absolute Positional encoding module.
814
+ This module implement Absolute sinusoidal positional encoding
815
+ from: https://arxiv.org/pdf/1706.03762.pdf
816
+
817
+ Args:
818
+ d_model: int
819
+ Input embedding size.
820
+ dropout_rate: float
821
+ dropout rate
822
+ max_len: int, optional
823
+ Maximum input length sequence, Default 5000
824
+
825
+ """
826
+
827
+ def __init__(self, d_model, dropout_rate, max_len=5000):
828
+ """Construct an PositionalEncoding object."""
829
+ super().__init__()
830
+ self.d_model = d_model
831
+ self.xscale = math.sqrt(self.d_model)
832
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
833
+ self.pe = None
834
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
835
+ self._register_load_state_dict_pre_hook(_pre_hook)
836
+
837
+ def extend_pe(self, x):
838
+ """Reset the positional encodings.
839
+
840
+ Args:
841
+ x: torch.Tensor
842
+ """
843
+ if self.pe is not None and self.pe.size(1) >= x.size(1):
844
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
845
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
846
+ return
847
+ pe = torch.zeros(x.size(1), self.d_model)
848
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
849
+ div_term = torch.exp(
850
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
851
+ * -(math.log(10000.0) / self.d_model)
852
+ )
853
+ pe[:, 0::2] = torch.sin(position * div_term)
854
+ pe[:, 1::2] = torch.cos(position * div_term)
855
+ pe = pe.unsqueeze(0)
856
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
857
+
858
+ def forward(self, x: torch.Tensor):
859
+ """Add positional encoding.
860
+
861
+ Args:
862
+ x: torch.Tensor
863
+ Input tensor. shape is (batch, time, ...)
864
+
865
+ Returns:
866
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
867
+
868
+ """
869
+ self.extend_pe(x)
870
+ x = x * self.xscale + self.pe[:, : x.size(1)]
871
+ return self.dropout(x)
872
+
873
+
874
+ #### forward embedding layers starts here
875
+ class MeanVarianceNormLayer(nn.Module):
876
+ """Mean/variance normalization layer.
877
+
878
+ Will subtract mean and multiply input by inverted standard deviation.
879
+ Typically used as a very first layer in a model.
880
+
881
+ Args:
882
+ input_size: int
883
+ layer input size.
884
+ """
885
+
886
+ def __init__(self, input_size):
887
+ super().__init__()
888
+ self.input_size = input_size
889
+ self.global_mean = nn.Parameter(torch.zeros(input_size))
890
+ self.global_invstd = nn.Parameter(torch.ones(input_size))
891
+
892
+ def forward(self, input_: Tensor) -> Tensor:
893
+ """MeanVarianceNormLayer Forward
894
+
895
+ Args:
896
+ input_: torch.Tensor
897
+ input tensor.
898
+ """
899
+ return (input_ - self.global_mean) * self.global_invstd
900
+
901
+
902
+ class CausalConv1D(nn.Conv1d):
903
+ """
904
+ A causal version of nn.Conv1d where each step would have limited access to
905
+ locations on its right or left
906
+ All arguments are the same as nn.Conv1d except padding.
907
+
908
+ If padding is set None, then paddings are set automatically to make it a
909
+ causal convolution where each location would not see any steps on its right.
910
+
911
+ If padding is set as a list (size of 2), then padding[0] would be used as
912
+ left padding and padding[1] as right padding.
913
+ It would make it possible to control the number of steps to be accessible
914
+ on the right and left.
915
+ This mode is not supported when stride > 1. padding[0]+padding[1] should
916
+ be equal to (kernel_size - 1).
917
+ """
918
+
919
+ def __init__(
920
+ self,
921
+ in_channels: int,
922
+ out_channels: int,
923
+ kernel_size: int,
924
+ stride: int = 1,
925
+ padding: Union[str, int] = 0,
926
+ dilation: int = 1,
927
+ groups: int = 1,
928
+ bias: bool = True,
929
+ padding_mode: str = "zeros",
930
+ device=None,
931
+ dtype=None,
932
+ ) -> None:
933
+ self.cache_drop_size = None
934
+ if padding is None:
935
+ self._left_padding = kernel_size - 1
936
+ self._right_padding = stride - 1
937
+ else:
938
+ if stride != 1 and padding != kernel_size - 1:
939
+ raise ValueError("No striding allowed for non-symmetric convolutions!")
940
+ if isinstance(padding, int):
941
+ self._left_padding = padding
942
+ self._right_padding = padding
943
+ elif (
944
+ isinstance(padding, list)
945
+ and len(padding) == 2
946
+ and padding[0] + padding[1] == kernel_size - 1
947
+ ):
948
+ self._left_padding = padding[0]
949
+ self._right_padding = padding[1]
950
+ else:
951
+ raise ValueError(f"Invalid padding param: {padding}!")
952
+
953
+ self._max_cache_len = self._left_padding
954
+
955
+ super().__init__(
956
+ in_channels=in_channels,
957
+ out_channels=out_channels,
958
+ kernel_size=kernel_size,
959
+ stride=stride,
960
+ padding=0,
961
+ dilation=dilation,
962
+ groups=groups,
963
+ bias=bias,
964
+ padding_mode=padding_mode,
965
+ device=device,
966
+ dtype=dtype,
967
+ )
968
+
969
+ def update_cache(self, x, cache=None):
970
+ if cache is None:
971
+ new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
972
+ next_cache = cache
973
+ else:
974
+ new_x = F.pad(x, pad=(0, self._right_padding))
975
+ new_x = torch.cat([cache, new_x], dim=-1)
976
+ if self.cache_drop_size > 0:
977
+ next_cache = new_x[:, :, : -self.cache_drop_size]
978
+ else:
979
+ next_cache = new_x
980
+ next_cache = next_cache[:, :, -cache.size(-1) :]
981
+ return new_x, next_cache
982
+
983
+ def forward(self, x, cache=None):
984
+ x, cache = self.update_cache(x, cache=cache)
985
+ x = super().forward(x)
986
+ if cache is None:
987
+ return x
988
+ else:
989
+ return x, cache
990
+
991
+
992
+ class CausalConv2D(nn.Conv2d):
993
+ """
994
+ A causal version of nn.Conv2d where each location in the 2D matrix would
995
+ have no access to locations on its right or down
996
+ All arguments are the same as nn.Conv2d except padding which should be
997
+ set as None
998
+ """
999
+
1000
+ def __init__(
1001
+ self,
1002
+ in_channels: int,
1003
+ out_channels: int,
1004
+ kernel_size: int,
1005
+ stride: int = 1,
1006
+ padding: Union[str, int] = 0,
1007
+ dilation: int = 1,
1008
+ groups: int = 1,
1009
+ bias: bool = True,
1010
+ padding_mode: str = "zeros",
1011
+ device=None,
1012
+ dtype=None,
1013
+ ) -> None:
1014
+ if padding is not None:
1015
+ raise ValueError("Argument padding should be set to None for CausalConv2D.")
1016
+ self._left_padding = kernel_size - 1
1017
+ self._right_padding = stride - 1
1018
+
1019
+ padding = 0
1020
+ super().__init__(
1021
+ in_channels,
1022
+ out_channels,
1023
+ kernel_size,
1024
+ stride,
1025
+ padding,
1026
+ dilation,
1027
+ groups,
1028
+ bias,
1029
+ padding_mode,
1030
+ device,
1031
+ dtype,
1032
+ )
1033
+
1034
+ def forward(
1035
+ self,
1036
+ x,
1037
+ ):
1038
+ x = F.pad(
1039
+ x,
1040
+ pad=(self._left_padding, self._right_padding, 0, 0),
1041
+ )
1042
+ x = super().forward(x)
1043
+ return x
1044
+
1045
+
1046
+ class NemoConvSubsampling(torch.nn.Module):
1047
+ """Convlutional subsampling module, taken from NeMo ASR
1048
+ (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
1049
+ 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
1050
+
1051
+ Striding Subsampling: "Speech-Transformer: A No-Recurrence
1052
+ Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
1053
+ et al. (https://ieeexplore.ieee.org/document/8462506)
1054
+
1055
+
1056
+ Compared with the EncoderConv2D (`input_layer: custom`), this is a
1057
+ much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
1058
+ Moreover, depthwise convolutions are used to reduce FLOPs, but the first
1059
+ layer is kept as a regular convolution so as not to degrade accuracy.
1060
+
1061
+ `Striding` and `dw_striding` are the same except that the latter uses
1062
+ depthwise convolutions after the first layer, whereas the former does not.
1063
+
1064
+ Args:
1065
+ subsampling_factor (int): Time reduction factor
1066
+ feat_in (int): size of the input features
1067
+ feat_out (int): size of the output features
1068
+ subsampling (str): The subsampling technique, choose from
1069
+ {"striding", "dw-striding", "striding_conv1d",
1070
+ "dw_striding_conv1d"}
1071
+ conv_channels (int): Number of channels for the convolution layers,
1072
+ default is 256.
1073
+ subsampling_conv_chunking_factor (int): Input chunking factor which
1074
+ can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
1075
+ activation (Module): activation function, default is nn.ReLU()
1076
+ is_causal (bool): whether to use causal Conv1/2D, where each step will
1077
+ have limited access to locations on its right or left
1078
+ """
1079
+
1080
+ def __init__(
1081
+ self,
1082
+ feat_in,
1083
+ feat_out,
1084
+ subsampling_factor=4,
1085
+ subsampling="dw_striding",
1086
+ conv_channels=256,
1087
+ subsampling_conv_chunking_factor=1,
1088
+ activation=nn.ReLU(), # noqa: B008
1089
+ is_causal=False,
1090
+ ):
1091
+ super().__init__()
1092
+ self._subsampling = subsampling
1093
+ self._conv_channels = conv_channels
1094
+ self._feat_in = feat_in
1095
+ self._feat_out = feat_out
1096
+
1097
+ if subsampling_factor % 2 != 0:
1098
+ raise ValueError("Sampling factor should be a multiply of 2!")
1099
+ self._sampling_num = int(math.log(subsampling_factor, 2))
1100
+ self.subsampling_factor = subsampling_factor
1101
+ self.is_causal = is_causal
1102
+ self.subsampling_causal_cond = subsampling in (
1103
+ "dw_striding",
1104
+ "striding",
1105
+ "striding_conv1d",
1106
+ )
1107
+
1108
+ if (
1109
+ subsampling_conv_chunking_factor != -1
1110
+ and subsampling_conv_chunking_factor != 1
1111
+ and subsampling_conv_chunking_factor % 2 != 0
1112
+ ):
1113
+ raise ValueError(
1114
+ "subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2"
1115
+ )
1116
+ self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
1117
+
1118
+ in_channels = 1
1119
+ layers = []
1120
+
1121
+ if subsampling == "dw_striding":
1122
+ self._stride = 2
1123
+ self._kernel_size = 3
1124
+ self._ceil_mode = False
1125
+
1126
+ if self.is_causal:
1127
+ self._left_padding = self._kernel_size - 1
1128
+ self._right_padding = self._stride - 1
1129
+ self._max_cache_len = subsampling_factor + 1
1130
+ else:
1131
+ self._left_padding = (self._kernel_size - 1) // 2
1132
+ self._right_padding = (self._kernel_size - 1) // 2
1133
+ self._max_cache_len = 0
1134
+
1135
+ # Layer 1
1136
+ if self.is_causal:
1137
+ layers.append(
1138
+ CausalConv2D(
1139
+ in_channels=in_channels,
1140
+ out_channels=conv_channels,
1141
+ kernel_size=self._kernel_size,
1142
+ stride=self._stride,
1143
+ padding=None,
1144
+ )
1145
+ )
1146
+ else:
1147
+ layers.append(
1148
+ torch.nn.Conv2d(
1149
+ in_channels=in_channels,
1150
+ out_channels=conv_channels,
1151
+ kernel_size=self._kernel_size,
1152
+ stride=self._stride,
1153
+ padding=self._left_padding,
1154
+ )
1155
+ )
1156
+ in_channels = conv_channels
1157
+ layers.append(activation)
1158
+
1159
+ for i in range(self._sampling_num - 1):
1160
+ if self.is_causal:
1161
+ layers.append(
1162
+ CausalConv2D(
1163
+ in_channels=in_channels,
1164
+ out_channels=in_channels,
1165
+ kernel_size=self._kernel_size,
1166
+ stride=self._stride,
1167
+ padding=None,
1168
+ groups=in_channels,
1169
+ )
1170
+ )
1171
+ else:
1172
+ layers.append(
1173
+ torch.nn.Conv2d(
1174
+ in_channels=in_channels,
1175
+ out_channels=in_channels,
1176
+ kernel_size=self._kernel_size,
1177
+ stride=self._stride,
1178
+ padding=self._left_padding,
1179
+ groups=in_channels,
1180
+ )
1181
+ )
1182
+
1183
+ layers.append(
1184
+ torch.nn.Conv2d(
1185
+ in_channels=in_channels,
1186
+ out_channels=conv_channels,
1187
+ kernel_size=1,
1188
+ stride=1,
1189
+ padding=0,
1190
+ groups=1,
1191
+ )
1192
+ )
1193
+ layers.append(activation)
1194
+ in_channels = conv_channels
1195
+
1196
+ elif subsampling == "striding":
1197
+ self._stride = 2
1198
+ self._kernel_size = 3
1199
+ self._ceil_mode = False
1200
+
1201
+ if self.is_causal:
1202
+ self._left_padding = self._kernel_size - 1
1203
+ self._right_padding = self._stride - 1
1204
+ self._max_cache_len = subsampling_factor + 1
1205
+ else:
1206
+ self._left_padding = (self._kernel_size - 1) // 2
1207
+ self._right_padding = (self._kernel_size - 1) // 2
1208
+ self._max_cache_len = 0
1209
+
1210
+ for i in range(self._sampling_num):
1211
+ if self.is_causal:
1212
+ layers.append(
1213
+ CausalConv2D(
1214
+ in_channels=in_channels,
1215
+ out_channels=conv_channels,
1216
+ kernel_size=self._kernel_size,
1217
+ stride=self._stride,
1218
+ padding=None,
1219
+ )
1220
+ )
1221
+ else:
1222
+ layers.append(
1223
+ torch.nn.Conv2d(
1224
+ in_channels=in_channels,
1225
+ out_channels=conv_channels,
1226
+ kernel_size=self._kernel_size,
1227
+ stride=self._stride,
1228
+ padding=self._left_padding,
1229
+ )
1230
+ )
1231
+ layers.append(activation)
1232
+ in_channels = conv_channels
1233
+
1234
+ elif subsampling == "striding_conv1d":
1235
+ in_channels = feat_in
1236
+
1237
+ self._stride = 2
1238
+ self._kernel_size = 5
1239
+ self._ceil_mode = False
1240
+
1241
+ if self.is_causal:
1242
+ self._left_padding = self._kernel_size - 1
1243
+ self._right_padding = self._stride - 1
1244
+ self._max_cache_len = subsampling_factor + 1
1245
+ else:
1246
+ self._left_padding = (self._kernel_size - 1) // 2
1247
+ self._right_padding = (self._kernel_size - 1) // 2
1248
+ self._max_cache_len = 0
1249
+
1250
+ for i in range(self._sampling_num):
1251
+ if self.is_causal:
1252
+ layers.append(
1253
+ CausalConv1D(
1254
+ in_channels=in_channels,
1255
+ out_channels=(
1256
+ feat_out
1257
+ if self._sampling_num == i + 1
1258
+ else conv_channels
1259
+ ),
1260
+ kernel_size=self._kernel_size,
1261
+ stride=self._stride,
1262
+ padding=None,
1263
+ )
1264
+ )
1265
+ else:
1266
+ layers.append(
1267
+ torch.nn.Conv1d(
1268
+ in_channels=in_channels,
1269
+ out_channels=(
1270
+ feat_out
1271
+ if self._sampling_num == i + 1
1272
+ else conv_channels
1273
+ ),
1274
+ kernel_size=self._kernel_size,
1275
+ stride=self._stride,
1276
+ padding=self._left_padding,
1277
+ )
1278
+ )
1279
+ layers.append(activation)
1280
+ in_channels = conv_channels
1281
+
1282
+ elif subsampling == "dw_striding_conv1d":
1283
+ in_channels = feat_in
1284
+
1285
+ self._stride = 2
1286
+ self._kernel_size = 5
1287
+ self._ceil_mode = False
1288
+
1289
+ self._left_padding = (self._kernel_size - 1) // 2
1290
+ self._right_padding = (self._kernel_size - 1) // 2
1291
+
1292
+ # Layer 1
1293
+ layers.extend(
1294
+ [
1295
+ torch.nn.Conv1d(
1296
+ in_channels=in_channels,
1297
+ out_channels=in_channels,
1298
+ kernel_size=self._kernel_size,
1299
+ stride=self._stride,
1300
+ padding=self._left_padding,
1301
+ groups=in_channels,
1302
+ ),
1303
+ torch.nn.Conv1d(
1304
+ in_channels=in_channels,
1305
+ out_channels=(
1306
+ feat_out if self._sampling_num == 1 else conv_channels
1307
+ ),
1308
+ kernel_size=1,
1309
+ stride=1,
1310
+ padding=0,
1311
+ groups=1,
1312
+ ),
1313
+ ]
1314
+ )
1315
+ in_channels = conv_channels
1316
+ layers.append(activation)
1317
+
1318
+ for i in range(self._sampling_num - 1):
1319
+ layers.extend(
1320
+ [
1321
+ torch.nn.Conv1d(
1322
+ in_channels=in_channels,
1323
+ out_channels=in_channels,
1324
+ kernel_size=self._kernel_size,
1325
+ stride=self._stride,
1326
+ padding=self._left_padding,
1327
+ groups=in_channels,
1328
+ ),
1329
+ torch.nn.Conv1d(
1330
+ in_channels=in_channels,
1331
+ out_channels=(
1332
+ feat_out
1333
+ if self._sampling_num == i + 2
1334
+ else conv_channels
1335
+ ),
1336
+ kernel_size=1,
1337
+ stride=1,
1338
+ padding=0,
1339
+ groups=1,
1340
+ ),
1341
+ ]
1342
+ )
1343
+ layers.append(activation)
1344
+ in_channels = conv_channels
1345
+
1346
+ else:
1347
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1348
+
1349
+ if subsampling in ["dw_striding", "striding"]:
1350
+ in_length = torch.tensor(feat_in, dtype=torch.float)
1351
+ out_length = calc_length(
1352
+ lengths=in_length,
1353
+ all_paddings=self._left_padding + self._right_padding,
1354
+ kernel_size=self._kernel_size,
1355
+ stride=self._stride,
1356
+ ceil_mode=self._ceil_mode,
1357
+ repeat_num=self._sampling_num,
1358
+ )
1359
+ self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
1360
+ self.conv2d_subsampling = True
1361
+ elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
1362
+ self.out = None
1363
+ self.conv2d_subsampling = False
1364
+ else:
1365
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1366
+
1367
+ self.conv = torch.nn.Sequential(*layers)
1368
+
1369
+ def get_sampling_frames(self):
1370
+ return [1, self.subsampling_factor]
1371
+
1372
+ def get_streaming_cache_size(self):
1373
+ return [0, self.subsampling_factor + 1]
1374
+
1375
+ def forward(self, x, mask):
1376
+ """
1377
+ Forward method for NeMo subsampling.
1378
+
1379
+ Args:
1380
+ x[Batch, Time, Filters]: torch.Tensor
1381
+ input tensor
1382
+ x_mask: torch.Tensor
1383
+ input mask
1384
+
1385
+ Returns:
1386
+ x: torch.Tensor
1387
+ Resulting tensor from subsampling (B, T //
1388
+ time_reduction_factor, feat_out)
1389
+ pad_mask: torch.Tensor
1390
+ tensor of padded hidden state sequences (B, 1, T //
1391
+ time_reduction_factor)
1392
+ """
1393
+ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
1394
+
1395
+ # split inputs if chunking_factor is set
1396
+ if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
1397
+ if self.subsampling_conv_chunking_factor == 1:
1398
+ # if subsampling_conv_chunking_factor is 1, we split only
1399
+ # if needed.
1400
+ # avoiding a bug / feature limiting indexing of tensors
1401
+ # to 2**31.
1402
+ # see https://github.com/pytorch/pytorch/issues/80020
1403
+ x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
1404
+ need_to_split = torch.numel(x) > x_ceil
1405
+ else:
1406
+ # if subsampling_conv_chunking_factor > 1 we always split
1407
+ need_to_split = True
1408
+
1409
+ if need_to_split:
1410
+ x, success = self.conv_split_by_batch(x)
1411
+ if not success: # if unable to split by batch, try by channel
1412
+ if self._subsampling == "dw_striding":
1413
+ x = self.conv_split_by_channel(x)
1414
+ else:
1415
+ x = self.conv(x) # try anyway
1416
+ else:
1417
+ x = self.conv(x)
1418
+ else:
1419
+ x = self.conv(x)
1420
+
1421
+ # Flatten Channel and Frequency Axes
1422
+ if self.conv2d_subsampling:
1423
+ b, c, t, f = x.size()
1424
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
1425
+ # Transpose to Channel Last mode
1426
+ else:
1427
+ x = x.transpose(1, 2)
1428
+
1429
+ if mask is None:
1430
+ return x, None
1431
+
1432
+ max_audio_length = x.shape[1]
1433
+ feature_lens = mask.sum(1)
1434
+ padding_length = torch.ceil(feature_lens / self.subsampling_factor)
1435
+ if self.is_causal and self.subsampling_causal_cond:
1436
+ feature_lens_remainder = feature_lens % self.subsampling_factor
1437
+ padding_length[feature_lens_remainder != 1] += 1
1438
+ pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
1439
+ padding_length.size(0), -1
1440
+ ) < padding_length.unsqueeze(1)
1441
+ return x, pad_mask.unsqueeze(1)
1442
+
1443
+ def reset_parameters(self):
1444
+ # initialize weights
1445
+ if self._subsampling == "dw_striding":
1446
+ with torch.no_grad():
1447
+ # init conv
1448
+ scale = 1.0 / self._kernel_size
1449
+ dw_max = (self._kernel_size**2) ** -0.5
1450
+ pw_max = self._conv_channels**-0.5
1451
+
1452
+ torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
1453
+ torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
1454
+
1455
+ for idx in range(2, len(self.conv), 3):
1456
+ torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
1457
+ torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
1458
+ torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
1459
+ torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)
1460
+
1461
+ # init fc (80 * 64 = 5120 from https://github.com/kssteven418/
1462
+ # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
1463
+ # src/models/conformer_encoder.py#L487
1464
+ fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
1465
+ torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
1466
+ torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
1467
+
1468
+ def conv_split_by_batch(self, x):
1469
+ """Tries to split input by batch, run conv and concat results"""
1470
+ b, _, _, _ = x.size()
1471
+ if b == 1: # can't split if batch size is 1
1472
+ return x, False
1473
+
1474
+ if self.subsampling_conv_chunking_factor > 1:
1475
+ cf = self.subsampling_conv_chunking_factor
1476
+ else:
1477
+ # avoiding a bug / feature limiting indexing of tensors to 2**31
1478
+ # see https://github.com/pytorch/pytorch/issues/80020
1479
+ x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
1480
+ p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
1481
+ cf = 2**p
1482
+
1483
+ new_batch_size = b // cf
1484
+ if new_batch_size == 0: # input is too big
1485
+ return x, False
1486
+
1487
+ return (
1488
+ torch.cat(
1489
+ [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]
1490
+ ),
1491
+ True,
1492
+ )
1493
+
1494
+ def conv_split_by_channel(self, x):
1495
+ """For dw convs, tries to split input by time, run conv and concat
1496
+ results"""
1497
+ x = self.conv[0](x) # full conv2D
1498
+ x = self.conv[1](x) # activation
1499
+
1500
+ for i in range(self._sampling_num - 1):
1501
+ _, c, t, _ = x.size()
1502
+
1503
+ if self.subsampling_conv_chunking_factor > 1:
1504
+ cf = self.subsampling_conv_chunking_factor
1505
+ else:
1506
+ # avoiding a bug / feature limiting indexing of tensors
1507
+ # to 2**31
1508
+ # see https://github.com/pytorch/pytorch/issues/80020
1509
+ p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
1510
+ cf = 2**p
1511
+
1512
+ new_c = int(c // cf)
1513
+ if new_c == 0:
1514
+ new_c = 1
1515
+
1516
+ new_t = int(t // cf)
1517
+ if new_t == 0:
1518
+ new_t = 1
1519
+
1520
+ x = self.channel_chunked_conv(
1521
+ self.conv[i * 3 + 2], new_c, x
1522
+ ) # conv2D, depthwise
1523
+
1524
+ # splitting pointwise convs by time
1525
+ x = torch.cat(
1526
+ [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)],
1527
+ 2,
1528
+ ) # conv2D, pointwise
1529
+ x = self.conv[i * 3 + 4](x) # activation
1530
+ return x
1531
+
1532
+ def channel_chunked_conv(self, conv, chunk_size, x):
1533
+ """Performs channel chunked convolution"""
1534
+
1535
+ ind = 0
1536
+ out_chunks = []
1537
+ for chunk in torch.split(x, chunk_size, 1):
1538
+ step = chunk.size()[1]
1539
+
1540
+ if self.is_causal:
1541
+ chunk = nn.functional.pad(
1542
+ chunk,
1543
+ pad=(
1544
+ self._kernel_size - 1,
1545
+ self._stride - 1,
1546
+ self._kernel_size - 1,
1547
+ self._stride - 1,
1548
+ ),
1549
+ )
1550
+ ch_out = nn.functional.conv2d(
1551
+ chunk,
1552
+ conv.weight[ind : ind + step, :, :, :],
1553
+ bias=conv.bias[ind : ind + step],
1554
+ stride=self._stride,
1555
+ padding=0,
1556
+ groups=step,
1557
+ )
1558
+ else:
1559
+ ch_out = nn.functional.conv2d(
1560
+ chunk,
1561
+ conv.weight[ind : ind + step, :, :, :],
1562
+ bias=conv.bias[ind : ind + step],
1563
+ stride=self._stride,
1564
+ padding=self._left_padding,
1565
+ groups=step,
1566
+ )
1567
+ out_chunks.append(ch_out)
1568
+ ind += step
1569
+
1570
+ return torch.cat(out_chunks, 1)
1571
+
1572
+ def change_subsampling_conv_chunking_factor(
1573
+ self, subsampling_conv_chunking_factor: int
1574
+ ):
1575
+ if (
1576
+ subsampling_conv_chunking_factor != -1
1577
+ and subsampling_conv_chunking_factor != 1
1578
+ and subsampling_conv_chunking_factor % 2 != 0
1579
+ ):
1580
+ raise ValueError(
1581
+ "subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2"
1582
+ )
1583
+ self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
1584
+
1585
+
1586
+ def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1):
1587
+ """Calculates the output length of a Tensor passed through a convolution or
1588
+ max pooling layer"""
1589
+ add_pad: float = all_paddings - kernel_size
1590
+ one: float = 1.0
1591
+ for i in range(repeat_num):
1592
+ lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
1593
+ lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
1594
+ return lengths.to(dtype=torch.int)
1595
+
1596
+
1597
+ #### multihead attention starts here
1598
+ class AttModule(nn.Module):
1599
+ """Attention abstraction module"""
1600
+
1601
+ def __init__(self):
1602
+ super().__init__()
1603
+ self.export_mode = False
1604
+
1605
+ def set_export(self, mode=True):
1606
+ """set the export mode"""
1607
+ self.export_mode = mode
1608
+
1609
+ def forward(
1610
+ self,
1611
+ x: Tensor,
1612
+ memory: Optional[Tensor] = None,
1613
+ pos_emb: Optional[Tensor] = None,
1614
+ att_mask: Optional[Tensor] = None,
1615
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
1616
+ """AttModule forward
1617
+
1618
+ Args:
1619
+ x: torch.Tensor
1620
+ input tensor.
1621
+ memory: torch.Tensor, optional
1622
+ memory tensor.
1623
+ pos_emb: torch.Tensor, optional
1624
+ positional encoder embedding.
1625
+ att_mask: torch.Tensor, optional
1626
+ attention mask tensor.
1627
+ """
1628
+ return x, memory, pos_emb, att_mask
1629
+
1630
+
1631
+ class AttBlock(BlockBase, AttModule):
1632
+ """Attention Block module to support both Attention and Block module."""
1633
+
1634
+ def memory_dims(self, max_len=False):
1635
+ """memory dimensions"""
1636
+ return (1, self.input_size)
1637
+
1638
+
1639
+ def masked_softmax(
1640
+ scores,
1641
+ mask: Optional[Tensor],
1642
+ ):
1643
+ if mask is not None:
1644
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
1645
+ scores = scores.masked_fill(mask, -torch.inf)
1646
+ attn = torch.softmax(scores, dim=-1).masked_fill(
1647
+ mask, 0.0
1648
+ ) # (batch, head, time1, time2)
1649
+ else:
1650
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
1651
+ return attn
1652
+
1653
+
1654
+ class MultiHeadedAttention(nn.Module):
1655
+ """Multi-Head Attention layer with optional relative position embedding
1656
+ and GLU.
1657
+
1658
+ Args:
1659
+ n_head: int
1660
+ the number of heads.
1661
+ n_feat: int
1662
+ input size features.
1663
+ dropout_rate: float
1664
+ dropout rate.
1665
+ use_LN: bool
1666
+ apply layer norm or not
1667
+ dropout_at_output: bool
1668
+ whether to apply dropout at output
1669
+ attention_inner_dim: int, optional
1670
+ the attention dimension used in the class,
1671
+ it can be different from the input dimension n_feat.
1672
+ default: -1 (equal to n_feat).
1673
+ use_pt_scaled_dot_product_attention: bool, optional
1674
+ if set True, use pytorch scaled dot product attention in training.
1675
+ NOTE: this will NOT be used in ONNX decoding due to a lack of
1676
+ support. In that case, we use the original attention
1677
+ implementation, which shows no regression.
1678
+ default: False.
1679
+ n_value: int, optional
1680
+ if set to values other than -1, use a different dimension for
1681
+ value. With the default value (i.e. -1), it is backward compatible.
1682
+ group_size: int, optional. must divide `n_head`
1683
+ if group_size > 1: GQA
1684
+ if group_size = 1: MHA
1685
+ if group_size = n_head: MQA
1686
+ """
1687
+
1688
+ inv_sqrt_d_k: torch.jit.Final[float]
1689
+ h: torch.jit.Final[int]
1690
+ h_k: torch.jit.Final[int]
1691
+ g: torch.jit.Final[int]
1692
+
1693
+ def __init__(
1694
+ self,
1695
+ n_head,
1696
+ n_feat,
1697
+ dropout_rate,
1698
+ attention_inner_dim=-1,
1699
+ glu_type="swish",
1700
+ bias_in_glu=True,
1701
+ use_pt_scaled_dot_product_attention=False,
1702
+ n_value=-1,
1703
+ group_size: int = 1,
1704
+ ):
1705
+ super().__init__()
1706
+ if n_value == -1:
1707
+ n_value = n_feat
1708
+ if attention_inner_dim == -1:
1709
+ attention_inner_dim = n_feat
1710
+ assert attention_inner_dim % n_head == 0
1711
+
1712
+ # We assume d_v always equals d_k
1713
+ self.d_k = attention_inner_dim // n_head
1714
+ self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
1715
+ self.h = n_head
1716
+ assert n_head % group_size == 0, "group_size must divide n_head"
1717
+ self.g = group_size
1718
+ self.h_k = n_head // group_size
1719
+
1720
+ self.linear_q = nn.Linear(n_feat, attention_inner_dim)
1721
+ self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
1722
+ self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
1723
+ self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
1724
+
1725
+ self.attn = torch.jit.Attribute(None, Optional[Tensor])
1726
+ self.dropout = nn.Dropout(p=dropout_rate)
1727
+ self.dropout_rate = dropout_rate
1728
+ self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention
1729
+
1730
+ if use_pt_scaled_dot_product_attention and group_size > 1:
1731
+ raise ValueError("Cannot use PT Scaled Attention with GQA")
1732
+
1733
+ # Torchscript eager quantization. Note that these functions below are
1734
+ # NOOPs and have very little impact on performance unless quantization
1735
+ # is enabled.
1736
+ self.quant_q = torch.ao.quantization.QuantStub()
1737
+ self.quant_x = torch.ao.quantization.QuantStub()
1738
+ self.dequant = torch.ao.quantization.DeQuantStub()
1739
+ self.ffunc = torch.ao.nn.quantized.FloatFunctional()
1740
+
1741
+ def forward(
1742
+ self,
1743
+ query: Tensor,
1744
+ key: Tensor,
1745
+ value: Tensor,
1746
+ pos_k: Tensor,
1747
+ pos_v: Tensor,
1748
+ mask: Optional[Tensor],
1749
+ relative_attention_bias: Optional[Tensor] = None,
1750
+ ):
1751
+ """Compute 'Scaled Dot Product Attention'.
1752
+
1753
+ Args:
1754
+ query: torch.Tensor
1755
+ query tensor (batch, time1, size)
1756
+ key: torch.Tensor
1757
+ key tensor (batch, time2, size)
1758
+ value: torch.Tensor
1759
+ value tensor (batch, time1, size)
1760
+ pos_k: torch.Tensor
1761
+ key tensor used for relative positional embedding.
1762
+ pos_v: torch.Tensor
1763
+ value tensor used for relative positional embedding.
1764
+ mask: torch.Tensor
1765
+ mask tensor (batch, time1, time2)
1766
+ relative_attention_bias: torch.Tensor
1767
+ bias added to attention logits w.r.t. relative positions
1768
+ (1, n_head, time1, time2)
1769
+ """
1770
+ n_batch = query.size(0)
1771
+
1772
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d)
1773
+ k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d)
1774
+ v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
1775
+ q = (
1776
+ q.transpose(1, 2)
1777
+ if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting()
1778
+ else q.transpose(1, 2) * self.inv_sqrt_d_k
1779
+ )
1780
+ k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
1781
+ v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
1782
+
1783
+ if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting():
1784
+ attn_mask = None
1785
+ if mask is not None:
1786
+ mask = mask.unsqueeze(1)
1787
+ if relative_attention_bias is not None:
1788
+ attn_mask = mask + relative_attention_bias
1789
+ else:
1790
+ attn_mask = mask
1791
+ if mask.dtype != q.dtype:
1792
+ attn_mask = attn_mask.to(q.dtype)
1793
+
1794
+ with torch.nn.attention.sdpa_kernel(
1795
+ [
1796
+ torch.nn.attention.SDPBackend.FLASH_ATTENTION,
1797
+ torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
1798
+ torch.nn.attention.SDPBackend.MATH,
1799
+ torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
1800
+ ]
1801
+ ):
1802
+ x = torch.nn.functional.scaled_dot_product_attention(
1803
+ q,
1804
+ k,
1805
+ v,
1806
+ attn_mask=attn_mask,
1807
+ dropout_p=self.dropout_rate,
1808
+ )
1809
+ else:
1810
+ if self.h != self.h_k:
1811
+ q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
1812
+ A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
1813
+ else:
1814
+ A = torch.matmul(q, k.transpose(-2, -1))
1815
+ if pos_k is not None:
1816
+ if self.h != self.h_k:
1817
+ B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
1818
+ else:
1819
+ reshape_q = (
1820
+ q.contiguous()
1821
+ .view(n_batch * self.h, -1, self.d_k)
1822
+ .transpose(0, 1)
1823
+ ) # (t1,nh,dk)
1824
+ B = torch.matmul(
1825
+ reshape_q, pos_k.transpose(-2, -1)
1826
+ ) # pos_k: (t1,dk,t2)
1827
+ B = B.transpose(0, 1).view(
1828
+ n_batch, self.h, pos_k.size(0), pos_k.size(1)
1829
+ )
1830
+ scores = A + B
1831
+ else:
1832
+ scores = A
1833
+
1834
+ if relative_attention_bias is not None:
1835
+ scores = scores + relative_attention_bias
1836
+
1837
+ attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
1838
+
1839
+ self.attn = attn
1840
+
1841
+ p_attn = self.dropout(attn)
1842
+ x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k)
1843
+ if pos_v is not None:
1844
+ reshape_attn = (
1845
+ p_attn.contiguous()
1846
+ .view(n_batch * self.h, pos_v.size(0), pos_v.size(1))
1847
+ .transpose(0, 1)
1848
+ ) # (t1, bh, t2)
1849
+
1850
+ attn_v = (
1851
+ torch.matmul(reshape_attn, pos_v)
1852
+ .transpose(0, 1)
1853
+ .contiguous()
1854
+ .view(n_batch, self.h, pos_v.size(0), self.d_k)
1855
+ )
1856
+ x = x + attn_v
1857
+ x = (
1858
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k)
1859
+ ) # (batch, time1, d_model)
1860
+
1861
+ return self.linear_out(x) # (batch, time1, d_model)
1862
+
1863
+
1864
+ class MultiSequential(torch.nn.Sequential):
1865
+ """Multi-input multi-output torch.nn.Sequential"""
1866
+
1867
+ @torch.jit.ignore
1868
+ def forward(self, *args):
1869
+ """Forward method implementation."""
1870
+ for m in self:
1871
+ args = m(*args)
1872
+ return args
1873
+
1874
+
1875
+ def get_offset(input_layer: str, time_reduction: int):
1876
+ """Get an offset. We will use the offset for determining #frames of a
1877
+ subsampled feature.
1878
+
1879
+ Args:
1880
+ input_layer (str): Type of an input layer
1881
+ time_reduction (int): time reduction factor for downsampling a feature
1882
+ Returns:
1883
+ int: offset
1884
+ """
1885
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
1886
+ return 3
1887
+ if input_layer in ("conv2d",) and time_reduction == 6:
1888
+ return 1
1889
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
1890
+ return 7
1891
+ return 0
1892
+
1893
+
1894
+ def unfold_tensor(xs_pad, max_seq_len):
1895
+ """
1896
+ For a given tensor with shape of (N, T, D), if sequence length T is
1897
+ longer than max_seq_len, this function unfold it to a
1898
+ (NT', max_seq_len, D) where T' is T // max_seq_len.
1899
+ Args:
1900
+ xs_pad: N, T, D
1901
+ """
1902
+ _, _, D = xs_pad.shape
1903
+ xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
1904
+ # N x D x 1 x T => N x (D x max_seq_len) x T'
1905
+ xs_pad = F.unfold(
1906
+ xs_pad[..., None, :],
1907
+ kernel_size=(1, max_seq_len),
1908
+ stride=(1, max_seq_len),
1909
+ )
1910
+ new_bsz, _, slen = xs_pad.shape
1911
+ # N x D x max_seq_len x T'
1912
+ xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
1913
+ # N x T' x max_seq_len x D
1914
+ xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
1915
+ # NT' x max_seq_len x D
1916
+ xs_pad = xs_pad.view(-1, max_seq_len, D)
1917
+ return xs_pad