sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1260 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
#!/usr/bin/env python3
|
15
|
+
import abc
|
16
|
+
import math
|
17
|
+
from typing import Literal, Optional
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
import torch
|
21
|
+
import torch.nn.functional as F
|
22
|
+
from torch import Tensor, nn
|
23
|
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
24
|
+
CheckpointWrapper,
|
25
|
+
)
|
26
|
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
27
|
+
from transformers import PretrainedConfig
|
28
|
+
|
29
|
+
from sglang.srt.models.phi4mm_utils import (
|
30
|
+
AbsolutePositionalEncoding,
|
31
|
+
ConvModule,
|
32
|
+
FeedForward,
|
33
|
+
MeanVarianceNormLayer,
|
34
|
+
MultiHeadedAttention,
|
35
|
+
MultiSequential,
|
36
|
+
NemoConvSubsampling,
|
37
|
+
T5RelativeAttentionLogitBias,
|
38
|
+
adaptive_enc_mask,
|
39
|
+
get_offset,
|
40
|
+
unfold_tensor,
|
41
|
+
)
|
42
|
+
|
43
|
+
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
44
|
+
|
45
|
+
|
46
|
+
class ConformerEncoderLayer(nn.Module):
|
47
|
+
"""ConformerEncoder Layer module.
|
48
|
+
for more details see conformer paper:
|
49
|
+
https://arxiv.org/abs/2005.08100
|
50
|
+
This module implement the Conformer block layer.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
d_model: int
|
54
|
+
attention dim.
|
55
|
+
ext_pw_out_channel: int
|
56
|
+
if > 0, ext_pw_out_channel is a dim channel size
|
57
|
+
for the last pointwise conv after swish activation.
|
58
|
+
depthwise_seperable_out_channel: int
|
59
|
+
if set different to 0, the number of
|
60
|
+
depthwise_seperable_out_channel will be used as a
|
61
|
+
channel_out of the second conv1d layer.
|
62
|
+
otherwise, it equal to 0, the second conv1d layer is skipped.
|
63
|
+
depthwise_multiplier: int
|
64
|
+
number of input_dim channels duplication. this value
|
65
|
+
will be used to compute the hidden channels of the Conv1D.
|
66
|
+
n_head: int
|
67
|
+
the number of heads for multihead attention module.
|
68
|
+
d_ffn: int
|
69
|
+
output size of the feed_forward blocks.
|
70
|
+
ext_pw_kernel_size: int
|
71
|
+
kernel size of the conv pointwise of the conformer.
|
72
|
+
kernel_size: int
|
73
|
+
kernel size.
|
74
|
+
dropout_rate: float
|
75
|
+
dropout rate.
|
76
|
+
causal: bool, optional
|
77
|
+
if set to True, convolution have no access
|
78
|
+
to future frames. default False.
|
79
|
+
batch_norm: bool, optional
|
80
|
+
if set to True, apply batchnorm before activation
|
81
|
+
in ConvModule layer of the conformer.
|
82
|
+
default False
|
83
|
+
activation: str, optional
|
84
|
+
activation function name,
|
85
|
+
one of ["relu", "swish", "sigmoid"],
|
86
|
+
sigmoid activation is only used with "glu_in_fnn=True",
|
87
|
+
default "relu".
|
88
|
+
chunk_se: int, optional
|
89
|
+
0 for offline SE.
|
90
|
+
1 for streaming SE, where mean is computed
|
91
|
+
by accumulated history until current chunk_se.
|
92
|
+
2 for streaming SE, where mean is computed
|
93
|
+
by only the current chunk.
|
94
|
+
default 0.
|
95
|
+
chunk_size: int, optional
|
96
|
+
chunk_size for cnn. default 18
|
97
|
+
conv_activation: str, optional
|
98
|
+
activation function used in ConvModule part
|
99
|
+
of the conformer, default "relu".
|
100
|
+
conv_glu_type: str, optional
|
101
|
+
activation function used for the glu inside
|
102
|
+
the ConvModule part of the conformer.
|
103
|
+
default: "sigmoid".
|
104
|
+
bias_in_glu: bool, optional
|
105
|
+
if set to True, use additive bias in the weight module
|
106
|
+
before GLU.
|
107
|
+
linear_glu_in_convm: bool, optional
|
108
|
+
if set to True, use GLULinear module,
|
109
|
+
otherwise, used GLUPointWiseConv module.
|
110
|
+
default to False.
|
111
|
+
attention_inner_dim: int, optional
|
112
|
+
if equal to -1, attention dim for linears k/q/v is
|
113
|
+
equal to d_model. otherwise attention_inner_dim is used.
|
114
|
+
default -1.
|
115
|
+
attention_glu_type: str, optional
|
116
|
+
activation function for glu used in the multihead attention,
|
117
|
+
default "swish".
|
118
|
+
activation_checkpointing: str, optional
|
119
|
+
a dictionarry of {"module","interval","offload"}, where
|
120
|
+
"module": str
|
121
|
+
accept ["transformer", "attention"] to select
|
122
|
+
which module should do activation checkpointing.
|
123
|
+
"interval": int, default 1,
|
124
|
+
interval of applying activation checkpointing,
|
125
|
+
interval = 1 means that we apply checkpointing
|
126
|
+
on every layer (if activation), otherwise,
|
127
|
+
we apply it every x interval.
|
128
|
+
"offload": bool, default False,
|
129
|
+
if set to True, we offload activation to cpu and
|
130
|
+
reload it during backward, otherwise,
|
131
|
+
we recalculate activation in backward.
|
132
|
+
default "".
|
133
|
+
export: bool, optional
|
134
|
+
if set to True, it remove the padding from convolutional layers
|
135
|
+
and allow the onnx conversion for inference.
|
136
|
+
default False.
|
137
|
+
use_pt_scaled_dot_product_attention: bool, optional
|
138
|
+
if set to True, use pytorch's scaled dot product attention
|
139
|
+
implementation in training.
|
140
|
+
attn_group_sizes: int, optional
|
141
|
+
the number of groups to use for attention, default 1
|
142
|
+
(Multi-Head Attention),
|
143
|
+
1 = typical Multi-Head Attention,
|
144
|
+
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
|
145
|
+
attn_group_sizes = attention_heads = Multi-Query Attention
|
146
|
+
"""
|
147
|
+
|
148
|
+
def __init__(
|
149
|
+
self,
|
150
|
+
d_model=512,
|
151
|
+
ext_pw_out_channel=0,
|
152
|
+
depthwise_seperable_out_channel=256,
|
153
|
+
depthwise_multiplier=1,
|
154
|
+
n_head=4,
|
155
|
+
d_ffn=2048,
|
156
|
+
ext_pw_kernel_size=1,
|
157
|
+
kernel_size=3,
|
158
|
+
dropout_rate=0.1,
|
159
|
+
causal=False,
|
160
|
+
batch_norm=False,
|
161
|
+
activation="relu",
|
162
|
+
chunk_se=0,
|
163
|
+
chunk_size=18,
|
164
|
+
conv_activation="relu",
|
165
|
+
conv_glu_type="sigmoid",
|
166
|
+
bias_in_glu=True,
|
167
|
+
linear_glu_in_convm=False,
|
168
|
+
attention_inner_dim=-1,
|
169
|
+
attention_glu_type="swish",
|
170
|
+
activation_checkpointing="",
|
171
|
+
export=False,
|
172
|
+
use_pt_scaled_dot_product_attention=False,
|
173
|
+
attn_group_sizes: int = 1,
|
174
|
+
):
|
175
|
+
super().__init__()
|
176
|
+
|
177
|
+
self.feed_forward_in = FeedForward(
|
178
|
+
d_model=d_model,
|
179
|
+
d_inner=d_ffn,
|
180
|
+
dropout_rate=dropout_rate,
|
181
|
+
activation=activation,
|
182
|
+
bias_in_glu=bias_in_glu,
|
183
|
+
)
|
184
|
+
|
185
|
+
self.self_attn = MultiHeadedAttention(
|
186
|
+
n_head,
|
187
|
+
d_model,
|
188
|
+
dropout_rate,
|
189
|
+
attention_inner_dim,
|
190
|
+
attention_glu_type,
|
191
|
+
bias_in_glu,
|
192
|
+
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
193
|
+
group_size=attn_group_sizes,
|
194
|
+
)
|
195
|
+
self.conv = ConvModule(
|
196
|
+
d_model,
|
197
|
+
ext_pw_out_channel,
|
198
|
+
depthwise_seperable_out_channel,
|
199
|
+
ext_pw_kernel_size,
|
200
|
+
kernel_size,
|
201
|
+
depthwise_multiplier,
|
202
|
+
dropout_rate,
|
203
|
+
causal,
|
204
|
+
batch_norm,
|
205
|
+
chunk_se,
|
206
|
+
chunk_size,
|
207
|
+
conv_activation,
|
208
|
+
conv_glu_type,
|
209
|
+
bias_in_glu,
|
210
|
+
linear_glu_in_convm,
|
211
|
+
export=export,
|
212
|
+
)
|
213
|
+
|
214
|
+
self.feed_forward_out = FeedForward(
|
215
|
+
d_model=d_model,
|
216
|
+
d_inner=d_ffn,
|
217
|
+
dropout_rate=dropout_rate,
|
218
|
+
activation=activation,
|
219
|
+
bias_in_glu=bias_in_glu,
|
220
|
+
)
|
221
|
+
|
222
|
+
self.layer_norm_att = nn.LayerNorm(d_model)
|
223
|
+
self.layer_norm = nn.LayerNorm(d_model)
|
224
|
+
|
225
|
+
def forward(
|
226
|
+
self,
|
227
|
+
x,
|
228
|
+
pos_k,
|
229
|
+
pos_v,
|
230
|
+
mask,
|
231
|
+
relative_attention_bias: Optional[Tensor] = None,
|
232
|
+
):
|
233
|
+
"""ConformerEncoder forward.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
x: torch.Tensor
|
237
|
+
input feature of shape (batch, max_time_in, size)
|
238
|
+
pos_k: torch.Tensor
|
239
|
+
positional key embedding.
|
240
|
+
mask: torch.Tensor
|
241
|
+
mask for x (batch, max_time_in)
|
242
|
+
relative_attention_bias: Optional[torch.Tensor]
|
243
|
+
bias added to attention logits w.r.t. relative positions
|
244
|
+
(1, n_head, time1, time2)
|
245
|
+
"""
|
246
|
+
x = x + 0.5 * self.feed_forward_in(x)
|
247
|
+
norm_x = self.layer_norm_att(x)
|
248
|
+
|
249
|
+
x = x + self.self_attn(
|
250
|
+
norm_x,
|
251
|
+
norm_x,
|
252
|
+
norm_x,
|
253
|
+
pos_k,
|
254
|
+
pos_v,
|
255
|
+
mask,
|
256
|
+
relative_attention_bias=relative_attention_bias,
|
257
|
+
)
|
258
|
+
x = x + self.conv(x)
|
259
|
+
x = x + 0.5 * self.feed_forward_out(x)
|
260
|
+
|
261
|
+
out = self.layer_norm(x)
|
262
|
+
|
263
|
+
return out, pos_k, pos_v, mask
|
264
|
+
|
265
|
+
|
266
|
+
class TransformerEncoderBase(abc.ABC, nn.Module):
|
267
|
+
"""The Base class for Transformer based encoders
|
268
|
+
|
269
|
+
Please set causal = True in streaming model
|
270
|
+
Args:
|
271
|
+
input_size: int
|
272
|
+
input feature dimension.
|
273
|
+
chunk_size: int, list(int)
|
274
|
+
Number of frames for each chunk
|
275
|
+
This variable can take 2 forms:
|
276
|
+
int: Used for inference, or single chunk size training
|
277
|
+
list(int) : Used only for variable chunk size training
|
278
|
+
Some examples for the 2 cases:
|
279
|
+
chunk_size = 12
|
280
|
+
chunk_size = [6, 8, 12, 24]
|
281
|
+
left_chunk: int, list(int)
|
282
|
+
Number of chunks used for masking in streaming mode.
|
283
|
+
This variable can take 2 forms:
|
284
|
+
int: Used for inference, or single chunk size training
|
285
|
+
list(int) : Used only for variable chunk size training. When
|
286
|
+
chunk_size is a list, left_chunk must be a list with same length.
|
287
|
+
Some examples for the 2 cases:
|
288
|
+
left_chunk = 6
|
289
|
+
left_chunk = [12, 9, 6, 3]
|
290
|
+
attention_dim: int, optional
|
291
|
+
attention dimension. default 256.
|
292
|
+
attention_heads: int, optional
|
293
|
+
the number of heads. default 4
|
294
|
+
input_layer: str, optional
|
295
|
+
input layer type before Conformer,
|
296
|
+
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
|
297
|
+
default "conv2d"
|
298
|
+
cnn_out: int, optional
|
299
|
+
the number of CNN channels before Conformer.
|
300
|
+
default -1.
|
301
|
+
cnn_layer_norm: bool, optional
|
302
|
+
layer norm between Conformer and the first CNN.
|
303
|
+
default False.
|
304
|
+
time_reduction: int, optional
|
305
|
+
time reduction factor
|
306
|
+
default 4
|
307
|
+
dropout_rate: float, optional
|
308
|
+
dropout rate. default 0.1
|
309
|
+
padding_idx: int, optional
|
310
|
+
padding index for input_layer=embed
|
311
|
+
default -1
|
312
|
+
relative_attention_bias_args: dict, optional
|
313
|
+
use more efficient scalar bias-based relative multihead attention
|
314
|
+
(Q*K^T + B) implemented in cmb.basics.embedding.
|
315
|
+
[T5/ALiBi]RelativeAttentionLogitBias
|
316
|
+
usage: relative_attention_bias_args={"type": t5/alibi}
|
317
|
+
additional method-specific arguments can be provided (see
|
318
|
+
transformer_base.py)
|
319
|
+
positional_dropout_rate: float, optional
|
320
|
+
dropout rate after positional encoding. default 0.0
|
321
|
+
nemo_conv_settings: dict, optional
|
322
|
+
A dictionary of settings for NeMo Subsampling.
|
323
|
+
default None
|
324
|
+
conv2d_extra_padding: str, optional
|
325
|
+
Add extra padding in conv2d subsampling layers. Choices are
|
326
|
+
(feat, feat_time, none, True).
|
327
|
+
if True or feat_time, the extra padding is added into non full
|
328
|
+
supraframe utts in batch.
|
329
|
+
Default: none
|
330
|
+
attention_group_size: int, optional
|
331
|
+
the number of groups to use for attention, default 1
|
332
|
+
(Multi-Head Attention),
|
333
|
+
1 = typical Multi-Head Attention,
|
334
|
+
1 < attention_group_size < attention_heads = Grouped-Query
|
335
|
+
Attention
|
336
|
+
attention_group_size = attention_heads = Multi-Query Attention
|
337
|
+
"""
|
338
|
+
|
339
|
+
def __init__(
|
340
|
+
self,
|
341
|
+
input_size,
|
342
|
+
chunk_size,
|
343
|
+
left_chunk,
|
344
|
+
attention_dim=256,
|
345
|
+
attention_heads=4,
|
346
|
+
input_layer="nemo_conv",
|
347
|
+
cnn_out=-1,
|
348
|
+
cnn_layer_norm=False,
|
349
|
+
time_reduction=4,
|
350
|
+
dropout_rate=0.0,
|
351
|
+
padding_idx=-1,
|
352
|
+
relative_attention_bias_args=None,
|
353
|
+
positional_dropout_rate=0.0,
|
354
|
+
nemo_conv_settings=None,
|
355
|
+
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
356
|
+
attention_group_size=1,
|
357
|
+
encoder_embedding_config=None,
|
358
|
+
):
|
359
|
+
super().__init__()
|
360
|
+
self.input_size = input_size
|
361
|
+
self.input_layer = input_layer
|
362
|
+
self.chunk_size = chunk_size
|
363
|
+
self.left_chunk = left_chunk
|
364
|
+
self.attention_dim = attention_dim
|
365
|
+
self.num_heads = attention_heads
|
366
|
+
self.attention_group_size = attention_group_size
|
367
|
+
self.time_reduction = time_reduction
|
368
|
+
self.nemo_conv_settings = nemo_conv_settings
|
369
|
+
self.encoder_embedding_config = encoder_embedding_config
|
370
|
+
|
371
|
+
if self.input_layer == "nemo_conv":
|
372
|
+
default_nemo_conv_settings = {
|
373
|
+
"subsampling": "dw_striding",
|
374
|
+
"subsampling_factor": self.time_reduction,
|
375
|
+
"feat_in": input_size,
|
376
|
+
"feat_out": attention_dim,
|
377
|
+
"conv_channels": 256,
|
378
|
+
"subsampling_conv_chunking_factor": 1,
|
379
|
+
"activation": nn.ReLU(),
|
380
|
+
"is_causal": False,
|
381
|
+
}
|
382
|
+
# Override any of the defaults with the incoming, user settings
|
383
|
+
if nemo_conv_settings:
|
384
|
+
default_nemo_conv_settings.update(nemo_conv_settings)
|
385
|
+
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
386
|
+
assert (
|
387
|
+
i not in nemo_conv_settings
|
388
|
+
), "{i} should be specified outside of the NeMo dictionary"
|
389
|
+
|
390
|
+
self.embed = NemoConvSubsampling(
|
391
|
+
**default_nemo_conv_settings,
|
392
|
+
)
|
393
|
+
else:
|
394
|
+
raise ValueError("unknown input_layer: " + input_layer)
|
395
|
+
|
396
|
+
self.pos_emb = AbsolutePositionalEncoding(
|
397
|
+
attention_dim, positional_dropout_rate
|
398
|
+
)
|
399
|
+
|
400
|
+
self.relative_attention_bias_type = (
|
401
|
+
relative_attention_bias_args.get("type")
|
402
|
+
if relative_attention_bias_args
|
403
|
+
else None
|
404
|
+
)
|
405
|
+
if self.relative_attention_bias_type == "t5":
|
406
|
+
assert (
|
407
|
+
self.num_heads % self.attention_group_size == 0
|
408
|
+
), "attention_group_size must divide n_head"
|
409
|
+
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
|
410
|
+
self.num_heads // self.attention_group_size,
|
411
|
+
max_distance=relative_attention_bias_args.get(
|
412
|
+
"t5_bias_max_distance", 1000
|
413
|
+
),
|
414
|
+
symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
|
415
|
+
)
|
416
|
+
else:
|
417
|
+
raise NotImplementedError
|
418
|
+
|
419
|
+
self.encoder_embedding = MeanVarianceNormLayer(
|
420
|
+
self.encoder_embedding_config["input_size"]
|
421
|
+
)
|
422
|
+
|
423
|
+
def compute_lens_change(self, feature_lens):
|
424
|
+
"""feature_lens: int
|
425
|
+
return updated feature lens.
|
426
|
+
|
427
|
+
This used to return a different lambda function for each case that
|
428
|
+
computed the right thing. That does not work within Torchscript.
|
429
|
+
If you really need this to be faster, create nn.Module()-s for all
|
430
|
+
the cases and return one of them. Torchscript does support that.
|
431
|
+
"""
|
432
|
+
if self.input_layer == "nemo_conv":
|
433
|
+
# Handle the special causal case
|
434
|
+
subsampling_causal_cond = self.nemo_conv_settings.get(
|
435
|
+
"subsampling", "dw_striding"
|
436
|
+
) in [
|
437
|
+
"dw_striding",
|
438
|
+
"striding",
|
439
|
+
"striding_conv1d",
|
440
|
+
]
|
441
|
+
is_causal = self.nemo_conv_settings.get("is_causal", False)
|
442
|
+
if is_causal and subsampling_causal_cond:
|
443
|
+
lens_change = (
|
444
|
+
torch.ceil(feature_lens / self.time_reduction).long()
|
445
|
+
if isinstance(feature_lens, Tensor)
|
446
|
+
else math.ceil(feature_lens / self.time_reduction)
|
447
|
+
)
|
448
|
+
feature_lens_remainder = feature_lens % self.time_reduction
|
449
|
+
if isinstance(feature_lens, Tensor):
|
450
|
+
lens_change[feature_lens_remainder != 1] += 1
|
451
|
+
elif feature_lens_remainder != 1:
|
452
|
+
lens_change += 1
|
453
|
+
return lens_change
|
454
|
+
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
|
455
|
+
return ceil_func(feature_lens / self.time_reduction)
|
456
|
+
|
457
|
+
@abc.abstractmethod
|
458
|
+
def forward(self):
|
459
|
+
"""Abstract forward method implementation."""
|
460
|
+
|
461
|
+
def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
|
462
|
+
"""If chunk size is a list, we will randomly select a chunk size."""
|
463
|
+
|
464
|
+
if chunk_size is None:
|
465
|
+
chunk_size = self.chunk_size
|
466
|
+
if left_chunk is None:
|
467
|
+
left_chunk = self.left_chunk
|
468
|
+
if isinstance(chunk_size, list):
|
469
|
+
# Variable chunk size during training
|
470
|
+
chunk_size_index = int(
|
471
|
+
torch.randint(low=0, high=len(chunk_size), size=(1,))
|
472
|
+
)
|
473
|
+
chunk_size_train_eff = chunk_size[chunk_size_index]
|
474
|
+
if not isinstance(left_chunk, list):
|
475
|
+
raise ValueError(
|
476
|
+
"Since chunk_size is a list, left_chunk must be a list"
|
477
|
+
)
|
478
|
+
if len(left_chunk) != len(chunk_size):
|
479
|
+
raise ValueError(
|
480
|
+
"The length of left_chunk must be the same as length of "
|
481
|
+
"chunk_size."
|
482
|
+
)
|
483
|
+
left_chunk_train_eff = left_chunk[chunk_size_index]
|
484
|
+
else:
|
485
|
+
chunk_size_train_eff = chunk_size
|
486
|
+
left_chunk_train_eff = left_chunk
|
487
|
+
|
488
|
+
return chunk_size_train_eff, left_chunk_train_eff
|
489
|
+
|
490
|
+
def _get_embed_class(self, embed):
|
491
|
+
# pylint: disable=protected-access
|
492
|
+
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
|
493
|
+
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
|
494
|
+
embed_class = embed
|
495
|
+
if is_embed_using_act_chkpt:
|
496
|
+
embed_class = embed._checkpoint_wrapped_module
|
497
|
+
if is_embed_fsdp_wrapped:
|
498
|
+
embed_class = embed.module
|
499
|
+
return embed_class
|
500
|
+
|
501
|
+
def _forward_embeddings_core(self, input_tensor, masks):
|
502
|
+
embed_class = self._get_embed_class(self.embed)
|
503
|
+
assert isinstance(embed_class, NemoConvSubsampling)
|
504
|
+
input_tensor, masks = self.embed(input_tensor, masks)
|
505
|
+
return input_tensor, masks
|
506
|
+
|
507
|
+
def _position_embedding(self, input_tensor):
|
508
|
+
pos_k = None
|
509
|
+
pos_v = None
|
510
|
+
if self.relative_attention_bias_layer is None:
|
511
|
+
input_tensor = self.pos_emb(
|
512
|
+
input_tensor
|
513
|
+
) # default to add abs sinusoid embedding
|
514
|
+
return pos_k, pos_v
|
515
|
+
|
516
|
+
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
|
517
|
+
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
|
518
|
+
chunk_size, left_chunk
|
519
|
+
)
|
520
|
+
|
521
|
+
# Create mask matrix for streaming
|
522
|
+
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
523
|
+
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
524
|
+
|
525
|
+
enc_streaming_mask = (
|
526
|
+
adaptive_enc_mask(
|
527
|
+
seq_len, chunk_start_idx, left_window=left_chunk_train_eff
|
528
|
+
)
|
529
|
+
.unsqueeze(0)
|
530
|
+
.expand([batch_size, -1, -1])
|
531
|
+
)
|
532
|
+
return enc_streaming_mask
|
533
|
+
|
534
|
+
def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None):
|
535
|
+
"""Forwarding the inputs through the top embedding layers
|
536
|
+
|
537
|
+
Args:
|
538
|
+
xs_pad: torch.Tensor
|
539
|
+
input tensor
|
540
|
+
masks: torch.Tensor
|
541
|
+
input mask
|
542
|
+
chunk_size_nc: (optional, default is None) chunk size for
|
543
|
+
non-causal layers
|
544
|
+
left_chunk_nc: (optional, default is None) # of left chunks for
|
545
|
+
non-causal layers
|
546
|
+
"""
|
547
|
+
# pylint: disable=R0915
|
548
|
+
# get new lens.
|
549
|
+
seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
|
550
|
+
if seq_len <= 0:
|
551
|
+
raise ValueError(
|
552
|
+
f"""The sequence length after time reduction is invalid:
|
553
|
+
{seq_len}. Your input feature is too short. Consider
|
554
|
+
filtering out the very short sentence from data
|
555
|
+
loader""",
|
556
|
+
)
|
557
|
+
|
558
|
+
batch_size = xs_pad.shape[0]
|
559
|
+
|
560
|
+
enc_streaming_mask = self._streaming_mask(
|
561
|
+
seq_len, batch_size, self.chunk_size, self.left_chunk
|
562
|
+
)
|
563
|
+
|
564
|
+
if xs_pad.is_cuda:
|
565
|
+
enc_streaming_mask = enc_streaming_mask.cuda()
|
566
|
+
xs_pad = xs_pad.cuda()
|
567
|
+
|
568
|
+
input_tensor = xs_pad
|
569
|
+
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
570
|
+
|
571
|
+
streaming_mask = enc_streaming_mask
|
572
|
+
if streaming_mask is not None and masks is not None:
|
573
|
+
hs_mask = masks & streaming_mask
|
574
|
+
elif masks is not None:
|
575
|
+
hs_mask = masks
|
576
|
+
else:
|
577
|
+
hs_mask = streaming_mask
|
578
|
+
|
579
|
+
if chunk_size_nc is not None:
|
580
|
+
enc_streaming_mask_nc = self._streaming_mask(
|
581
|
+
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
582
|
+
)
|
583
|
+
if xs_pad.is_cuda:
|
584
|
+
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
585
|
+
if masks is not None:
|
586
|
+
hs_mask_nc = masks & enc_streaming_mask_nc
|
587
|
+
else:
|
588
|
+
hs_mask_nc = enc_streaming_mask_nc
|
589
|
+
else:
|
590
|
+
hs_mask_nc = None
|
591
|
+
|
592
|
+
pos_k, pos_v = self._position_embedding(input_tensor)
|
593
|
+
|
594
|
+
if chunk_size_nc is None:
|
595
|
+
return input_tensor, pos_k, pos_v, hs_mask, masks
|
596
|
+
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
|
597
|
+
|
598
|
+
def get_offset(self):
|
599
|
+
"""Returns offset used when retaining inputs for decoding.
|
600
|
+
|
601
|
+
This is essentially, how many additional frames have to be added to
|
602
|
+
the front-end CNN input to ensure it can produce a single output.
|
603
|
+
So if the "padding" parameter is 0, typically offset will be > 0.
|
604
|
+
"""
|
605
|
+
return get_offset(self.input_layer, self.time_reduction)
|
606
|
+
|
607
|
+
|
608
|
+
class ConformerEncoder(TransformerEncoderBase):
|
609
|
+
"""ConformerEncoder module.
|
610
|
+
see original paper for more details:
|
611
|
+
https://arxiv.org/abs/2005.08100
|
612
|
+
|
613
|
+
Please set causal = True in streaming model
|
614
|
+
Args:
|
615
|
+
input_size: int
|
616
|
+
input feature dimension.
|
617
|
+
chunk_size: int, list(int)
|
618
|
+
Number of frames for each chunk
|
619
|
+
This variable can take 2 forms:
|
620
|
+
int: Used for inference, or single chunk size training
|
621
|
+
list(int) : Used only for variable chunk size training
|
622
|
+
Some examples for the 2 cases:
|
623
|
+
chunk_size = 12
|
624
|
+
chunk_size = [6, 8, 12, 24]
|
625
|
+
left_chunk: int, list(int)
|
626
|
+
Number of chunks used for masking in streaming mode.
|
627
|
+
This variable can take 2 forms:
|
628
|
+
int: Used for inference, or single chunk size training
|
629
|
+
list(int) : Used only for variable chunk size training. When
|
630
|
+
chunk_size is a list, left_chunk must be a list with same length.
|
631
|
+
Some examples for the 2 cases:
|
632
|
+
left_chunk = 6
|
633
|
+
left_chunk = [12, 9, 6, 3]
|
634
|
+
left_chunk: int
|
635
|
+
number of chunks used for masking in streaming mode.
|
636
|
+
num_lang: int
|
637
|
+
This parameter is used to store the number of languages in the
|
638
|
+
lang_dict, only used for multiseed/multilingual models.
|
639
|
+
default None.
|
640
|
+
attention_dim: int, optional
|
641
|
+
attention dimension. default 256.
|
642
|
+
attention_heads: int, optional
|
643
|
+
the number of heads. default 4
|
644
|
+
linear_units:
|
645
|
+
the number of units of position-wise feed forward.
|
646
|
+
default 2048
|
647
|
+
num_block:
|
648
|
+
number of Transformer layer. default 6
|
649
|
+
dropout_rate: float, optional
|
650
|
+
dropout rate. default 0.1
|
651
|
+
input_layer: str, optional
|
652
|
+
input layer type before Conformer,
|
653
|
+
one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
|
654
|
+
default "conv2d"
|
655
|
+
causal: bool, optional
|
656
|
+
if set to True, convolution have no access
|
657
|
+
to future frames. default False.
|
658
|
+
batch_norm: bool, optional
|
659
|
+
if set to True, apply batchnorm before activation
|
660
|
+
in ConvModule layer of the conformer.
|
661
|
+
default False
|
662
|
+
cnn_out: int, optional
|
663
|
+
the number of CNN channels before Conformer.
|
664
|
+
default -1.
|
665
|
+
cnn_layer_norm: bool, optional
|
666
|
+
layer norm between Conformer and the first CNN.
|
667
|
+
default False.
|
668
|
+
ext_pw_out_channel: int, optional
|
669
|
+
the number of channel for CNN
|
670
|
+
before depthwise_seperable_CNN.
|
671
|
+
If 0 then use linear. default 0.
|
672
|
+
ext_pw_kernel_size: int, optional
|
673
|
+
kernel size of N before depthwise_seperable_CNN.
|
674
|
+
only work for ext_pw_out_channel > 0.
|
675
|
+
default 1
|
676
|
+
depthwise_seperable_out_channel: int, optional
|
677
|
+
the number of channel for
|
678
|
+
depthwise_seperable_CNN.
|
679
|
+
default 256.
|
680
|
+
depthwise_multiplier: int, optional
|
681
|
+
the number of multiplier for
|
682
|
+
depthwise_seperable_CNN.
|
683
|
+
default 1.
|
684
|
+
chunk_se: int, optional
|
685
|
+
0 for offline SE.
|
686
|
+
1 for streaming SE, where mean is computed
|
687
|
+
by accumulated history until current chunk_se.
|
688
|
+
2 for streaming SE, where mean is computed
|
689
|
+
by only the current chunk.
|
690
|
+
default 0.
|
691
|
+
kernel_size: int, optional
|
692
|
+
the number of kernels for depthwise_seperable_CNN.
|
693
|
+
default 3.
|
694
|
+
activation: str, optional
|
695
|
+
FeedForward block activation.
|
696
|
+
one of ["relu", "swish", "sigmoid"]
|
697
|
+
default "relu".
|
698
|
+
conv_activation: str, optional
|
699
|
+
activation function used in ConvModule part
|
700
|
+
of the conformer, default "relu".
|
701
|
+
conv_glu_type: str, optional
|
702
|
+
activation used use glu in depthwise_seperable_CNN,
|
703
|
+
default "sigmoid"
|
704
|
+
bias_in_glu: bool, optional
|
705
|
+
if set to True, use additive bias in the weight module
|
706
|
+
before GLU. default True
|
707
|
+
linear_glu_in_convm: bool, optional
|
708
|
+
if set to True, use GLULinear module,
|
709
|
+
otherwise, used GLUPointWiseConv module.
|
710
|
+
default to False.
|
711
|
+
attention_glu_type: str
|
712
|
+
only work for glu_in_attention !=0
|
713
|
+
default "swish".
|
714
|
+
export: bool, optional
|
715
|
+
if set to True, it remove the padding from convolutional layers
|
716
|
+
and allow the onnx conversion for inference.
|
717
|
+
default False.
|
718
|
+
activation_checkpointing: str, optional
|
719
|
+
a dictionarry of {"module","interval","offload"}, where
|
720
|
+
"module": str
|
721
|
+
accept ["transformer", "attention"] to select
|
722
|
+
which module should do activation checkpointing.
|
723
|
+
"interval": int, default 1,
|
724
|
+
interval of applying activation checkpointing,
|
725
|
+
interval = 1 means that we apply checkpointing
|
726
|
+
on every layer (if activation), otherwise,
|
727
|
+
we apply it every x interval.
|
728
|
+
"offload": bool, default False,
|
729
|
+
if set to True, we offload activation to cpu and
|
730
|
+
reload it during backward, otherwise,
|
731
|
+
we recalculate activation in backward.
|
732
|
+
default "".
|
733
|
+
extra_layer_output_idx: int
|
734
|
+
the layer index to be exposed.
|
735
|
+
relative_attention_bias_args: dict, optional
|
736
|
+
use more efficient scalar bias-based relative multihead attention
|
737
|
+
(Q*K^T + B) implemented in cmb.basics.embedding.
|
738
|
+
[T5/ALiBi]RelativeAttentionLogitBias
|
739
|
+
usage: relative_attention_bias_args={"type": t5/alibi}
|
740
|
+
additional method-specific arguments can be provided (see
|
741
|
+
transformer_base.py)
|
742
|
+
time_reduction: int optional
|
743
|
+
time reduction factor
|
744
|
+
default 4
|
745
|
+
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
|
746
|
+
dot product attention in training.
|
747
|
+
Default: False
|
748
|
+
nemo_conv_settings: dict, optional
|
749
|
+
A dictionary of settings for NeMo Subsampling.
|
750
|
+
default: None
|
751
|
+
usage: nemo_conv_settings=
|
752
|
+
{
|
753
|
+
"subsampling":
|
754
|
+
dw_striding/striding/dw_striding_conv1d/striding_conv1d,
|
755
|
+
"conv_channels": int,
|
756
|
+
"subsampling_conv_chunking_factor": int,
|
757
|
+
"is_causal": True/False
|
758
|
+
}
|
759
|
+
conv2d_extra_padding: str, optional
|
760
|
+
Add extra padding in conv2d subsampling layers. Choices are
|
761
|
+
(feat, feat_time, none, True)
|
762
|
+
Default: none
|
763
|
+
replication_pad_for_subsample_embedding: For batched-streaming
|
764
|
+
decoding, use "replication" padding for the cache at start of
|
765
|
+
utterance.
|
766
|
+
Default: False
|
767
|
+
attention_group_size: int, optional
|
768
|
+
the number of groups to use for attention, default 1
|
769
|
+
(Multi-Head Attention),
|
770
|
+
1 = typical Multi-Head Attention,
|
771
|
+
1 < attention_group_size < attention_heads = Grouped-Query
|
772
|
+
Attention
|
773
|
+
attention_group_size = attention_heads = Multi-Query Attention
|
774
|
+
"""
|
775
|
+
|
776
|
+
extra_multi_layer_output_idxs: list[int]
|
777
|
+
|
778
|
+
def __init__( # pylint: disable-all
|
779
|
+
self,
|
780
|
+
input_size,
|
781
|
+
chunk_size,
|
782
|
+
left_chunk,
|
783
|
+
num_lang=None,
|
784
|
+
attention_dim=256,
|
785
|
+
attention_heads=4,
|
786
|
+
linear_units=2048,
|
787
|
+
num_blocks=6,
|
788
|
+
dropout_rate=0.1,
|
789
|
+
input_layer="nemo_conv",
|
790
|
+
causal=True,
|
791
|
+
batch_norm=False,
|
792
|
+
cnn_out=-1,
|
793
|
+
cnn_layer_norm=False,
|
794
|
+
ext_pw_out_channel=0,
|
795
|
+
ext_pw_kernel_size=1,
|
796
|
+
depthwise_seperable_out_channel=256,
|
797
|
+
depthwise_multiplier=1,
|
798
|
+
chunk_se=0,
|
799
|
+
kernel_size=3,
|
800
|
+
activation="relu",
|
801
|
+
conv_activation="relu",
|
802
|
+
conv_glu_type="sigmoid",
|
803
|
+
bias_in_glu=True,
|
804
|
+
linear_glu_in_convm=False,
|
805
|
+
attention_glu_type="swish",
|
806
|
+
export=False,
|
807
|
+
extra_layer_output_idx=-1,
|
808
|
+
extra_multi_layer_output_idxs=[], # noqa
|
809
|
+
activation_checkpointing="",
|
810
|
+
relative_attention_bias_args=None,
|
811
|
+
time_reduction=4,
|
812
|
+
use_pt_scaled_dot_product_attention=False,
|
813
|
+
nemo_conv_settings=None,
|
814
|
+
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
815
|
+
replication_pad_for_subsample_embedding=False,
|
816
|
+
attention_group_size=1,
|
817
|
+
encoder_embedding_config=None,
|
818
|
+
):
|
819
|
+
super().__init__(
|
820
|
+
input_size,
|
821
|
+
chunk_size,
|
822
|
+
left_chunk,
|
823
|
+
attention_dim,
|
824
|
+
attention_heads,
|
825
|
+
input_layer,
|
826
|
+
cnn_out,
|
827
|
+
cnn_layer_norm,
|
828
|
+
time_reduction,
|
829
|
+
dropout_rate=dropout_rate,
|
830
|
+
relative_attention_bias_args=relative_attention_bias_args,
|
831
|
+
positional_dropout_rate=0.0,
|
832
|
+
nemo_conv_settings=nemo_conv_settings,
|
833
|
+
conv2d_extra_padding=conv2d_extra_padding,
|
834
|
+
attention_group_size=attention_group_size,
|
835
|
+
encoder_embedding_config=encoder_embedding_config,
|
836
|
+
)
|
837
|
+
self.num_blocks = num_blocks
|
838
|
+
self.num_lang = num_lang
|
839
|
+
self.kernel_size = kernel_size
|
840
|
+
self.replication_pad_for_subsample_embedding: bool = (
|
841
|
+
replication_pad_for_subsample_embedding
|
842
|
+
)
|
843
|
+
assert (
|
844
|
+
self.num_heads % attention_group_size == 0
|
845
|
+
), "attention_group_size must divide n_head"
|
846
|
+
self.num_heads_k = self.num_heads // attention_group_size
|
847
|
+
|
848
|
+
self.encoders = MultiSequential(
|
849
|
+
*[
|
850
|
+
ConformerEncoderLayer(
|
851
|
+
d_model=attention_dim,
|
852
|
+
ext_pw_out_channel=ext_pw_out_channel,
|
853
|
+
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
854
|
+
depthwise_multiplier=depthwise_multiplier,
|
855
|
+
n_head=attention_heads,
|
856
|
+
d_ffn=linear_units,
|
857
|
+
ext_pw_kernel_size=ext_pw_kernel_size,
|
858
|
+
kernel_size=kernel_size,
|
859
|
+
dropout_rate=dropout_rate,
|
860
|
+
causal=causal,
|
861
|
+
batch_norm=batch_norm,
|
862
|
+
activation=activation,
|
863
|
+
chunk_se=chunk_se,
|
864
|
+
chunk_size=chunk_size,
|
865
|
+
conv_activation=conv_activation,
|
866
|
+
conv_glu_type=conv_glu_type,
|
867
|
+
bias_in_glu=bias_in_glu,
|
868
|
+
linear_glu_in_convm=linear_glu_in_convm,
|
869
|
+
attention_glu_type=attention_glu_type,
|
870
|
+
activation_checkpointing=activation_checkpointing,
|
871
|
+
export=export,
|
872
|
+
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
873
|
+
attn_group_sizes=attention_group_size,
|
874
|
+
)
|
875
|
+
for _ in range(num_blocks)
|
876
|
+
]
|
877
|
+
)
|
878
|
+
self.extra_layer_output_idx = extra_layer_output_idx
|
879
|
+
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
880
|
+
# Make a zeros scalar we can use in get_initial_state to determine
|
881
|
+
# the device and the needed dtype:
|
882
|
+
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
883
|
+
|
884
|
+
def init_relative_attention_bias(self, input_tensor):
|
885
|
+
if self.relative_attention_bias_layer:
|
886
|
+
return self.relative_attention_bias_layer(input_tensor)
|
887
|
+
|
888
|
+
def calculate_hs_mask(self, xs_pad, device, mask):
|
889
|
+
max_audio_length = xs_pad.shape[1]
|
890
|
+
batch_size = xs_pad.shape[0]
|
891
|
+
enc_streaming_mask = self._streaming_mask(
|
892
|
+
max_audio_length, batch_size, self.chunk_size, self.left_chunk
|
893
|
+
)
|
894
|
+
enc_streaming_mask = enc_streaming_mask.to(device)
|
895
|
+
if mask is None:
|
896
|
+
return enc_streaming_mask
|
897
|
+
|
898
|
+
feature_lens = mask.sum(1)
|
899
|
+
padding_length = feature_lens
|
900
|
+
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
|
901
|
+
padding_length.size(0), -1
|
902
|
+
) < padding_length.unsqueeze(1)
|
903
|
+
pad_mask = pad_mask.unsqueeze(1)
|
904
|
+
pad_mask = pad_mask & enc_streaming_mask
|
905
|
+
return pad_mask
|
906
|
+
|
907
|
+
@torch.jit.ignore
|
908
|
+
def forward(self, xs_pad, masks):
|
909
|
+
"""Conformer Forward function
|
910
|
+
|
911
|
+
Args:
|
912
|
+
xs_pad: torch.Tensor
|
913
|
+
input tensor
|
914
|
+
masks: torch.Tensor
|
915
|
+
post-embedding input lengths
|
916
|
+
"""
|
917
|
+
xs_pad = self.encoder_embedding(xs_pad)
|
918
|
+
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
|
919
|
+
xs_pad, masks
|
920
|
+
)
|
921
|
+
|
922
|
+
unfolded = False
|
923
|
+
ori_bz, seq_len, D = input_tensor.shape
|
924
|
+
max_seq_len = 500 # maximum position for absolute positional encoding
|
925
|
+
if seq_len > max_seq_len:
|
926
|
+
# audio sequence is longer than max_seq_len, unfold it into chunks
|
927
|
+
# of max_seq_len
|
928
|
+
unfolded = True
|
929
|
+
# the unfold op will drop residual frames, pad it to the multiple
|
930
|
+
# of max_seq_len
|
931
|
+
if seq_len % max_seq_len > 0:
|
932
|
+
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
|
933
|
+
else:
|
934
|
+
chunk_pad_size = 0
|
935
|
+
if chunk_pad_size > 0:
|
936
|
+
input_tensor_pad = F.pad(
|
937
|
+
input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
|
938
|
+
)
|
939
|
+
input_tensor = input_tensor_pad.to(input_tensor.device)
|
940
|
+
input_tensor = unfold_tensor(input_tensor, max_seq_len)
|
941
|
+
if masks is not None:
|
942
|
+
# revise hs_mask here because the previous calculated hs_mask
|
943
|
+
# did not consider extra pad
|
944
|
+
subsampled_pad_mask = masks.squeeze(
|
945
|
+
1
|
946
|
+
) # [bz, subsampled_unmask_seq_len]
|
947
|
+
extra_padded_subsamlped_pad_mask = F.pad(
|
948
|
+
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
|
949
|
+
) # extra padding to the pad mask
|
950
|
+
extra_padded_subsamlped_pad_mask = (
|
951
|
+
extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
|
952
|
+
)
|
953
|
+
masks_unfold = unfold_tensor(
|
954
|
+
extra_padded_subsamlped_pad_mask, max_seq_len
|
955
|
+
) # unfold the pad mask like we did to the input tensor
|
956
|
+
masks_unfold = masks_unfold.squeeze(
|
957
|
+
-1
|
958
|
+
).bool() # unfold op does not support bool tensor
|
959
|
+
else:
|
960
|
+
masks_unfold = None
|
961
|
+
hs_mask = self.calculate_hs_mask(
|
962
|
+
input_tensor, input_tensor.device, masks_unfold
|
963
|
+
) # calculate hs_mask based on the unfolded pad mask
|
964
|
+
|
965
|
+
# layer_emb = None
|
966
|
+
|
967
|
+
relative_attention_bias = self.init_relative_attention_bias(input_tensor)
|
968
|
+
|
969
|
+
_simplified_path = (
|
970
|
+
self.extra_layer_output_idx == -1 and relative_attention_bias is None
|
971
|
+
)
|
972
|
+
|
973
|
+
if _simplified_path:
|
974
|
+
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
|
975
|
+
else:
|
976
|
+
for i, layer in enumerate(self.encoders):
|
977
|
+
input_tensor, _, _, _ = layer(
|
978
|
+
input_tensor,
|
979
|
+
pos_k,
|
980
|
+
pos_v,
|
981
|
+
hs_mask,
|
982
|
+
relative_attention_bias=relative_attention_bias,
|
983
|
+
)
|
984
|
+
|
985
|
+
# if i == self.extra_layer_output_idx:
|
986
|
+
# layer_emb = input_tensor
|
987
|
+
|
988
|
+
if unfolded:
|
989
|
+
embed_dim = input_tensor.shape[-1]
|
990
|
+
input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
|
991
|
+
# if we ever padded before unfolding, we need to remove the padding
|
992
|
+
if chunk_pad_size > 0:
|
993
|
+
input_tensor = input_tensor[:, :-chunk_pad_size, :]
|
994
|
+
|
995
|
+
return input_tensor, masks # , layer_emb
|
996
|
+
|
997
|
+
|
998
|
+
class WindowQformer(nn.Module):
|
999
|
+
"""Window-level Qformer"""
|
1000
|
+
|
1001
|
+
def __init__(
|
1002
|
+
self,
|
1003
|
+
window_size: int = 8,
|
1004
|
+
num_queries: int = 1,
|
1005
|
+
num_blocks: int = 2,
|
1006
|
+
attention_dim: int = 512,
|
1007
|
+
attention_heads: int = 8,
|
1008
|
+
linear_units: int = 2048,
|
1009
|
+
dropout_rate: float = 0.0,
|
1010
|
+
normalize_before: bool = True,
|
1011
|
+
):
|
1012
|
+
super().__init__()
|
1013
|
+
|
1014
|
+
self.decoders = nn.ModuleList(
|
1015
|
+
[
|
1016
|
+
nn.TransformerDecoderLayer(
|
1017
|
+
d_model=attention_dim,
|
1018
|
+
nhead=attention_heads,
|
1019
|
+
dim_feedforward=linear_units,
|
1020
|
+
dropout=dropout_rate,
|
1021
|
+
activation="relu",
|
1022
|
+
batch_first=True,
|
1023
|
+
norm_first=normalize_before, # TODO need to verify
|
1024
|
+
)
|
1025
|
+
for _ in range(num_blocks)
|
1026
|
+
]
|
1027
|
+
)
|
1028
|
+
|
1029
|
+
self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
|
1030
|
+
self.after_norm = (
|
1031
|
+
nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
|
1032
|
+
)
|
1033
|
+
self.window_size = window_size
|
1034
|
+
|
1035
|
+
def forward(self, audio_embed, mask, embed_len=None):
|
1036
|
+
"""forward decoder"""
|
1037
|
+
# audio_embed: N x T x D => N x D x T
|
1038
|
+
|
1039
|
+
audio_embed = audio_embed.transpose(1, 2)
|
1040
|
+
# audio_embed: N x D x 1 x T => N x DK x T'
|
1041
|
+
padding = audio_embed.shape[-1] % self.window_size
|
1042
|
+
if padding > 0:
|
1043
|
+
audio_embed = F.pad(
|
1044
|
+
audio_embed, (0, self.window_size - padding), "constant", 0
|
1045
|
+
)
|
1046
|
+
|
1047
|
+
embed_chunk = F.unfold(
|
1048
|
+
audio_embed[..., None, :],
|
1049
|
+
kernel_size=(1, self.window_size),
|
1050
|
+
stride=(1, self.window_size),
|
1051
|
+
)
|
1052
|
+
bsz, _, slen = embed_chunk.shape
|
1053
|
+
# N x D x K x T'
|
1054
|
+
embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
|
1055
|
+
# N x T' x K x D
|
1056
|
+
embed_chunk = embed_chunk.transpose(1, 3).contiguous()
|
1057
|
+
# NT' x K x D
|
1058
|
+
embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
|
1059
|
+
# NT' x 1 x D
|
1060
|
+
q = self.queries.expand(bsz * slen, -1, -1)
|
1061
|
+
for layer in self.decoders:
|
1062
|
+
q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
|
1063
|
+
|
1064
|
+
if self.after_norm is not None:
|
1065
|
+
q = self.after_norm(q)
|
1066
|
+
|
1067
|
+
if embed_len is not None:
|
1068
|
+
embed_len = embed_len // self.window_size
|
1069
|
+
# N x T' x D
|
1070
|
+
out = q.view(bsz, slen, -1)
|
1071
|
+
|
1072
|
+
return out, embed_len
|
1073
|
+
|
1074
|
+
|
1075
|
+
class AudioEmbedding(nn.Module):
|
1076
|
+
"""Image embedding."""
|
1077
|
+
|
1078
|
+
def __init__(self, config: PretrainedConfig, **kwargs) -> None:
|
1079
|
+
super().__init__()
|
1080
|
+
self.config = config
|
1081
|
+
# n_embed or hidden_size for text LM
|
1082
|
+
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
|
1083
|
+
|
1084
|
+
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
1085
|
+
|
1086
|
+
audio_dim_out = (
|
1087
|
+
None # Set this variable according to the actual audio processor
|
1088
|
+
)
|
1089
|
+
self.layer_idx = -2
|
1090
|
+
|
1091
|
+
if (
|
1092
|
+
isinstance(config.audio_processor, dict)
|
1093
|
+
and config.audio_processor.get("name", None) == "cascades"
|
1094
|
+
):
|
1095
|
+
encoder_config = config.audio_processor.get("config", None)
|
1096
|
+
assert encoder_config is not None
|
1097
|
+
self.encoder = ConformerEncoder(**encoder_config)
|
1098
|
+
|
1099
|
+
audio_dim_out = encoder_config["attention_dim"]
|
1100
|
+
n_mels = encoder_config["input_size"]
|
1101
|
+
else:
|
1102
|
+
raise NotImplementedError("")
|
1103
|
+
|
1104
|
+
assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
|
1105
|
+
self.audio_dim_out = audio_dim_out
|
1106
|
+
self.audio_dim_in = n_mels
|
1107
|
+
|
1108
|
+
self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
|
1109
|
+
|
1110
|
+
self.downsample_rate = kwargs.get("downsample_rate", 1)
|
1111
|
+
|
1112
|
+
if kwargs.get("use_qformer", False):
|
1113
|
+
qformer_config = kwargs.get("qformer_config", {})
|
1114
|
+
qformer_config["attention_dim"] = audio_dim_out
|
1115
|
+
self.qformer = WindowQformer(**qformer_config)
|
1116
|
+
else:
|
1117
|
+
self.qformer = None
|
1118
|
+
|
1119
|
+
if kwargs.get("use_conv_downsample", False):
|
1120
|
+
assert (
|
1121
|
+
self.qformer is None
|
1122
|
+
), "don't support use qformer and conv downsample together"
|
1123
|
+
nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
|
1124
|
+
default_nemo_conv_settings = {
|
1125
|
+
"subsampling": "dw_striding",
|
1126
|
+
"subsampling_factor": self.downsample_rate,
|
1127
|
+
"feat_in": audio_dim_out,
|
1128
|
+
"feat_out": audio_dim_out,
|
1129
|
+
"conv_channels": 256,
|
1130
|
+
"subsampling_conv_chunking_factor": 1,
|
1131
|
+
"activation": nn.ReLU(),
|
1132
|
+
"is_causal": False,
|
1133
|
+
}
|
1134
|
+
# Override any of the defaults with the incoming, user settings
|
1135
|
+
if nemo_conv_settings:
|
1136
|
+
default_nemo_conv_settings.update(nemo_conv_settings)
|
1137
|
+
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
1138
|
+
assert (
|
1139
|
+
i not in nemo_conv_settings
|
1140
|
+
), "{i} should be specified outside of the NeMo dictionary"
|
1141
|
+
|
1142
|
+
self.conv_ds = NemoConvSubsampling(
|
1143
|
+
**default_nemo_conv_settings,
|
1144
|
+
)
|
1145
|
+
else:
|
1146
|
+
self.conv_ds = None
|
1147
|
+
|
1148
|
+
projection_cls = kwargs.get("projection_cls", "linear")
|
1149
|
+
if projection_cls == "linear":
|
1150
|
+
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
|
1151
|
+
elif projection_cls == "mlp":
|
1152
|
+
# follow llava-v1.5's implementation
|
1153
|
+
# (do not use image_projection and image_proj_norm)
|
1154
|
+
dim_projection = hidden_size
|
1155
|
+
depth = 2
|
1156
|
+
self.linear_downsample_rate = (
|
1157
|
+
1 if (self.qformer or self.conv_ds) else self.downsample_rate
|
1158
|
+
)
|
1159
|
+
layers = [
|
1160
|
+
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
|
1161
|
+
]
|
1162
|
+
for _ in range(1, depth):
|
1163
|
+
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
1164
|
+
self.audio_projection = nn.Sequential(*layers)
|
1165
|
+
# NOTE vision-speech tasks use a separate projection layer
|
1166
|
+
layers = [
|
1167
|
+
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
|
1168
|
+
]
|
1169
|
+
for _ in range(1, depth):
|
1170
|
+
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
1171
|
+
self.audio_projection_for_vision = nn.Sequential(*layers)
|
1172
|
+
else:
|
1173
|
+
raise NotImplementedError(
|
1174
|
+
f"projection_cls = {projection_cls}, not implemented"
|
1175
|
+
)
|
1176
|
+
|
1177
|
+
# TODO: audio sequence compression - Qformer
|
1178
|
+
self.vocab_size = config.vocab_size
|
1179
|
+
self.input_embeds = None
|
1180
|
+
self.audio_embed_sizes = None
|
1181
|
+
|
1182
|
+
def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None:
|
1183
|
+
self.input_embeds = input_embeds
|
1184
|
+
|
1185
|
+
def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None:
|
1186
|
+
self.audio_embed_sizes = audio_embed_sizes
|
1187
|
+
|
1188
|
+
def get_audio_features(
|
1189
|
+
self,
|
1190
|
+
input_embeds: torch.FloatTensor,
|
1191
|
+
audio_attention_mask: torch.Tensor = None,
|
1192
|
+
audio_projection_mode: str = "speech",
|
1193
|
+
) -> torch.FloatTensor:
|
1194
|
+
"""
|
1195
|
+
arguments:
|
1196
|
+
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
1197
|
+
"""
|
1198
|
+
if self.freeze_audio_processor:
|
1199
|
+
with torch.no_grad():
|
1200
|
+
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
|
1201
|
+
else:
|
1202
|
+
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
|
1203
|
+
|
1204
|
+
if self.qformer is not None:
|
1205
|
+
audio_features, _ = self.qformer(audio_features, mask=None)
|
1206
|
+
|
1207
|
+
if self.conv_ds is not None:
|
1208
|
+
if masks is not None:
|
1209
|
+
masks = masks.squeeze(1)
|
1210
|
+
|
1211
|
+
audio_features, masks = self.conv_ds(audio_features, mask=masks)
|
1212
|
+
|
1213
|
+
if self.linear_downsample_rate != 1:
|
1214
|
+
bs, seq_len, feat_dim = audio_features.size()
|
1215
|
+
padding = seq_len % self.linear_downsample_rate
|
1216
|
+
if padding > 0:
|
1217
|
+
audio_features = F.pad(
|
1218
|
+
audio_features,
|
1219
|
+
(0, 0, 0, self.linear_downsample_rate - padding),
|
1220
|
+
"constant",
|
1221
|
+
0,
|
1222
|
+
)
|
1223
|
+
|
1224
|
+
seq_len = audio_features.size(1)
|
1225
|
+
audio_features = audio_features.view(
|
1226
|
+
bs,
|
1227
|
+
seq_len // self.linear_downsample_rate,
|
1228
|
+
feat_dim * self.linear_downsample_rate,
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
if audio_projection_mode == "speech":
|
1232
|
+
audio_set_tensor = self.audio_projection(audio_features)
|
1233
|
+
elif audio_projection_mode == "vision":
|
1234
|
+
audio_set_tensor = self.audio_projection_for_vision(audio_features)
|
1235
|
+
else:
|
1236
|
+
raise ValueError(
|
1237
|
+
f"audio_projection_mode = {audio_projection_mode} not " "implemented"
|
1238
|
+
)
|
1239
|
+
|
1240
|
+
return audio_set_tensor
|
1241
|
+
|
1242
|
+
def forward(
|
1243
|
+
self,
|
1244
|
+
audio_features: torch.FloatTensor,
|
1245
|
+
audio_attention_mask: torch.Tensor = None,
|
1246
|
+
audio_projection_mode: str = "speech",
|
1247
|
+
) -> torch.FloatTensor:
|
1248
|
+
"""
|
1249
|
+
arguments:
|
1250
|
+
audio_features: audio features (num_audio_tokens, T, D)
|
1251
|
+
|
1252
|
+
returns:
|
1253
|
+
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
|
1254
|
+
"""
|
1255
|
+
audio_embeds = self.get_audio_features(
|
1256
|
+
audio_features,
|
1257
|
+
audio_attention_mask=audio_attention_mask,
|
1258
|
+
audio_projection_mode=audio_projection_mode,
|
1259
|
+
)
|
1260
|
+
return audio_embeds
|