sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1917 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
#!/usr/bin/env python3
|
15
|
+
import math
|
16
|
+
from typing import Optional, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn.functional as F
|
20
|
+
from torch import Tensor, nn
|
21
|
+
|
22
|
+
|
23
|
+
class BlockBase(nn.Module):
|
24
|
+
"""Block abstract module"""
|
25
|
+
|
26
|
+
def __init__(self, input_size, output_size):
|
27
|
+
super().__init__()
|
28
|
+
self.input_size = input_size
|
29
|
+
self.output_size = output_size
|
30
|
+
|
31
|
+
|
32
|
+
def get_activation(name="relu"):
|
33
|
+
"""Select an activation function by name
|
34
|
+
|
35
|
+
Args:
|
36
|
+
name: str
|
37
|
+
activation function name,
|
38
|
+
one of ["relu", "gelu", "swish", "sigmoid"],
|
39
|
+
default "relu".
|
40
|
+
"""
|
41
|
+
name = name.lower()
|
42
|
+
if name == "relu":
|
43
|
+
return nn.ReLU(inplace=True)
|
44
|
+
if name == "gelu":
|
45
|
+
return nn.GELU()
|
46
|
+
if name == "swish":
|
47
|
+
return Swish()
|
48
|
+
if name == "sigmoid":
|
49
|
+
return torch.nn.Sigmoid()
|
50
|
+
return nn.Identity()
|
51
|
+
|
52
|
+
|
53
|
+
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
|
54
|
+
"""
|
55
|
+
The function is very important for Transformer Transducer Streaming mode
|
56
|
+
Args:
|
57
|
+
xs_len (int): sequence length
|
58
|
+
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
|
59
|
+
It also supports adaptive chunk size [0,10,15,45]
|
60
|
+
left_window (int): how many left chunks can be seen
|
61
|
+
right_window (int): how many right chunks can be seen. It is used for
|
62
|
+
chunk overlap model.
|
63
|
+
Returns:
|
64
|
+
mask (torch.Tensor): a mask tensor for streaming model
|
65
|
+
Torch 1.0.1
|
66
|
+
tensor([[1., 1., 0., 0.],
|
67
|
+
[0., 1., 1., 0.],
|
68
|
+
[0., 0., 1., 1.]])
|
69
|
+
Torch 1.4.1
|
70
|
+
tensor([[True., True., False., False.],
|
71
|
+
[False., True., True., False.],
|
72
|
+
[False., False., True., True.]])
|
73
|
+
"""
|
74
|
+
chunk_start_idx = torch.Tensor(
|
75
|
+
chunk_start_idx
|
76
|
+
).long() # first idx of each chunk, such as [0,18,36,48].
|
77
|
+
start_pad = torch.nn.functional.pad(
|
78
|
+
chunk_start_idx, (1, 0)
|
79
|
+
) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
|
80
|
+
end_pad = torch.nn.functional.pad(
|
81
|
+
chunk_start_idx, (0, 1), value=x_len
|
82
|
+
) # append x_len to the end, so it becomes [0,18,36,48, x_len]
|
83
|
+
seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
|
84
|
+
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[
|
85
|
+
:, 1
|
86
|
+
] # idx size: [x_len]
|
87
|
+
# boundary = end_pad[idx] # boundary size: [x_len]
|
88
|
+
seq_range_expand = (
|
89
|
+
torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
|
90
|
+
) # seq_range_expand size [x_len, x_len]
|
91
|
+
idx_left = idx - left_window
|
92
|
+
idx_left[idx_left < 0] = 0
|
93
|
+
boundary_left = start_pad[idx_left]
|
94
|
+
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
|
95
|
+
idx_right = idx + right_window
|
96
|
+
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
|
97
|
+
boundary_right = end_pad[idx_right]
|
98
|
+
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
|
99
|
+
return mask_left & mask_right
|
100
|
+
|
101
|
+
|
102
|
+
class Swish(nn.Module):
|
103
|
+
"""Implement Swish activation module.
|
104
|
+
From https://arxiv.org/pdf/2005.03191.pdf
|
105
|
+
|
106
|
+
"""
|
107
|
+
|
108
|
+
def __init__(self) -> None:
|
109
|
+
super().__init__()
|
110
|
+
self.act_fn = nn.Sigmoid()
|
111
|
+
|
112
|
+
def forward(self, x: Tensor) -> Tensor:
|
113
|
+
"""Apply Swish function
|
114
|
+
|
115
|
+
Args:
|
116
|
+
x: torch.Tensor
|
117
|
+
Input.
|
118
|
+
"""
|
119
|
+
return x * self.act_fn(x)
|
120
|
+
|
121
|
+
|
122
|
+
class GLU(nn.Module):
|
123
|
+
"""Implement Gated Linear Unit (GLU) module"""
|
124
|
+
|
125
|
+
def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
|
126
|
+
super().__init__()
|
127
|
+
self.dim = dim
|
128
|
+
self.act_name = act_name.lower()
|
129
|
+
|
130
|
+
if self.act_name == "relu":
|
131
|
+
self.act_fn = nn.ReLU(inplace=True)
|
132
|
+
elif self.act_name == "gelu":
|
133
|
+
self.act_fn = nn.GELU()
|
134
|
+
elif self.act_name == "swish":
|
135
|
+
self.act_fn = Swish()
|
136
|
+
elif self.act_name == "sigmoid":
|
137
|
+
self.act_fn = nn.Sigmoid()
|
138
|
+
else:
|
139
|
+
self.act_fn = nn.Identity()
|
140
|
+
|
141
|
+
def forward(self, x: Tensor) -> Tensor:
|
142
|
+
"""GLU forward
|
143
|
+
Apply Swish function on the first half of input matrices
|
144
|
+
with sigmoid of the second half.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
x: torch.Tensor
|
148
|
+
Input.
|
149
|
+
|
150
|
+
"""
|
151
|
+
half_x, gate = x.chunk(2, dim=self.dim)
|
152
|
+
return half_x * self.act_fn(gate)
|
153
|
+
|
154
|
+
|
155
|
+
# TODO: Abdel, this can be improved using GLU module
|
156
|
+
class GLUPointWiseConv(nn.Module):
|
157
|
+
"""GLUPointWiseConv module
|
158
|
+
used for conformer architecture,
|
159
|
+
for more details see:
|
160
|
+
https://arxiv.org/pdf/2005.08100v1.pdf
|
161
|
+
|
162
|
+
Args:
|
163
|
+
input_dim: int
|
164
|
+
input channel size.
|
165
|
+
output_dim: int
|
166
|
+
output channel size.
|
167
|
+
kernel_size: int
|
168
|
+
kernel size
|
169
|
+
glu_type: str, optional
|
170
|
+
activation function one of
|
171
|
+
["sigmoid", "relu", "gelu"]
|
172
|
+
default "sigmoid".
|
173
|
+
bias_in_glu: bool, optional
|
174
|
+
use addtive bias in glu
|
175
|
+
causal: bool, optional
|
176
|
+
if set to True, padding is set to the half of
|
177
|
+
kernel size, ie, convolution can't see future frames.
|
178
|
+
default False.
|
179
|
+
|
180
|
+
"""
|
181
|
+
|
182
|
+
def __init__(
|
183
|
+
self,
|
184
|
+
input_dim,
|
185
|
+
output_dim,
|
186
|
+
kernel_size,
|
187
|
+
glu_type="sigmoid",
|
188
|
+
bias_in_glu=True,
|
189
|
+
causal=False,
|
190
|
+
):
|
191
|
+
super().__init__()
|
192
|
+
|
193
|
+
self.glu_type = glu_type
|
194
|
+
self.output_dim = output_dim
|
195
|
+
self.bias_in_glu = bias_in_glu
|
196
|
+
if causal:
|
197
|
+
self.ext_pw_conv_1d = nn.Conv1d(
|
198
|
+
input_dim,
|
199
|
+
output_dim * 2,
|
200
|
+
kernel_size,
|
201
|
+
1,
|
202
|
+
padding=(kernel_size - 1),
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
self.ext_pw_conv_1d = nn.Conv1d(
|
206
|
+
input_dim,
|
207
|
+
output_dim * 2,
|
208
|
+
kernel_size,
|
209
|
+
1,
|
210
|
+
padding=(kernel_size - 1) // 2,
|
211
|
+
)
|
212
|
+
|
213
|
+
if glu_type == "sigmoid":
|
214
|
+
self.glu_act = nn.Sigmoid()
|
215
|
+
elif glu_type == "relu":
|
216
|
+
self.glu_act = nn.ReLU()
|
217
|
+
elif glu_type == "gelu":
|
218
|
+
self.glu_act = nn.GELU()
|
219
|
+
elif glu_type == "swish":
|
220
|
+
self.glu_act = Swish()
|
221
|
+
else:
|
222
|
+
raise ValueError(f"Unsupported activation type {self.glu_act}")
|
223
|
+
|
224
|
+
if bias_in_glu:
|
225
|
+
self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
226
|
+
self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
|
227
|
+
|
228
|
+
def forward(self, x):
|
229
|
+
"""
|
230
|
+
Args:
|
231
|
+
x: torch.Tensor
|
232
|
+
input tensor
|
233
|
+
"""
|
234
|
+
# to be consistent with GLULinear, we assume the input always has the
|
235
|
+
# #channel (#dim) in the last dimension of the tensor, so need to
|
236
|
+
# switch the dimension first for 1D-Conv case
|
237
|
+
x = x.permute([0, 2, 1])
|
238
|
+
x = self.ext_pw_conv_1d(x)
|
239
|
+
if self.glu_type == "bilinear":
|
240
|
+
if self.bias_in_glu:
|
241
|
+
x = (x[:, 0 : self.output_dim, :] + self.b1) * (
|
242
|
+
x[:, self.output_dim : self.output_dim * 2, :] + self.b2
|
243
|
+
)
|
244
|
+
else:
|
245
|
+
x = (x[:, 0 : self.output_dim, :]) * (
|
246
|
+
x[:, self.output_dim : self.output_dim * 2, :]
|
247
|
+
)
|
248
|
+
else:
|
249
|
+
if self.bias_in_glu:
|
250
|
+
x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act(
|
251
|
+
x[:, self.output_dim : self.output_dim * 2, :] + self.b2
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
x = (x[:, 0 : self.output_dim, :]) * self.glu_act(
|
255
|
+
x[:, self.output_dim : self.output_dim * 2, :]
|
256
|
+
)
|
257
|
+
|
258
|
+
x = x.permute([0, 2, 1])
|
259
|
+
return x
|
260
|
+
|
261
|
+
|
262
|
+
class DepthWiseSeperableConv1d(nn.Module):
|
263
|
+
"""DepthWiseSeperableConv1d module used in Convnet module
|
264
|
+
for the conformer, for more details see:
|
265
|
+
https://arxiv.org/pdf/2005.08100v1.pdf
|
266
|
+
|
267
|
+
Args:
|
268
|
+
input_dim: int
|
269
|
+
input channel size.
|
270
|
+
depthwise_seperable_out_channel: int
|
271
|
+
if set different to 0, the number of
|
272
|
+
depthwise_seperable_out_channel will be used as a channel_out
|
273
|
+
of the second conv1d layer.
|
274
|
+
otherwise, it equal to 0, the second conv1d layer is skipped.
|
275
|
+
kernel_size: int
|
276
|
+
kernel_size
|
277
|
+
depthwise_multiplier: int
|
278
|
+
number of input_dim channels duplication. this value
|
279
|
+
will be used to compute the hidden channels of the Conv1D.
|
280
|
+
padding: int, optional
|
281
|
+
padding for the conv1d,
|
282
|
+
default: 0.
|
283
|
+
|
284
|
+
"""
|
285
|
+
|
286
|
+
def __init__(
|
287
|
+
self,
|
288
|
+
input_dim,
|
289
|
+
depthwise_seperable_out_channel,
|
290
|
+
kernel_size,
|
291
|
+
depthwise_multiplier,
|
292
|
+
padding=0,
|
293
|
+
):
|
294
|
+
super().__init__()
|
295
|
+
|
296
|
+
self.dw_conv = nn.Conv1d(
|
297
|
+
input_dim,
|
298
|
+
input_dim * depthwise_multiplier,
|
299
|
+
kernel_size,
|
300
|
+
1,
|
301
|
+
padding=padding,
|
302
|
+
groups=input_dim,
|
303
|
+
)
|
304
|
+
|
305
|
+
if depthwise_seperable_out_channel != 0:
|
306
|
+
self.pw_conv = nn.Conv1d(
|
307
|
+
input_dim * depthwise_multiplier,
|
308
|
+
depthwise_seperable_out_channel,
|
309
|
+
1,
|
310
|
+
1,
|
311
|
+
0,
|
312
|
+
)
|
313
|
+
else:
|
314
|
+
self.pw_conv = nn.Identity()
|
315
|
+
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
316
|
+
|
317
|
+
def forward(self, x):
|
318
|
+
"""
|
319
|
+
|
320
|
+
Args:
|
321
|
+
x: torch.Tensor
|
322
|
+
input tensor
|
323
|
+
"""
|
324
|
+
x = self.dw_conv(x)
|
325
|
+
if self.depthwise_seperable_out_channel != 0:
|
326
|
+
x = self.pw_conv(x)
|
327
|
+
return x
|
328
|
+
|
329
|
+
|
330
|
+
class ConvModule(nn.Module):
|
331
|
+
"""ConvModule Module for the conformer block.
|
332
|
+
for more details see:
|
333
|
+
https://arxiv.org/pdf/2005.08100v1.pdf
|
334
|
+
|
335
|
+
Args:
|
336
|
+
input_dim: int
|
337
|
+
input channel size.
|
338
|
+
ext_pw_out_channel: int
|
339
|
+
if > 0, ext_pw_out_channel is a dim channel size
|
340
|
+
for the last pointwise conv after swish activation.
|
341
|
+
depthwise_seperable_out_channel: int
|
342
|
+
if set different to 0, the number of
|
343
|
+
depthwise_seperable_out_channel
|
344
|
+
will be used as a channel_out of the second conv1d layer.
|
345
|
+
otherwise, it equal to 0, the second conv1d layer is skipped.
|
346
|
+
ext_pw_kernel_size: int
|
347
|
+
kernel size of the conv pointwise of the conformer.
|
348
|
+
kernel_size: int
|
349
|
+
kernel size.
|
350
|
+
depthwise_multiplier: int
|
351
|
+
number of input_dim channels duplication. this value
|
352
|
+
will be used to compute the hidden channels of the Conv1D.
|
353
|
+
dropout_rate: float
|
354
|
+
dropout rate.
|
355
|
+
causal: bool, optional
|
356
|
+
if set to True, convolution have no access
|
357
|
+
to future frames. default False.
|
358
|
+
batch_norm: bool, optional
|
359
|
+
if set to True, apply batchnorm before activation.
|
360
|
+
default False
|
361
|
+
chunk_se: int, optional
|
362
|
+
0 for offline SE.
|
363
|
+
1 for streaming SE, where mean is computed
|
364
|
+
by accumulated history until current chunk_se.
|
365
|
+
2 for streaming SE, where mean is computed
|
366
|
+
by only the current chunk.
|
367
|
+
chunk_size: int, optional
|
368
|
+
chunk size for cnn. default 18
|
369
|
+
activation: str, optional
|
370
|
+
activation function used in ConvModule,
|
371
|
+
default: "relu".
|
372
|
+
glu_type: str, optional
|
373
|
+
activation function used for the glu,
|
374
|
+
default: "sigmoid".
|
375
|
+
bias_in_glu: bool, optional
|
376
|
+
if set to True, use additive bias in the weight module
|
377
|
+
before GLU.
|
378
|
+
linear_glu_in_convm: bool, optional
|
379
|
+
if set to True, use GLULinear module,
|
380
|
+
otherwise, used GLUPointWiseConv module.
|
381
|
+
default to False.
|
382
|
+
export: bool, optional,
|
383
|
+
if set to True, padding is equal to 0. This is for inference,
|
384
|
+
or onnx export. Typically this is set by the export program or
|
385
|
+
the decoder program, and it isn't present in your config file.
|
386
|
+
default False
|
387
|
+
"""
|
388
|
+
|
389
|
+
def __init__(
|
390
|
+
self,
|
391
|
+
input_dim,
|
392
|
+
ext_pw_out_channel,
|
393
|
+
depthwise_seperable_out_channel,
|
394
|
+
ext_pw_kernel_size,
|
395
|
+
kernel_size,
|
396
|
+
depthwise_multiplier,
|
397
|
+
dropout_rate,
|
398
|
+
causal=False,
|
399
|
+
batch_norm=False,
|
400
|
+
chunk_se=0,
|
401
|
+
chunk_size=18,
|
402
|
+
activation="relu",
|
403
|
+
glu_type="sigmoid",
|
404
|
+
bias_in_glu=True,
|
405
|
+
linear_glu_in_convm=False,
|
406
|
+
export=False,
|
407
|
+
):
|
408
|
+
super().__init__()
|
409
|
+
self.layer_norm = nn.LayerNorm(input_dim)
|
410
|
+
self.input_dim = input_dim
|
411
|
+
self.ext_pw_out_channel = ext_pw_out_channel
|
412
|
+
self.ext_pw_kernel_size = ext_pw_kernel_size
|
413
|
+
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
414
|
+
self.glu_type = glu_type
|
415
|
+
self.bias_in_glu = bias_in_glu
|
416
|
+
self.linear_glu_in_convm = linear_glu_in_convm
|
417
|
+
self.causal = causal
|
418
|
+
|
419
|
+
self._add_ext_pw_layer()
|
420
|
+
|
421
|
+
self.batch_norm = batch_norm
|
422
|
+
self.kernel_size = kernel_size
|
423
|
+
|
424
|
+
if batch_norm:
|
425
|
+
self.bn_layer = nn.BatchNorm1d(input_dim)
|
426
|
+
|
427
|
+
self.act = get_activation(activation)
|
428
|
+
self.dropout = nn.Dropout(dropout_rate)
|
429
|
+
self.export = export
|
430
|
+
|
431
|
+
if causal:
|
432
|
+
padding = 0 if export else kernel_size - 1
|
433
|
+
else:
|
434
|
+
padding = (kernel_size - 1) // 2
|
435
|
+
|
436
|
+
self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
|
437
|
+
input_dim,
|
438
|
+
depthwise_seperable_out_channel,
|
439
|
+
kernel_size,
|
440
|
+
depthwise_multiplier,
|
441
|
+
padding=padding,
|
442
|
+
)
|
443
|
+
|
444
|
+
if depthwise_seperable_out_channel != 0:
|
445
|
+
if input_dim != depthwise_seperable_out_channel:
|
446
|
+
self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim)
|
447
|
+
else:
|
448
|
+
if depthwise_multiplier != 1:
|
449
|
+
self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim)
|
450
|
+
|
451
|
+
def _add_ext_pw_layer(self):
|
452
|
+
"""
|
453
|
+
This function is an extension of __init__ function
|
454
|
+
and dedicated to the convolution module creation
|
455
|
+
of the conformer.
|
456
|
+
"""
|
457
|
+
self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
|
458
|
+
nn.Identity()
|
459
|
+
) # jit hacks.
|
460
|
+
self.squeeze_excitation = nn.Identity() # jit.
|
461
|
+
self.apply_ln1 = self.fix_len1 = False # jit.
|
462
|
+
|
463
|
+
if self.ext_pw_out_channel != 0:
|
464
|
+
if self.causal:
|
465
|
+
self.ext_pw_conv_1d = nn.Conv1d(
|
466
|
+
self.input_dim,
|
467
|
+
self.ext_pw_out_channel,
|
468
|
+
self.ext_pw_kernel_size,
|
469
|
+
1,
|
470
|
+
padding=(self.ext_pw_kernel_size - 1),
|
471
|
+
)
|
472
|
+
if self.ext_pw_kernel_size > 1:
|
473
|
+
self.fix_len1 = True
|
474
|
+
else:
|
475
|
+
self.fix_len1 = False
|
476
|
+
else:
|
477
|
+
self.ext_pw_conv_1d = nn.Conv1d(
|
478
|
+
self.input_dim,
|
479
|
+
self.ext_pw_out_channel,
|
480
|
+
self.ext_pw_kernel_size,
|
481
|
+
1,
|
482
|
+
padding=(self.ext_pw_kernel_size - 1) // 2,
|
483
|
+
)
|
484
|
+
self.fix_len1 = False
|
485
|
+
|
486
|
+
if self.linear_glu_in_convm:
|
487
|
+
self.glu = GLULinear(
|
488
|
+
self.input_dim,
|
489
|
+
self.ext_pw_out_channel,
|
490
|
+
self.glu_type,
|
491
|
+
self.bias_in_glu,
|
492
|
+
)
|
493
|
+
else:
|
494
|
+
self.glu = GLUPointWiseConv(
|
495
|
+
self.input_dim,
|
496
|
+
self.ext_pw_out_channel,
|
497
|
+
self.ext_pw_kernel_size,
|
498
|
+
self.glu_type,
|
499
|
+
self.bias_in_glu,
|
500
|
+
self.causal,
|
501
|
+
)
|
502
|
+
|
503
|
+
if self.input_dim != self.ext_pw_out_channel:
|
504
|
+
self.apply_ln1 = True
|
505
|
+
self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
|
506
|
+
else:
|
507
|
+
self.apply_ln1 = False
|
508
|
+
else:
|
509
|
+
self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
|
510
|
+
self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
|
511
|
+
|
512
|
+
def forward(self, x):
|
513
|
+
"""ConvModule Forward.
|
514
|
+
|
515
|
+
Args:
|
516
|
+
x: torch.Tensor
|
517
|
+
input tensor.
|
518
|
+
"""
|
519
|
+
x = self.layer_norm(x)
|
520
|
+
|
521
|
+
if self.ext_pw_out_channel != 0:
|
522
|
+
x = self.glu(x)
|
523
|
+
if self.causal and self.ext_pw_kernel_size > 1:
|
524
|
+
x = x[:, : -(self.ext_pw_kernel_size - 1), :]
|
525
|
+
if self.apply_ln1:
|
526
|
+
x = self.ln1(x)
|
527
|
+
else:
|
528
|
+
x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
|
529
|
+
x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
|
530
|
+
x = x_0 + x_1
|
531
|
+
|
532
|
+
x = x.permute([0, 2, 1])
|
533
|
+
|
534
|
+
x = self.dw_sep_conv_1d(x)
|
535
|
+
if self.causal and self.kernel_size > 1:
|
536
|
+
x = x[:, :, : -(self.kernel_size - 1)]
|
537
|
+
if hasattr(self, "ln2"):
|
538
|
+
x = x.permute([0, 2, 1])
|
539
|
+
x = self.ln2(x)
|
540
|
+
x = x.permute([0, 2, 1])
|
541
|
+
if self.batch_norm:
|
542
|
+
x = self.bn_layer(x)
|
543
|
+
x = self.act(x)
|
544
|
+
|
545
|
+
if self.ext_pw_out_channel != 0:
|
546
|
+
x = self.ext_pw_conv_1d(x)
|
547
|
+
if self.fix_len1:
|
548
|
+
x = x[:, :, : -(self.ext_pw_kernel_size - 1)]
|
549
|
+
|
550
|
+
if self.apply_ln1:
|
551
|
+
x = x.permute([0, 2, 1])
|
552
|
+
x = self.ln1(x)
|
553
|
+
x = x.permute([0, 2, 1])
|
554
|
+
|
555
|
+
x = x.permute([0, 2, 1])
|
556
|
+
else:
|
557
|
+
x = x.unsqueeze(1).permute([0, 1, 3, 2])
|
558
|
+
x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
|
559
|
+
x = x.squeeze(1)
|
560
|
+
|
561
|
+
x = self.dropout(x)
|
562
|
+
return x
|
563
|
+
|
564
|
+
|
565
|
+
class GLULinear(nn.Module):
|
566
|
+
"""Linear + GLU module
|
567
|
+
|
568
|
+
Args:
|
569
|
+
input_dim: int
|
570
|
+
input size
|
571
|
+
output_dim: int
|
572
|
+
output size.
|
573
|
+
glu_type:
|
574
|
+
activation function name used in glu module.
|
575
|
+
default "sigmoid" (swish function).
|
576
|
+
bias_in_glu: bool, optional
|
577
|
+
If True, the addtive bias is added. Default False.
|
578
|
+
"""
|
579
|
+
|
580
|
+
def __init__(
|
581
|
+
self,
|
582
|
+
input_dim,
|
583
|
+
output_dim,
|
584
|
+
glu_type="sigmoid",
|
585
|
+
bias_in_glu=True,
|
586
|
+
):
|
587
|
+
super().__init__()
|
588
|
+
self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
|
589
|
+
self.glu_act = GLU(-1, glu_type)
|
590
|
+
|
591
|
+
def forward(self, x):
|
592
|
+
"""GLULinear forward
|
593
|
+
|
594
|
+
Args:
|
595
|
+
x: torch.Tensor
|
596
|
+
inpute tensor.
|
597
|
+
"""
|
598
|
+
x = self.linear(x)
|
599
|
+
return self.glu_act(x)
|
600
|
+
|
601
|
+
|
602
|
+
class FeedForward(nn.Module):
|
603
|
+
"""FeedForward Module.
|
604
|
+
For more details see Conformer paper:
|
605
|
+
https://arxiv.org/pdf/2005.08100.pdf
|
606
|
+
|
607
|
+
Args:
|
608
|
+
d_model: int
|
609
|
+
input size.
|
610
|
+
d_inner: int
|
611
|
+
output size.
|
612
|
+
dropout_rate: float,
|
613
|
+
dropout rate.
|
614
|
+
activation: str,
|
615
|
+
activation function name,
|
616
|
+
one of ["relu", "swish", "sigmoid"],
|
617
|
+
sigmoid activation is only used with "glu_in_fnn=True",
|
618
|
+
default "sigmoid".
|
619
|
+
bias_in_glu: bool, optional
|
620
|
+
"""
|
621
|
+
|
622
|
+
def __init__(
|
623
|
+
self,
|
624
|
+
d_model,
|
625
|
+
d_inner,
|
626
|
+
dropout_rate,
|
627
|
+
activation="sigmoid",
|
628
|
+
bias_in_glu=True,
|
629
|
+
):
|
630
|
+
super().__init__()
|
631
|
+
self.d_model = d_model
|
632
|
+
self.d_inner = d_inner
|
633
|
+
|
634
|
+
self.layer_norm = nn.LayerNorm(d_model)
|
635
|
+
module = GLULinear(d_model, d_inner, activation, bias_in_glu)
|
636
|
+
self.net = nn.Sequential(
|
637
|
+
module,
|
638
|
+
nn.Dropout(dropout_rate),
|
639
|
+
nn.Linear(d_inner, d_model),
|
640
|
+
nn.Dropout(dropout_rate),
|
641
|
+
)
|
642
|
+
|
643
|
+
def forward(self, x):
|
644
|
+
"""FeedForward forward function.
|
645
|
+
|
646
|
+
Args:
|
647
|
+
x: torch.Tensor
|
648
|
+
input tensor.
|
649
|
+
"""
|
650
|
+
out = self.net(self.layer_norm(x))
|
651
|
+
|
652
|
+
return out
|
653
|
+
|
654
|
+
|
655
|
+
#### positional encoding starts here
|
656
|
+
def _pre_hook(
|
657
|
+
state_dict,
|
658
|
+
prefix,
|
659
|
+
local_metadata,
|
660
|
+
strict,
|
661
|
+
missing_keys,
|
662
|
+
unexpected_keys,
|
663
|
+
error_msgs,
|
664
|
+
):
|
665
|
+
"""Perform pre-hook in load_state_dict for backward compatibility.
|
666
|
+
|
667
|
+
Note:
|
668
|
+
We saved self.pe until v.0.5.2 but we have omitted it later.
|
669
|
+
Therefore, we remove the item "pe" from `state_dict` for backward
|
670
|
+
compatibility.
|
671
|
+
|
672
|
+
"""
|
673
|
+
k = prefix + "pe"
|
674
|
+
if k in state_dict:
|
675
|
+
state_dict.pop(k)
|
676
|
+
|
677
|
+
|
678
|
+
class T5RelativeAttentionLogitBias(nn.Module):
|
679
|
+
"""
|
680
|
+
This module implements the relative position bias described in Section
|
681
|
+
2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
|
682
|
+
|
683
|
+
The Huggingface implementation is used as a reference
|
684
|
+
https://github.com/huggingface/transformers/blob/v4.30.0/src/
|
685
|
+
transformers/models/t5/modeling_t5.py#L435
|
686
|
+
|
687
|
+
Modifies attention as Q*K^T + B, where B is a learned scalar bias based
|
688
|
+
on relative position of the query and key. It is HxNxN, where H is the
|
689
|
+
number of heads, N is the sequence length.
|
690
|
+
|
691
|
+
I've made these modifications to the original T5 bias:
|
692
|
+
- Skipping of the bucketing step. Original T5 bias converted rel
|
693
|
+
position distances into logarithmically increasing buckets. This is
|
694
|
+
supposed to help with length generalization.
|
695
|
+
- I just directly use rel position index as bias values, as we don't
|
696
|
+
need length generalization (40s max is good enough for ASR encoder),
|
697
|
+
and it keeps ONNX export simple.
|
698
|
+
- I've also extended it so that biases can be asymmetric, the default
|
699
|
+
implementation treats L->R and R->L the same. Asymmetric was found to
|
700
|
+
yield better results in my experiments.
|
701
|
+
|
702
|
+
Args:
|
703
|
+
num_heads: int
|
704
|
+
Number of attention heads
|
705
|
+
num_buckets: int
|
706
|
+
Number of buckets to use for relative attention bias. This is the
|
707
|
+
size of the learnable bias parameter. Bucketing is not yet
|
708
|
+
supported, so this defaults to -1 which means no bucketing is
|
709
|
+
used (max_distance determines size of bias param).
|
710
|
+
max_distance: int
|
711
|
+
Maximum distance to use for relative attention bias. With
|
712
|
+
num_buckets=-1, this directly controls the max size of the bias
|
713
|
+
parameter. When num_buckets > 0 is supported, this will control
|
714
|
+
the maximum distance for logarithmic bucketing after which all
|
715
|
+
positions are in the same bucket.
|
716
|
+
symmetric: bool
|
717
|
+
Whether to use symmetric or asymmetric biases. symmetric=False uses
|
718
|
+
2x number of bias params to distinguish L->R from R->L. This was
|
719
|
+
found to be better for the encoder.
|
720
|
+
"""
|
721
|
+
|
722
|
+
def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False):
|
723
|
+
super().__init__()
|
724
|
+
self.num_heads = num_heads
|
725
|
+
self.num_buckets = num_buckets
|
726
|
+
self.max_distance = max_distance
|
727
|
+
self.symmetric = symmetric
|
728
|
+
self._skip_bucketing = self.num_buckets < 0
|
729
|
+
if self._skip_bucketing:
|
730
|
+
self.num_buckets = max_distance
|
731
|
+
else:
|
732
|
+
raise NotImplementedError(
|
733
|
+
"T5 attention bias with bucketed positions is not yet tested"
|
734
|
+
)
|
735
|
+
if not self.symmetric:
|
736
|
+
self.num_buckets *= 2
|
737
|
+
self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
|
738
|
+
|
739
|
+
def forward(self, x):
|
740
|
+
# instantiate bias compatible with shape of x
|
741
|
+
maxpos = x.size(1)
|
742
|
+
context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
|
743
|
+
:, None
|
744
|
+
]
|
745
|
+
memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[
|
746
|
+
None, :
|
747
|
+
]
|
748
|
+
relative_position = memory_position - context_position
|
749
|
+
# clipping to a maximum distance using ops that play well with ONNX
|
750
|
+
# export
|
751
|
+
relative_position = relative_position.masked_fill(
|
752
|
+
relative_position < -self.max_distance, -self.max_distance
|
753
|
+
)
|
754
|
+
relative_position = relative_position.masked_fill(
|
755
|
+
relative_position > self.max_distance - 1, self.max_distance - 1
|
756
|
+
)
|
757
|
+
|
758
|
+
# mapping from relative position to index in the bias parameter
|
759
|
+
if self._skip_bucketing:
|
760
|
+
bias_idx = relative_position
|
761
|
+
else:
|
762
|
+
bias_idx = self._bucket_relative_position(relative_position)
|
763
|
+
if self.symmetric:
|
764
|
+
bias_idx = bias_idx.abs()
|
765
|
+
else:
|
766
|
+
bias_idx += self.num_buckets // 2
|
767
|
+
|
768
|
+
t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
|
769
|
+
t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L]
|
770
|
+
|
771
|
+
return t5_rel_att_bias
|
772
|
+
|
773
|
+
def _bucket_relative_position(self, relative_position):
|
774
|
+
# this is a placeholder (isn't tested, likely buggy) using HuggingFace
|
775
|
+
# implem as a reference this also needs to be extended to support
|
776
|
+
# asymmetric +/- ve positions
|
777
|
+
relative_buckets = 0
|
778
|
+
if not self.causal:
|
779
|
+
self.num_buckets //= 2
|
780
|
+
relative_buckets += (relative_position > 0).to(
|
781
|
+
torch.long
|
782
|
+
) * self.num_buckets
|
783
|
+
relative_position = torch.abs(relative_position)
|
784
|
+
else:
|
785
|
+
relative_position = -torch.min(
|
786
|
+
relative_position, torch.zeros_like(relative_position)
|
787
|
+
)
|
788
|
+
# now relative_position is in the range [0, inf)
|
789
|
+
|
790
|
+
# half of the buckets are for exact increments in positions
|
791
|
+
max_exact = self.num_buckets // 2
|
792
|
+
is_small = relative_position < max_exact
|
793
|
+
|
794
|
+
# The other half of the buckets are for logarithmically bigger bins in
|
795
|
+
# positions up to max_distance
|
796
|
+
relative_position_if_large = max_exact + (
|
797
|
+
torch.log(relative_position.float() / max_exact)
|
798
|
+
/ math.log(self.max_distance / max_exact)
|
799
|
+
* (self.num_buckets - max_exact)
|
800
|
+
).to(torch.long)
|
801
|
+
relative_position_if_large = torch.min(
|
802
|
+
relative_position_if_large,
|
803
|
+
torch.full_like(relative_position_if_large, self.num_buckets - 1),
|
804
|
+
)
|
805
|
+
|
806
|
+
relative_buckets += torch.where(
|
807
|
+
is_small, relative_position, relative_position_if_large
|
808
|
+
)
|
809
|
+
return relative_buckets
|
810
|
+
|
811
|
+
|
812
|
+
class AbsolutePositionalEncoding(nn.Module):
|
813
|
+
"""Absolute Positional encoding module.
|
814
|
+
This module implement Absolute sinusoidal positional encoding
|
815
|
+
from: https://arxiv.org/pdf/1706.03762.pdf
|
816
|
+
|
817
|
+
Args:
|
818
|
+
d_model: int
|
819
|
+
Input embedding size.
|
820
|
+
dropout_rate: float
|
821
|
+
dropout rate
|
822
|
+
max_len: int, optional
|
823
|
+
Maximum input length sequence, Default 5000
|
824
|
+
|
825
|
+
"""
|
826
|
+
|
827
|
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
828
|
+
"""Construct an PositionalEncoding object."""
|
829
|
+
super().__init__()
|
830
|
+
self.d_model = d_model
|
831
|
+
self.xscale = math.sqrt(self.d_model)
|
832
|
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
833
|
+
self.pe = None
|
834
|
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
835
|
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
836
|
+
|
837
|
+
def extend_pe(self, x):
|
838
|
+
"""Reset the positional encodings.
|
839
|
+
|
840
|
+
Args:
|
841
|
+
x: torch.Tensor
|
842
|
+
"""
|
843
|
+
if self.pe is not None and self.pe.size(1) >= x.size(1):
|
844
|
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
845
|
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
846
|
+
return
|
847
|
+
pe = torch.zeros(x.size(1), self.d_model)
|
848
|
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
849
|
+
div_term = torch.exp(
|
850
|
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
851
|
+
* -(math.log(10000.0) / self.d_model)
|
852
|
+
)
|
853
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
854
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
855
|
+
pe = pe.unsqueeze(0)
|
856
|
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
857
|
+
|
858
|
+
def forward(self, x: torch.Tensor):
|
859
|
+
"""Add positional encoding.
|
860
|
+
|
861
|
+
Args:
|
862
|
+
x: torch.Tensor
|
863
|
+
Input tensor. shape is (batch, time, ...)
|
864
|
+
|
865
|
+
Returns:
|
866
|
+
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
867
|
+
|
868
|
+
"""
|
869
|
+
self.extend_pe(x)
|
870
|
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
871
|
+
return self.dropout(x)
|
872
|
+
|
873
|
+
|
874
|
+
#### forward embedding layers starts here
|
875
|
+
class MeanVarianceNormLayer(nn.Module):
|
876
|
+
"""Mean/variance normalization layer.
|
877
|
+
|
878
|
+
Will subtract mean and multiply input by inverted standard deviation.
|
879
|
+
Typically used as a very first layer in a model.
|
880
|
+
|
881
|
+
Args:
|
882
|
+
input_size: int
|
883
|
+
layer input size.
|
884
|
+
"""
|
885
|
+
|
886
|
+
def __init__(self, input_size):
|
887
|
+
super().__init__()
|
888
|
+
self.input_size = input_size
|
889
|
+
self.global_mean = nn.Parameter(torch.zeros(input_size))
|
890
|
+
self.global_invstd = nn.Parameter(torch.ones(input_size))
|
891
|
+
|
892
|
+
def forward(self, input_: Tensor) -> Tensor:
|
893
|
+
"""MeanVarianceNormLayer Forward
|
894
|
+
|
895
|
+
Args:
|
896
|
+
input_: torch.Tensor
|
897
|
+
input tensor.
|
898
|
+
"""
|
899
|
+
return (input_ - self.global_mean) * self.global_invstd
|
900
|
+
|
901
|
+
|
902
|
+
class CausalConv1D(nn.Conv1d):
|
903
|
+
"""
|
904
|
+
A causal version of nn.Conv1d where each step would have limited access to
|
905
|
+
locations on its right or left
|
906
|
+
All arguments are the same as nn.Conv1d except padding.
|
907
|
+
|
908
|
+
If padding is set None, then paddings are set automatically to make it a
|
909
|
+
causal convolution where each location would not see any steps on its right.
|
910
|
+
|
911
|
+
If padding is set as a list (size of 2), then padding[0] would be used as
|
912
|
+
left padding and padding[1] as right padding.
|
913
|
+
It would make it possible to control the number of steps to be accessible
|
914
|
+
on the right and left.
|
915
|
+
This mode is not supported when stride > 1. padding[0]+padding[1] should
|
916
|
+
be equal to (kernel_size - 1).
|
917
|
+
"""
|
918
|
+
|
919
|
+
def __init__(
|
920
|
+
self,
|
921
|
+
in_channels: int,
|
922
|
+
out_channels: int,
|
923
|
+
kernel_size: int,
|
924
|
+
stride: int = 1,
|
925
|
+
padding: Union[str, int] = 0,
|
926
|
+
dilation: int = 1,
|
927
|
+
groups: int = 1,
|
928
|
+
bias: bool = True,
|
929
|
+
padding_mode: str = "zeros",
|
930
|
+
device=None,
|
931
|
+
dtype=None,
|
932
|
+
) -> None:
|
933
|
+
self.cache_drop_size = None
|
934
|
+
if padding is None:
|
935
|
+
self._left_padding = kernel_size - 1
|
936
|
+
self._right_padding = stride - 1
|
937
|
+
else:
|
938
|
+
if stride != 1 and padding != kernel_size - 1:
|
939
|
+
raise ValueError("No striding allowed for non-symmetric convolutions!")
|
940
|
+
if isinstance(padding, int):
|
941
|
+
self._left_padding = padding
|
942
|
+
self._right_padding = padding
|
943
|
+
elif (
|
944
|
+
isinstance(padding, list)
|
945
|
+
and len(padding) == 2
|
946
|
+
and padding[0] + padding[1] == kernel_size - 1
|
947
|
+
):
|
948
|
+
self._left_padding = padding[0]
|
949
|
+
self._right_padding = padding[1]
|
950
|
+
else:
|
951
|
+
raise ValueError(f"Invalid padding param: {padding}!")
|
952
|
+
|
953
|
+
self._max_cache_len = self._left_padding
|
954
|
+
|
955
|
+
super().__init__(
|
956
|
+
in_channels=in_channels,
|
957
|
+
out_channels=out_channels,
|
958
|
+
kernel_size=kernel_size,
|
959
|
+
stride=stride,
|
960
|
+
padding=0,
|
961
|
+
dilation=dilation,
|
962
|
+
groups=groups,
|
963
|
+
bias=bias,
|
964
|
+
padding_mode=padding_mode,
|
965
|
+
device=device,
|
966
|
+
dtype=dtype,
|
967
|
+
)
|
968
|
+
|
969
|
+
def update_cache(self, x, cache=None):
|
970
|
+
if cache is None:
|
971
|
+
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
|
972
|
+
next_cache = cache
|
973
|
+
else:
|
974
|
+
new_x = F.pad(x, pad=(0, self._right_padding))
|
975
|
+
new_x = torch.cat([cache, new_x], dim=-1)
|
976
|
+
if self.cache_drop_size > 0:
|
977
|
+
next_cache = new_x[:, :, : -self.cache_drop_size]
|
978
|
+
else:
|
979
|
+
next_cache = new_x
|
980
|
+
next_cache = next_cache[:, :, -cache.size(-1) :]
|
981
|
+
return new_x, next_cache
|
982
|
+
|
983
|
+
def forward(self, x, cache=None):
|
984
|
+
x, cache = self.update_cache(x, cache=cache)
|
985
|
+
x = super().forward(x)
|
986
|
+
if cache is None:
|
987
|
+
return x
|
988
|
+
else:
|
989
|
+
return x, cache
|
990
|
+
|
991
|
+
|
992
|
+
class CausalConv2D(nn.Conv2d):
|
993
|
+
"""
|
994
|
+
A causal version of nn.Conv2d where each location in the 2D matrix would
|
995
|
+
have no access to locations on its right or down
|
996
|
+
All arguments are the same as nn.Conv2d except padding which should be
|
997
|
+
set as None
|
998
|
+
"""
|
999
|
+
|
1000
|
+
def __init__(
|
1001
|
+
self,
|
1002
|
+
in_channels: int,
|
1003
|
+
out_channels: int,
|
1004
|
+
kernel_size: int,
|
1005
|
+
stride: int = 1,
|
1006
|
+
padding: Union[str, int] = 0,
|
1007
|
+
dilation: int = 1,
|
1008
|
+
groups: int = 1,
|
1009
|
+
bias: bool = True,
|
1010
|
+
padding_mode: str = "zeros",
|
1011
|
+
device=None,
|
1012
|
+
dtype=None,
|
1013
|
+
) -> None:
|
1014
|
+
if padding is not None:
|
1015
|
+
raise ValueError("Argument padding should be set to None for CausalConv2D.")
|
1016
|
+
self._left_padding = kernel_size - 1
|
1017
|
+
self._right_padding = stride - 1
|
1018
|
+
|
1019
|
+
padding = 0
|
1020
|
+
super().__init__(
|
1021
|
+
in_channels,
|
1022
|
+
out_channels,
|
1023
|
+
kernel_size,
|
1024
|
+
stride,
|
1025
|
+
padding,
|
1026
|
+
dilation,
|
1027
|
+
groups,
|
1028
|
+
bias,
|
1029
|
+
padding_mode,
|
1030
|
+
device,
|
1031
|
+
dtype,
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
def forward(
|
1035
|
+
self,
|
1036
|
+
x,
|
1037
|
+
):
|
1038
|
+
x = F.pad(
|
1039
|
+
x,
|
1040
|
+
pad=(self._left_padding, self._right_padding, 0, 0),
|
1041
|
+
)
|
1042
|
+
x = super().forward(x)
|
1043
|
+
return x
|
1044
|
+
|
1045
|
+
|
1046
|
+
class NemoConvSubsampling(torch.nn.Module):
|
1047
|
+
"""Convlutional subsampling module, taken from NeMo ASR
|
1048
|
+
(https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
|
1049
|
+
34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
|
1050
|
+
|
1051
|
+
Striding Subsampling: "Speech-Transformer: A No-Recurrence
|
1052
|
+
Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
|
1053
|
+
et al. (https://ieeexplore.ieee.org/document/8462506)
|
1054
|
+
|
1055
|
+
|
1056
|
+
Compared with the EncoderConv2D (`input_layer: custom`), this is a
|
1057
|
+
much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
|
1058
|
+
Moreover, depthwise convolutions are used to reduce FLOPs, but the first
|
1059
|
+
layer is kept as a regular convolution so as not to degrade accuracy.
|
1060
|
+
|
1061
|
+
`Striding` and `dw_striding` are the same except that the latter uses
|
1062
|
+
depthwise convolutions after the first layer, whereas the former does not.
|
1063
|
+
|
1064
|
+
Args:
|
1065
|
+
subsampling_factor (int): Time reduction factor
|
1066
|
+
feat_in (int): size of the input features
|
1067
|
+
feat_out (int): size of the output features
|
1068
|
+
subsampling (str): The subsampling technique, choose from
|
1069
|
+
{"striding", "dw-striding", "striding_conv1d",
|
1070
|
+
"dw_striding_conv1d"}
|
1071
|
+
conv_channels (int): Number of channels for the convolution layers,
|
1072
|
+
default is 256.
|
1073
|
+
subsampling_conv_chunking_factor (int): Input chunking factor which
|
1074
|
+
can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
|
1075
|
+
activation (Module): activation function, default is nn.ReLU()
|
1076
|
+
is_causal (bool): whether to use causal Conv1/2D, where each step will
|
1077
|
+
have limited access to locations on its right or left
|
1078
|
+
"""
|
1079
|
+
|
1080
|
+
def __init__(
|
1081
|
+
self,
|
1082
|
+
feat_in,
|
1083
|
+
feat_out,
|
1084
|
+
subsampling_factor=4,
|
1085
|
+
subsampling="dw_striding",
|
1086
|
+
conv_channels=256,
|
1087
|
+
subsampling_conv_chunking_factor=1,
|
1088
|
+
activation=nn.ReLU(), # noqa: B008
|
1089
|
+
is_causal=False,
|
1090
|
+
):
|
1091
|
+
super().__init__()
|
1092
|
+
self._subsampling = subsampling
|
1093
|
+
self._conv_channels = conv_channels
|
1094
|
+
self._feat_in = feat_in
|
1095
|
+
self._feat_out = feat_out
|
1096
|
+
|
1097
|
+
if subsampling_factor % 2 != 0:
|
1098
|
+
raise ValueError("Sampling factor should be a multiply of 2!")
|
1099
|
+
self._sampling_num = int(math.log(subsampling_factor, 2))
|
1100
|
+
self.subsampling_factor = subsampling_factor
|
1101
|
+
self.is_causal = is_causal
|
1102
|
+
self.subsampling_causal_cond = subsampling in (
|
1103
|
+
"dw_striding",
|
1104
|
+
"striding",
|
1105
|
+
"striding_conv1d",
|
1106
|
+
)
|
1107
|
+
|
1108
|
+
if (
|
1109
|
+
subsampling_conv_chunking_factor != -1
|
1110
|
+
and subsampling_conv_chunking_factor != 1
|
1111
|
+
and subsampling_conv_chunking_factor % 2 != 0
|
1112
|
+
):
|
1113
|
+
raise ValueError(
|
1114
|
+
"subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2"
|
1115
|
+
)
|
1116
|
+
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
1117
|
+
|
1118
|
+
in_channels = 1
|
1119
|
+
layers = []
|
1120
|
+
|
1121
|
+
if subsampling == "dw_striding":
|
1122
|
+
self._stride = 2
|
1123
|
+
self._kernel_size = 3
|
1124
|
+
self._ceil_mode = False
|
1125
|
+
|
1126
|
+
if self.is_causal:
|
1127
|
+
self._left_padding = self._kernel_size - 1
|
1128
|
+
self._right_padding = self._stride - 1
|
1129
|
+
self._max_cache_len = subsampling_factor + 1
|
1130
|
+
else:
|
1131
|
+
self._left_padding = (self._kernel_size - 1) // 2
|
1132
|
+
self._right_padding = (self._kernel_size - 1) // 2
|
1133
|
+
self._max_cache_len = 0
|
1134
|
+
|
1135
|
+
# Layer 1
|
1136
|
+
if self.is_causal:
|
1137
|
+
layers.append(
|
1138
|
+
CausalConv2D(
|
1139
|
+
in_channels=in_channels,
|
1140
|
+
out_channels=conv_channels,
|
1141
|
+
kernel_size=self._kernel_size,
|
1142
|
+
stride=self._stride,
|
1143
|
+
padding=None,
|
1144
|
+
)
|
1145
|
+
)
|
1146
|
+
else:
|
1147
|
+
layers.append(
|
1148
|
+
torch.nn.Conv2d(
|
1149
|
+
in_channels=in_channels,
|
1150
|
+
out_channels=conv_channels,
|
1151
|
+
kernel_size=self._kernel_size,
|
1152
|
+
stride=self._stride,
|
1153
|
+
padding=self._left_padding,
|
1154
|
+
)
|
1155
|
+
)
|
1156
|
+
in_channels = conv_channels
|
1157
|
+
layers.append(activation)
|
1158
|
+
|
1159
|
+
for i in range(self._sampling_num - 1):
|
1160
|
+
if self.is_causal:
|
1161
|
+
layers.append(
|
1162
|
+
CausalConv2D(
|
1163
|
+
in_channels=in_channels,
|
1164
|
+
out_channels=in_channels,
|
1165
|
+
kernel_size=self._kernel_size,
|
1166
|
+
stride=self._stride,
|
1167
|
+
padding=None,
|
1168
|
+
groups=in_channels,
|
1169
|
+
)
|
1170
|
+
)
|
1171
|
+
else:
|
1172
|
+
layers.append(
|
1173
|
+
torch.nn.Conv2d(
|
1174
|
+
in_channels=in_channels,
|
1175
|
+
out_channels=in_channels,
|
1176
|
+
kernel_size=self._kernel_size,
|
1177
|
+
stride=self._stride,
|
1178
|
+
padding=self._left_padding,
|
1179
|
+
groups=in_channels,
|
1180
|
+
)
|
1181
|
+
)
|
1182
|
+
|
1183
|
+
layers.append(
|
1184
|
+
torch.nn.Conv2d(
|
1185
|
+
in_channels=in_channels,
|
1186
|
+
out_channels=conv_channels,
|
1187
|
+
kernel_size=1,
|
1188
|
+
stride=1,
|
1189
|
+
padding=0,
|
1190
|
+
groups=1,
|
1191
|
+
)
|
1192
|
+
)
|
1193
|
+
layers.append(activation)
|
1194
|
+
in_channels = conv_channels
|
1195
|
+
|
1196
|
+
elif subsampling == "striding":
|
1197
|
+
self._stride = 2
|
1198
|
+
self._kernel_size = 3
|
1199
|
+
self._ceil_mode = False
|
1200
|
+
|
1201
|
+
if self.is_causal:
|
1202
|
+
self._left_padding = self._kernel_size - 1
|
1203
|
+
self._right_padding = self._stride - 1
|
1204
|
+
self._max_cache_len = subsampling_factor + 1
|
1205
|
+
else:
|
1206
|
+
self._left_padding = (self._kernel_size - 1) // 2
|
1207
|
+
self._right_padding = (self._kernel_size - 1) // 2
|
1208
|
+
self._max_cache_len = 0
|
1209
|
+
|
1210
|
+
for i in range(self._sampling_num):
|
1211
|
+
if self.is_causal:
|
1212
|
+
layers.append(
|
1213
|
+
CausalConv2D(
|
1214
|
+
in_channels=in_channels,
|
1215
|
+
out_channels=conv_channels,
|
1216
|
+
kernel_size=self._kernel_size,
|
1217
|
+
stride=self._stride,
|
1218
|
+
padding=None,
|
1219
|
+
)
|
1220
|
+
)
|
1221
|
+
else:
|
1222
|
+
layers.append(
|
1223
|
+
torch.nn.Conv2d(
|
1224
|
+
in_channels=in_channels,
|
1225
|
+
out_channels=conv_channels,
|
1226
|
+
kernel_size=self._kernel_size,
|
1227
|
+
stride=self._stride,
|
1228
|
+
padding=self._left_padding,
|
1229
|
+
)
|
1230
|
+
)
|
1231
|
+
layers.append(activation)
|
1232
|
+
in_channels = conv_channels
|
1233
|
+
|
1234
|
+
elif subsampling == "striding_conv1d":
|
1235
|
+
in_channels = feat_in
|
1236
|
+
|
1237
|
+
self._stride = 2
|
1238
|
+
self._kernel_size = 5
|
1239
|
+
self._ceil_mode = False
|
1240
|
+
|
1241
|
+
if self.is_causal:
|
1242
|
+
self._left_padding = self._kernel_size - 1
|
1243
|
+
self._right_padding = self._stride - 1
|
1244
|
+
self._max_cache_len = subsampling_factor + 1
|
1245
|
+
else:
|
1246
|
+
self._left_padding = (self._kernel_size - 1) // 2
|
1247
|
+
self._right_padding = (self._kernel_size - 1) // 2
|
1248
|
+
self._max_cache_len = 0
|
1249
|
+
|
1250
|
+
for i in range(self._sampling_num):
|
1251
|
+
if self.is_causal:
|
1252
|
+
layers.append(
|
1253
|
+
CausalConv1D(
|
1254
|
+
in_channels=in_channels,
|
1255
|
+
out_channels=(
|
1256
|
+
feat_out
|
1257
|
+
if self._sampling_num == i + 1
|
1258
|
+
else conv_channels
|
1259
|
+
),
|
1260
|
+
kernel_size=self._kernel_size,
|
1261
|
+
stride=self._stride,
|
1262
|
+
padding=None,
|
1263
|
+
)
|
1264
|
+
)
|
1265
|
+
else:
|
1266
|
+
layers.append(
|
1267
|
+
torch.nn.Conv1d(
|
1268
|
+
in_channels=in_channels,
|
1269
|
+
out_channels=(
|
1270
|
+
feat_out
|
1271
|
+
if self._sampling_num == i + 1
|
1272
|
+
else conv_channels
|
1273
|
+
),
|
1274
|
+
kernel_size=self._kernel_size,
|
1275
|
+
stride=self._stride,
|
1276
|
+
padding=self._left_padding,
|
1277
|
+
)
|
1278
|
+
)
|
1279
|
+
layers.append(activation)
|
1280
|
+
in_channels = conv_channels
|
1281
|
+
|
1282
|
+
elif subsampling == "dw_striding_conv1d":
|
1283
|
+
in_channels = feat_in
|
1284
|
+
|
1285
|
+
self._stride = 2
|
1286
|
+
self._kernel_size = 5
|
1287
|
+
self._ceil_mode = False
|
1288
|
+
|
1289
|
+
self._left_padding = (self._kernel_size - 1) // 2
|
1290
|
+
self._right_padding = (self._kernel_size - 1) // 2
|
1291
|
+
|
1292
|
+
# Layer 1
|
1293
|
+
layers.extend(
|
1294
|
+
[
|
1295
|
+
torch.nn.Conv1d(
|
1296
|
+
in_channels=in_channels,
|
1297
|
+
out_channels=in_channels,
|
1298
|
+
kernel_size=self._kernel_size,
|
1299
|
+
stride=self._stride,
|
1300
|
+
padding=self._left_padding,
|
1301
|
+
groups=in_channels,
|
1302
|
+
),
|
1303
|
+
torch.nn.Conv1d(
|
1304
|
+
in_channels=in_channels,
|
1305
|
+
out_channels=(
|
1306
|
+
feat_out if self._sampling_num == 1 else conv_channels
|
1307
|
+
),
|
1308
|
+
kernel_size=1,
|
1309
|
+
stride=1,
|
1310
|
+
padding=0,
|
1311
|
+
groups=1,
|
1312
|
+
),
|
1313
|
+
]
|
1314
|
+
)
|
1315
|
+
in_channels = conv_channels
|
1316
|
+
layers.append(activation)
|
1317
|
+
|
1318
|
+
for i in range(self._sampling_num - 1):
|
1319
|
+
layers.extend(
|
1320
|
+
[
|
1321
|
+
torch.nn.Conv1d(
|
1322
|
+
in_channels=in_channels,
|
1323
|
+
out_channels=in_channels,
|
1324
|
+
kernel_size=self._kernel_size,
|
1325
|
+
stride=self._stride,
|
1326
|
+
padding=self._left_padding,
|
1327
|
+
groups=in_channels,
|
1328
|
+
),
|
1329
|
+
torch.nn.Conv1d(
|
1330
|
+
in_channels=in_channels,
|
1331
|
+
out_channels=(
|
1332
|
+
feat_out
|
1333
|
+
if self._sampling_num == i + 2
|
1334
|
+
else conv_channels
|
1335
|
+
),
|
1336
|
+
kernel_size=1,
|
1337
|
+
stride=1,
|
1338
|
+
padding=0,
|
1339
|
+
groups=1,
|
1340
|
+
),
|
1341
|
+
]
|
1342
|
+
)
|
1343
|
+
layers.append(activation)
|
1344
|
+
in_channels = conv_channels
|
1345
|
+
|
1346
|
+
else:
|
1347
|
+
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
1348
|
+
|
1349
|
+
if subsampling in ["dw_striding", "striding"]:
|
1350
|
+
in_length = torch.tensor(feat_in, dtype=torch.float)
|
1351
|
+
out_length = calc_length(
|
1352
|
+
lengths=in_length,
|
1353
|
+
all_paddings=self._left_padding + self._right_padding,
|
1354
|
+
kernel_size=self._kernel_size,
|
1355
|
+
stride=self._stride,
|
1356
|
+
ceil_mode=self._ceil_mode,
|
1357
|
+
repeat_num=self._sampling_num,
|
1358
|
+
)
|
1359
|
+
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
|
1360
|
+
self.conv2d_subsampling = True
|
1361
|
+
elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
|
1362
|
+
self.out = None
|
1363
|
+
self.conv2d_subsampling = False
|
1364
|
+
else:
|
1365
|
+
raise ValueError(f"Not valid sub-sampling: {subsampling}!")
|
1366
|
+
|
1367
|
+
self.conv = torch.nn.Sequential(*layers)
|
1368
|
+
|
1369
|
+
def get_sampling_frames(self):
|
1370
|
+
return [1, self.subsampling_factor]
|
1371
|
+
|
1372
|
+
def get_streaming_cache_size(self):
|
1373
|
+
return [0, self.subsampling_factor + 1]
|
1374
|
+
|
1375
|
+
def forward(self, x, mask):
|
1376
|
+
"""
|
1377
|
+
Forward method for NeMo subsampling.
|
1378
|
+
|
1379
|
+
Args:
|
1380
|
+
x[Batch, Time, Filters]: torch.Tensor
|
1381
|
+
input tensor
|
1382
|
+
x_mask: torch.Tensor
|
1383
|
+
input mask
|
1384
|
+
|
1385
|
+
Returns:
|
1386
|
+
x: torch.Tensor
|
1387
|
+
Resulting tensor from subsampling (B, T //
|
1388
|
+
time_reduction_factor, feat_out)
|
1389
|
+
pad_mask: torch.Tensor
|
1390
|
+
tensor of padded hidden state sequences (B, 1, T //
|
1391
|
+
time_reduction_factor)
|
1392
|
+
"""
|
1393
|
+
x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
|
1394
|
+
|
1395
|
+
# split inputs if chunking_factor is set
|
1396
|
+
if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
|
1397
|
+
if self.subsampling_conv_chunking_factor == 1:
|
1398
|
+
# if subsampling_conv_chunking_factor is 1, we split only
|
1399
|
+
# if needed.
|
1400
|
+
# avoiding a bug / feature limiting indexing of tensors
|
1401
|
+
# to 2**31.
|
1402
|
+
# see https://github.com/pytorch/pytorch/issues/80020
|
1403
|
+
x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
|
1404
|
+
need_to_split = torch.numel(x) > x_ceil
|
1405
|
+
else:
|
1406
|
+
# if subsampling_conv_chunking_factor > 1 we always split
|
1407
|
+
need_to_split = True
|
1408
|
+
|
1409
|
+
if need_to_split:
|
1410
|
+
x, success = self.conv_split_by_batch(x)
|
1411
|
+
if not success: # if unable to split by batch, try by channel
|
1412
|
+
if self._subsampling == "dw_striding":
|
1413
|
+
x = self.conv_split_by_channel(x)
|
1414
|
+
else:
|
1415
|
+
x = self.conv(x) # try anyway
|
1416
|
+
else:
|
1417
|
+
x = self.conv(x)
|
1418
|
+
else:
|
1419
|
+
x = self.conv(x)
|
1420
|
+
|
1421
|
+
# Flatten Channel and Frequency Axes
|
1422
|
+
if self.conv2d_subsampling:
|
1423
|
+
b, c, t, f = x.size()
|
1424
|
+
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
|
1425
|
+
# Transpose to Channel Last mode
|
1426
|
+
else:
|
1427
|
+
x = x.transpose(1, 2)
|
1428
|
+
|
1429
|
+
if mask is None:
|
1430
|
+
return x, None
|
1431
|
+
|
1432
|
+
max_audio_length = x.shape[1]
|
1433
|
+
feature_lens = mask.sum(1)
|
1434
|
+
padding_length = torch.ceil(feature_lens / self.subsampling_factor)
|
1435
|
+
if self.is_causal and self.subsampling_causal_cond:
|
1436
|
+
feature_lens_remainder = feature_lens % self.subsampling_factor
|
1437
|
+
padding_length[feature_lens_remainder != 1] += 1
|
1438
|
+
pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
|
1439
|
+
padding_length.size(0), -1
|
1440
|
+
) < padding_length.unsqueeze(1)
|
1441
|
+
return x, pad_mask.unsqueeze(1)
|
1442
|
+
|
1443
|
+
def reset_parameters(self):
|
1444
|
+
# initialize weights
|
1445
|
+
if self._subsampling == "dw_striding":
|
1446
|
+
with torch.no_grad():
|
1447
|
+
# init conv
|
1448
|
+
scale = 1.0 / self._kernel_size
|
1449
|
+
dw_max = (self._kernel_size**2) ** -0.5
|
1450
|
+
pw_max = self._conv_channels**-0.5
|
1451
|
+
|
1452
|
+
torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
|
1453
|
+
torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
|
1454
|
+
|
1455
|
+
for idx in range(2, len(self.conv), 3):
|
1456
|
+
torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
|
1457
|
+
torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
|
1458
|
+
torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
|
1459
|
+
torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)
|
1460
|
+
|
1461
|
+
# init fc (80 * 64 = 5120 from https://github.com/kssteven418/
|
1462
|
+
# Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
|
1463
|
+
# src/models/conformer_encoder.py#L487
|
1464
|
+
fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
|
1465
|
+
torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
|
1466
|
+
torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
|
1467
|
+
|
1468
|
+
def conv_split_by_batch(self, x):
|
1469
|
+
"""Tries to split input by batch, run conv and concat results"""
|
1470
|
+
b, _, _, _ = x.size()
|
1471
|
+
if b == 1: # can't split if batch size is 1
|
1472
|
+
return x, False
|
1473
|
+
|
1474
|
+
if self.subsampling_conv_chunking_factor > 1:
|
1475
|
+
cf = self.subsampling_conv_chunking_factor
|
1476
|
+
else:
|
1477
|
+
# avoiding a bug / feature limiting indexing of tensors to 2**31
|
1478
|
+
# see https://github.com/pytorch/pytorch/issues/80020
|
1479
|
+
x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
|
1480
|
+
p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
|
1481
|
+
cf = 2**p
|
1482
|
+
|
1483
|
+
new_batch_size = b // cf
|
1484
|
+
if new_batch_size == 0: # input is too big
|
1485
|
+
return x, False
|
1486
|
+
|
1487
|
+
return (
|
1488
|
+
torch.cat(
|
1489
|
+
[self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]
|
1490
|
+
),
|
1491
|
+
True,
|
1492
|
+
)
|
1493
|
+
|
1494
|
+
def conv_split_by_channel(self, x):
|
1495
|
+
"""For dw convs, tries to split input by time, run conv and concat
|
1496
|
+
results"""
|
1497
|
+
x = self.conv[0](x) # full conv2D
|
1498
|
+
x = self.conv[1](x) # activation
|
1499
|
+
|
1500
|
+
for i in range(self._sampling_num - 1):
|
1501
|
+
_, c, t, _ = x.size()
|
1502
|
+
|
1503
|
+
if self.subsampling_conv_chunking_factor > 1:
|
1504
|
+
cf = self.subsampling_conv_chunking_factor
|
1505
|
+
else:
|
1506
|
+
# avoiding a bug / feature limiting indexing of tensors
|
1507
|
+
# to 2**31
|
1508
|
+
# see https://github.com/pytorch/pytorch/issues/80020
|
1509
|
+
p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
|
1510
|
+
cf = 2**p
|
1511
|
+
|
1512
|
+
new_c = int(c // cf)
|
1513
|
+
if new_c == 0:
|
1514
|
+
new_c = 1
|
1515
|
+
|
1516
|
+
new_t = int(t // cf)
|
1517
|
+
if new_t == 0:
|
1518
|
+
new_t = 1
|
1519
|
+
|
1520
|
+
x = self.channel_chunked_conv(
|
1521
|
+
self.conv[i * 3 + 2], new_c, x
|
1522
|
+
) # conv2D, depthwise
|
1523
|
+
|
1524
|
+
# splitting pointwise convs by time
|
1525
|
+
x = torch.cat(
|
1526
|
+
[self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)],
|
1527
|
+
2,
|
1528
|
+
) # conv2D, pointwise
|
1529
|
+
x = self.conv[i * 3 + 4](x) # activation
|
1530
|
+
return x
|
1531
|
+
|
1532
|
+
def channel_chunked_conv(self, conv, chunk_size, x):
|
1533
|
+
"""Performs channel chunked convolution"""
|
1534
|
+
|
1535
|
+
ind = 0
|
1536
|
+
out_chunks = []
|
1537
|
+
for chunk in torch.split(x, chunk_size, 1):
|
1538
|
+
step = chunk.size()[1]
|
1539
|
+
|
1540
|
+
if self.is_causal:
|
1541
|
+
chunk = nn.functional.pad(
|
1542
|
+
chunk,
|
1543
|
+
pad=(
|
1544
|
+
self._kernel_size - 1,
|
1545
|
+
self._stride - 1,
|
1546
|
+
self._kernel_size - 1,
|
1547
|
+
self._stride - 1,
|
1548
|
+
),
|
1549
|
+
)
|
1550
|
+
ch_out = nn.functional.conv2d(
|
1551
|
+
chunk,
|
1552
|
+
conv.weight[ind : ind + step, :, :, :],
|
1553
|
+
bias=conv.bias[ind : ind + step],
|
1554
|
+
stride=self._stride,
|
1555
|
+
padding=0,
|
1556
|
+
groups=step,
|
1557
|
+
)
|
1558
|
+
else:
|
1559
|
+
ch_out = nn.functional.conv2d(
|
1560
|
+
chunk,
|
1561
|
+
conv.weight[ind : ind + step, :, :, :],
|
1562
|
+
bias=conv.bias[ind : ind + step],
|
1563
|
+
stride=self._stride,
|
1564
|
+
padding=self._left_padding,
|
1565
|
+
groups=step,
|
1566
|
+
)
|
1567
|
+
out_chunks.append(ch_out)
|
1568
|
+
ind += step
|
1569
|
+
|
1570
|
+
return torch.cat(out_chunks, 1)
|
1571
|
+
|
1572
|
+
def change_subsampling_conv_chunking_factor(
|
1573
|
+
self, subsampling_conv_chunking_factor: int
|
1574
|
+
):
|
1575
|
+
if (
|
1576
|
+
subsampling_conv_chunking_factor != -1
|
1577
|
+
and subsampling_conv_chunking_factor != 1
|
1578
|
+
and subsampling_conv_chunking_factor % 2 != 0
|
1579
|
+
):
|
1580
|
+
raise ValueError(
|
1581
|
+
"subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2"
|
1582
|
+
)
|
1583
|
+
self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
|
1584
|
+
|
1585
|
+
|
1586
|
+
def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1):
|
1587
|
+
"""Calculates the output length of a Tensor passed through a convolution or
|
1588
|
+
max pooling layer"""
|
1589
|
+
add_pad: float = all_paddings - kernel_size
|
1590
|
+
one: float = 1.0
|
1591
|
+
for i in range(repeat_num):
|
1592
|
+
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
|
1593
|
+
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
|
1594
|
+
return lengths.to(dtype=torch.int)
|
1595
|
+
|
1596
|
+
|
1597
|
+
#### multihead attention starts here
|
1598
|
+
class AttModule(nn.Module):
|
1599
|
+
"""Attention abstraction module"""
|
1600
|
+
|
1601
|
+
def __init__(self):
|
1602
|
+
super().__init__()
|
1603
|
+
self.export_mode = False
|
1604
|
+
|
1605
|
+
def set_export(self, mode=True):
|
1606
|
+
"""set the export mode"""
|
1607
|
+
self.export_mode = mode
|
1608
|
+
|
1609
|
+
def forward(
|
1610
|
+
self,
|
1611
|
+
x: Tensor,
|
1612
|
+
memory: Optional[Tensor] = None,
|
1613
|
+
pos_emb: Optional[Tensor] = None,
|
1614
|
+
att_mask: Optional[Tensor] = None,
|
1615
|
+
) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
1616
|
+
"""AttModule forward
|
1617
|
+
|
1618
|
+
Args:
|
1619
|
+
x: torch.Tensor
|
1620
|
+
input tensor.
|
1621
|
+
memory: torch.Tensor, optional
|
1622
|
+
memory tensor.
|
1623
|
+
pos_emb: torch.Tensor, optional
|
1624
|
+
positional encoder embedding.
|
1625
|
+
att_mask: torch.Tensor, optional
|
1626
|
+
attention mask tensor.
|
1627
|
+
"""
|
1628
|
+
return x, memory, pos_emb, att_mask
|
1629
|
+
|
1630
|
+
|
1631
|
+
class AttBlock(BlockBase, AttModule):
|
1632
|
+
"""Attention Block module to support both Attention and Block module."""
|
1633
|
+
|
1634
|
+
def memory_dims(self, max_len=False):
|
1635
|
+
"""memory dimensions"""
|
1636
|
+
return (1, self.input_size)
|
1637
|
+
|
1638
|
+
|
1639
|
+
def masked_softmax(
|
1640
|
+
scores,
|
1641
|
+
mask: Optional[Tensor],
|
1642
|
+
):
|
1643
|
+
if mask is not None:
|
1644
|
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
|
1645
|
+
scores = scores.masked_fill(mask, -torch.inf)
|
1646
|
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
1647
|
+
mask, 0.0
|
1648
|
+
) # (batch, head, time1, time2)
|
1649
|
+
else:
|
1650
|
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
1651
|
+
return attn
|
1652
|
+
|
1653
|
+
|
1654
|
+
class MultiHeadedAttention(nn.Module):
|
1655
|
+
"""Multi-Head Attention layer with optional relative position embedding
|
1656
|
+
and GLU.
|
1657
|
+
|
1658
|
+
Args:
|
1659
|
+
n_head: int
|
1660
|
+
the number of heads.
|
1661
|
+
n_feat: int
|
1662
|
+
input size features.
|
1663
|
+
dropout_rate: float
|
1664
|
+
dropout rate.
|
1665
|
+
use_LN: bool
|
1666
|
+
apply layer norm or not
|
1667
|
+
dropout_at_output: bool
|
1668
|
+
whether to apply dropout at output
|
1669
|
+
attention_inner_dim: int, optional
|
1670
|
+
the attention dimension used in the class,
|
1671
|
+
it can be different from the input dimension n_feat.
|
1672
|
+
default: -1 (equal to n_feat).
|
1673
|
+
use_pt_scaled_dot_product_attention: bool, optional
|
1674
|
+
if set True, use pytorch scaled dot product attention in training.
|
1675
|
+
NOTE: this will NOT be used in ONNX decoding due to a lack of
|
1676
|
+
support. In that case, we use the original attention
|
1677
|
+
implementation, which shows no regression.
|
1678
|
+
default: False.
|
1679
|
+
n_value: int, optional
|
1680
|
+
if set to values other than -1, use a different dimension for
|
1681
|
+
value. With the default value (i.e. -1), it is backward compatible.
|
1682
|
+
group_size: int, optional. must divide `n_head`
|
1683
|
+
if group_size > 1: GQA
|
1684
|
+
if group_size = 1: MHA
|
1685
|
+
if group_size = n_head: MQA
|
1686
|
+
"""
|
1687
|
+
|
1688
|
+
inv_sqrt_d_k: torch.jit.Final[float]
|
1689
|
+
h: torch.jit.Final[int]
|
1690
|
+
h_k: torch.jit.Final[int]
|
1691
|
+
g: torch.jit.Final[int]
|
1692
|
+
|
1693
|
+
def __init__(
|
1694
|
+
self,
|
1695
|
+
n_head,
|
1696
|
+
n_feat,
|
1697
|
+
dropout_rate,
|
1698
|
+
attention_inner_dim=-1,
|
1699
|
+
glu_type="swish",
|
1700
|
+
bias_in_glu=True,
|
1701
|
+
use_pt_scaled_dot_product_attention=False,
|
1702
|
+
n_value=-1,
|
1703
|
+
group_size: int = 1,
|
1704
|
+
):
|
1705
|
+
super().__init__()
|
1706
|
+
if n_value == -1:
|
1707
|
+
n_value = n_feat
|
1708
|
+
if attention_inner_dim == -1:
|
1709
|
+
attention_inner_dim = n_feat
|
1710
|
+
assert attention_inner_dim % n_head == 0
|
1711
|
+
|
1712
|
+
# We assume d_v always equals d_k
|
1713
|
+
self.d_k = attention_inner_dim // n_head
|
1714
|
+
self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
|
1715
|
+
self.h = n_head
|
1716
|
+
assert n_head % group_size == 0, "group_size must divide n_head"
|
1717
|
+
self.g = group_size
|
1718
|
+
self.h_k = n_head // group_size
|
1719
|
+
|
1720
|
+
self.linear_q = nn.Linear(n_feat, attention_inner_dim)
|
1721
|
+
self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
|
1722
|
+
self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
|
1723
|
+
self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
|
1724
|
+
|
1725
|
+
self.attn = torch.jit.Attribute(None, Optional[Tensor])
|
1726
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
1727
|
+
self.dropout_rate = dropout_rate
|
1728
|
+
self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention
|
1729
|
+
|
1730
|
+
if use_pt_scaled_dot_product_attention and group_size > 1:
|
1731
|
+
raise ValueError("Cannot use PT Scaled Attention with GQA")
|
1732
|
+
|
1733
|
+
# Torchscript eager quantization. Note that these functions below are
|
1734
|
+
# NOOPs and have very little impact on performance unless quantization
|
1735
|
+
# is enabled.
|
1736
|
+
self.quant_q = torch.ao.quantization.QuantStub()
|
1737
|
+
self.quant_x = torch.ao.quantization.QuantStub()
|
1738
|
+
self.dequant = torch.ao.quantization.DeQuantStub()
|
1739
|
+
self.ffunc = torch.ao.nn.quantized.FloatFunctional()
|
1740
|
+
|
1741
|
+
def forward(
|
1742
|
+
self,
|
1743
|
+
query: Tensor,
|
1744
|
+
key: Tensor,
|
1745
|
+
value: Tensor,
|
1746
|
+
pos_k: Tensor,
|
1747
|
+
pos_v: Tensor,
|
1748
|
+
mask: Optional[Tensor],
|
1749
|
+
relative_attention_bias: Optional[Tensor] = None,
|
1750
|
+
):
|
1751
|
+
"""Compute 'Scaled Dot Product Attention'.
|
1752
|
+
|
1753
|
+
Args:
|
1754
|
+
query: torch.Tensor
|
1755
|
+
query tensor (batch, time1, size)
|
1756
|
+
key: torch.Tensor
|
1757
|
+
key tensor (batch, time2, size)
|
1758
|
+
value: torch.Tensor
|
1759
|
+
value tensor (batch, time1, size)
|
1760
|
+
pos_k: torch.Tensor
|
1761
|
+
key tensor used for relative positional embedding.
|
1762
|
+
pos_v: torch.Tensor
|
1763
|
+
value tensor used for relative positional embedding.
|
1764
|
+
mask: torch.Tensor
|
1765
|
+
mask tensor (batch, time1, time2)
|
1766
|
+
relative_attention_bias: torch.Tensor
|
1767
|
+
bias added to attention logits w.r.t. relative positions
|
1768
|
+
(1, n_head, time1, time2)
|
1769
|
+
"""
|
1770
|
+
n_batch = query.size(0)
|
1771
|
+
|
1772
|
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d)
|
1773
|
+
k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d)
|
1774
|
+
v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
|
1775
|
+
q = (
|
1776
|
+
q.transpose(1, 2)
|
1777
|
+
if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting()
|
1778
|
+
else q.transpose(1, 2) * self.inv_sqrt_d_k
|
1779
|
+
)
|
1780
|
+
k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
|
1781
|
+
v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
|
1782
|
+
|
1783
|
+
if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting():
|
1784
|
+
attn_mask = None
|
1785
|
+
if mask is not None:
|
1786
|
+
mask = mask.unsqueeze(1)
|
1787
|
+
if relative_attention_bias is not None:
|
1788
|
+
attn_mask = mask + relative_attention_bias
|
1789
|
+
else:
|
1790
|
+
attn_mask = mask
|
1791
|
+
if mask.dtype != q.dtype:
|
1792
|
+
attn_mask = attn_mask.to(q.dtype)
|
1793
|
+
|
1794
|
+
with torch.nn.attention.sdpa_kernel(
|
1795
|
+
[
|
1796
|
+
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
1797
|
+
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
1798
|
+
torch.nn.attention.SDPBackend.MATH,
|
1799
|
+
torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
|
1800
|
+
]
|
1801
|
+
):
|
1802
|
+
x = torch.nn.functional.scaled_dot_product_attention(
|
1803
|
+
q,
|
1804
|
+
k,
|
1805
|
+
v,
|
1806
|
+
attn_mask=attn_mask,
|
1807
|
+
dropout_p=self.dropout_rate,
|
1808
|
+
)
|
1809
|
+
else:
|
1810
|
+
if self.h != self.h_k:
|
1811
|
+
q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
|
1812
|
+
A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
|
1813
|
+
else:
|
1814
|
+
A = torch.matmul(q, k.transpose(-2, -1))
|
1815
|
+
if pos_k is not None:
|
1816
|
+
if self.h != self.h_k:
|
1817
|
+
B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
|
1818
|
+
else:
|
1819
|
+
reshape_q = (
|
1820
|
+
q.contiguous()
|
1821
|
+
.view(n_batch * self.h, -1, self.d_k)
|
1822
|
+
.transpose(0, 1)
|
1823
|
+
) # (t1,nh,dk)
|
1824
|
+
B = torch.matmul(
|
1825
|
+
reshape_q, pos_k.transpose(-2, -1)
|
1826
|
+
) # pos_k: (t1,dk,t2)
|
1827
|
+
B = B.transpose(0, 1).view(
|
1828
|
+
n_batch, self.h, pos_k.size(0), pos_k.size(1)
|
1829
|
+
)
|
1830
|
+
scores = A + B
|
1831
|
+
else:
|
1832
|
+
scores = A
|
1833
|
+
|
1834
|
+
if relative_attention_bias is not None:
|
1835
|
+
scores = scores + relative_attention_bias
|
1836
|
+
|
1837
|
+
attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
|
1838
|
+
|
1839
|
+
self.attn = attn
|
1840
|
+
|
1841
|
+
p_attn = self.dropout(attn)
|
1842
|
+
x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k)
|
1843
|
+
if pos_v is not None:
|
1844
|
+
reshape_attn = (
|
1845
|
+
p_attn.contiguous()
|
1846
|
+
.view(n_batch * self.h, pos_v.size(0), pos_v.size(1))
|
1847
|
+
.transpose(0, 1)
|
1848
|
+
) # (t1, bh, t2)
|
1849
|
+
|
1850
|
+
attn_v = (
|
1851
|
+
torch.matmul(reshape_attn, pos_v)
|
1852
|
+
.transpose(0, 1)
|
1853
|
+
.contiguous()
|
1854
|
+
.view(n_batch, self.h, pos_v.size(0), self.d_k)
|
1855
|
+
)
|
1856
|
+
x = x + attn_v
|
1857
|
+
x = (
|
1858
|
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k)
|
1859
|
+
) # (batch, time1, d_model)
|
1860
|
+
|
1861
|
+
return self.linear_out(x) # (batch, time1, d_model)
|
1862
|
+
|
1863
|
+
|
1864
|
+
class MultiSequential(torch.nn.Sequential):
|
1865
|
+
"""Multi-input multi-output torch.nn.Sequential"""
|
1866
|
+
|
1867
|
+
@torch.jit.ignore
|
1868
|
+
def forward(self, *args):
|
1869
|
+
"""Forward method implementation."""
|
1870
|
+
for m in self:
|
1871
|
+
args = m(*args)
|
1872
|
+
return args
|
1873
|
+
|
1874
|
+
|
1875
|
+
def get_offset(input_layer: str, time_reduction: int):
|
1876
|
+
"""Get an offset. We will use the offset for determining #frames of a
|
1877
|
+
subsampled feature.
|
1878
|
+
|
1879
|
+
Args:
|
1880
|
+
input_layer (str): Type of an input layer
|
1881
|
+
time_reduction (int): time reduction factor for downsampling a feature
|
1882
|
+
Returns:
|
1883
|
+
int: offset
|
1884
|
+
"""
|
1885
|
+
if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
|
1886
|
+
return 3
|
1887
|
+
if input_layer in ("conv2d",) and time_reduction == 6:
|
1888
|
+
return 1
|
1889
|
+
if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
|
1890
|
+
return 7
|
1891
|
+
return 0
|
1892
|
+
|
1893
|
+
|
1894
|
+
def unfold_tensor(xs_pad, max_seq_len):
|
1895
|
+
"""
|
1896
|
+
For a given tensor with shape of (N, T, D), if sequence length T is
|
1897
|
+
longer than max_seq_len, this function unfold it to a
|
1898
|
+
(NT', max_seq_len, D) where T' is T // max_seq_len.
|
1899
|
+
Args:
|
1900
|
+
xs_pad: N, T, D
|
1901
|
+
"""
|
1902
|
+
_, _, D = xs_pad.shape
|
1903
|
+
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
|
1904
|
+
# N x D x 1 x T => N x (D x max_seq_len) x T'
|
1905
|
+
xs_pad = F.unfold(
|
1906
|
+
xs_pad[..., None, :],
|
1907
|
+
kernel_size=(1, max_seq_len),
|
1908
|
+
stride=(1, max_seq_len),
|
1909
|
+
)
|
1910
|
+
new_bsz, _, slen = xs_pad.shape
|
1911
|
+
# N x D x max_seq_len x T'
|
1912
|
+
xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
|
1913
|
+
# N x T' x max_seq_len x D
|
1914
|
+
xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
|
1915
|
+
# NT' x max_seq_len x D
|
1916
|
+
xs_pad = xs_pad.view(-1, max_seq_len, D)
|
1917
|
+
return xs_pad
|