sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1260 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ #!/usr/bin/env python3
15
+ import abc
16
+ import math
17
+ from typing import Literal, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import Tensor, nn
23
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
24
+ CheckpointWrapper,
25
+ )
26
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
27
+ from transformers import PretrainedConfig
28
+
29
+ from sglang.srt.models.phi4mm_utils import (
30
+ AbsolutePositionalEncoding,
31
+ ConvModule,
32
+ FeedForward,
33
+ MeanVarianceNormLayer,
34
+ MultiHeadedAttention,
35
+ MultiSequential,
36
+ NemoConvSubsampling,
37
+ T5RelativeAttentionLogitBias,
38
+ adaptive_enc_mask,
39
+ get_offset,
40
+ unfold_tensor,
41
+ )
42
+
43
+ _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
44
+
45
+
46
+ class ConformerEncoderLayer(nn.Module):
47
+ """ConformerEncoder Layer module.
48
+ for more details see conformer paper:
49
+ https://arxiv.org/abs/2005.08100
50
+ This module implement the Conformer block layer.
51
+
52
+ Args:
53
+ d_model: int
54
+ attention dim.
55
+ ext_pw_out_channel: int
56
+ if > 0, ext_pw_out_channel is a dim channel size
57
+ for the last pointwise conv after swish activation.
58
+ depthwise_seperable_out_channel: int
59
+ if set different to 0, the number of
60
+ depthwise_seperable_out_channel will be used as a
61
+ channel_out of the second conv1d layer.
62
+ otherwise, it equal to 0, the second conv1d layer is skipped.
63
+ depthwise_multiplier: int
64
+ number of input_dim channels duplication. this value
65
+ will be used to compute the hidden channels of the Conv1D.
66
+ n_head: int
67
+ the number of heads for multihead attention module.
68
+ d_ffn: int
69
+ output size of the feed_forward blocks.
70
+ ext_pw_kernel_size: int
71
+ kernel size of the conv pointwise of the conformer.
72
+ kernel_size: int
73
+ kernel size.
74
+ dropout_rate: float
75
+ dropout rate.
76
+ causal: bool, optional
77
+ if set to True, convolution have no access
78
+ to future frames. default False.
79
+ batch_norm: bool, optional
80
+ if set to True, apply batchnorm before activation
81
+ in ConvModule layer of the conformer.
82
+ default False
83
+ activation: str, optional
84
+ activation function name,
85
+ one of ["relu", "swish", "sigmoid"],
86
+ sigmoid activation is only used with "glu_in_fnn=True",
87
+ default "relu".
88
+ chunk_se: int, optional
89
+ 0 for offline SE.
90
+ 1 for streaming SE, where mean is computed
91
+ by accumulated history until current chunk_se.
92
+ 2 for streaming SE, where mean is computed
93
+ by only the current chunk.
94
+ default 0.
95
+ chunk_size: int, optional
96
+ chunk_size for cnn. default 18
97
+ conv_activation: str, optional
98
+ activation function used in ConvModule part
99
+ of the conformer, default "relu".
100
+ conv_glu_type: str, optional
101
+ activation function used for the glu inside
102
+ the ConvModule part of the conformer.
103
+ default: "sigmoid".
104
+ bias_in_glu: bool, optional
105
+ if set to True, use additive bias in the weight module
106
+ before GLU.
107
+ linear_glu_in_convm: bool, optional
108
+ if set to True, use GLULinear module,
109
+ otherwise, used GLUPointWiseConv module.
110
+ default to False.
111
+ attention_inner_dim: int, optional
112
+ if equal to -1, attention dim for linears k/q/v is
113
+ equal to d_model. otherwise attention_inner_dim is used.
114
+ default -1.
115
+ attention_glu_type: str, optional
116
+ activation function for glu used in the multihead attention,
117
+ default "swish".
118
+ activation_checkpointing: str, optional
119
+ a dictionarry of {"module","interval","offload"}, where
120
+ "module": str
121
+ accept ["transformer", "attention"] to select
122
+ which module should do activation checkpointing.
123
+ "interval": int, default 1,
124
+ interval of applying activation checkpointing,
125
+ interval = 1 means that we apply checkpointing
126
+ on every layer (if activation), otherwise,
127
+ we apply it every x interval.
128
+ "offload": bool, default False,
129
+ if set to True, we offload activation to cpu and
130
+ reload it during backward, otherwise,
131
+ we recalculate activation in backward.
132
+ default "".
133
+ export: bool, optional
134
+ if set to True, it remove the padding from convolutional layers
135
+ and allow the onnx conversion for inference.
136
+ default False.
137
+ use_pt_scaled_dot_product_attention: bool, optional
138
+ if set to True, use pytorch's scaled dot product attention
139
+ implementation in training.
140
+ attn_group_sizes: int, optional
141
+ the number of groups to use for attention, default 1
142
+ (Multi-Head Attention),
143
+ 1 = typical Multi-Head Attention,
144
+ 1 < attn_group_sizes < attention_heads = Grouped-Query Attention
145
+ attn_group_sizes = attention_heads = Multi-Query Attention
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ d_model=512,
151
+ ext_pw_out_channel=0,
152
+ depthwise_seperable_out_channel=256,
153
+ depthwise_multiplier=1,
154
+ n_head=4,
155
+ d_ffn=2048,
156
+ ext_pw_kernel_size=1,
157
+ kernel_size=3,
158
+ dropout_rate=0.1,
159
+ causal=False,
160
+ batch_norm=False,
161
+ activation="relu",
162
+ chunk_se=0,
163
+ chunk_size=18,
164
+ conv_activation="relu",
165
+ conv_glu_type="sigmoid",
166
+ bias_in_glu=True,
167
+ linear_glu_in_convm=False,
168
+ attention_inner_dim=-1,
169
+ attention_glu_type="swish",
170
+ activation_checkpointing="",
171
+ export=False,
172
+ use_pt_scaled_dot_product_attention=False,
173
+ attn_group_sizes: int = 1,
174
+ ):
175
+ super().__init__()
176
+
177
+ self.feed_forward_in = FeedForward(
178
+ d_model=d_model,
179
+ d_inner=d_ffn,
180
+ dropout_rate=dropout_rate,
181
+ activation=activation,
182
+ bias_in_glu=bias_in_glu,
183
+ )
184
+
185
+ self.self_attn = MultiHeadedAttention(
186
+ n_head,
187
+ d_model,
188
+ dropout_rate,
189
+ attention_inner_dim,
190
+ attention_glu_type,
191
+ bias_in_glu,
192
+ use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
193
+ group_size=attn_group_sizes,
194
+ )
195
+ self.conv = ConvModule(
196
+ d_model,
197
+ ext_pw_out_channel,
198
+ depthwise_seperable_out_channel,
199
+ ext_pw_kernel_size,
200
+ kernel_size,
201
+ depthwise_multiplier,
202
+ dropout_rate,
203
+ causal,
204
+ batch_norm,
205
+ chunk_se,
206
+ chunk_size,
207
+ conv_activation,
208
+ conv_glu_type,
209
+ bias_in_glu,
210
+ linear_glu_in_convm,
211
+ export=export,
212
+ )
213
+
214
+ self.feed_forward_out = FeedForward(
215
+ d_model=d_model,
216
+ d_inner=d_ffn,
217
+ dropout_rate=dropout_rate,
218
+ activation=activation,
219
+ bias_in_glu=bias_in_glu,
220
+ )
221
+
222
+ self.layer_norm_att = nn.LayerNorm(d_model)
223
+ self.layer_norm = nn.LayerNorm(d_model)
224
+
225
+ def forward(
226
+ self,
227
+ x,
228
+ pos_k,
229
+ pos_v,
230
+ mask,
231
+ relative_attention_bias: Optional[Tensor] = None,
232
+ ):
233
+ """ConformerEncoder forward.
234
+
235
+ Args:
236
+ x: torch.Tensor
237
+ input feature of shape (batch, max_time_in, size)
238
+ pos_k: torch.Tensor
239
+ positional key embedding.
240
+ mask: torch.Tensor
241
+ mask for x (batch, max_time_in)
242
+ relative_attention_bias: Optional[torch.Tensor]
243
+ bias added to attention logits w.r.t. relative positions
244
+ (1, n_head, time1, time2)
245
+ """
246
+ x = x + 0.5 * self.feed_forward_in(x)
247
+ norm_x = self.layer_norm_att(x)
248
+
249
+ x = x + self.self_attn(
250
+ norm_x,
251
+ norm_x,
252
+ norm_x,
253
+ pos_k,
254
+ pos_v,
255
+ mask,
256
+ relative_attention_bias=relative_attention_bias,
257
+ )
258
+ x = x + self.conv(x)
259
+ x = x + 0.5 * self.feed_forward_out(x)
260
+
261
+ out = self.layer_norm(x)
262
+
263
+ return out, pos_k, pos_v, mask
264
+
265
+
266
+ class TransformerEncoderBase(abc.ABC, nn.Module):
267
+ """The Base class for Transformer based encoders
268
+
269
+ Please set causal = True in streaming model
270
+ Args:
271
+ input_size: int
272
+ input feature dimension.
273
+ chunk_size: int, list(int)
274
+ Number of frames for each chunk
275
+ This variable can take 2 forms:
276
+ int: Used for inference, or single chunk size training
277
+ list(int) : Used only for variable chunk size training
278
+ Some examples for the 2 cases:
279
+ chunk_size = 12
280
+ chunk_size = [6, 8, 12, 24]
281
+ left_chunk: int, list(int)
282
+ Number of chunks used for masking in streaming mode.
283
+ This variable can take 2 forms:
284
+ int: Used for inference, or single chunk size training
285
+ list(int) : Used only for variable chunk size training. When
286
+ chunk_size is a list, left_chunk must be a list with same length.
287
+ Some examples for the 2 cases:
288
+ left_chunk = 6
289
+ left_chunk = [12, 9, 6, 3]
290
+ attention_dim: int, optional
291
+ attention dimension. default 256.
292
+ attention_heads: int, optional
293
+ the number of heads. default 4
294
+ input_layer: str, optional
295
+ input layer type before Conformer,
296
+ one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
297
+ default "conv2d"
298
+ cnn_out: int, optional
299
+ the number of CNN channels before Conformer.
300
+ default -1.
301
+ cnn_layer_norm: bool, optional
302
+ layer norm between Conformer and the first CNN.
303
+ default False.
304
+ time_reduction: int, optional
305
+ time reduction factor
306
+ default 4
307
+ dropout_rate: float, optional
308
+ dropout rate. default 0.1
309
+ padding_idx: int, optional
310
+ padding index for input_layer=embed
311
+ default -1
312
+ relative_attention_bias_args: dict, optional
313
+ use more efficient scalar bias-based relative multihead attention
314
+ (Q*K^T + B) implemented in cmb.basics.embedding.
315
+ [T5/ALiBi]RelativeAttentionLogitBias
316
+ usage: relative_attention_bias_args={"type": t5/alibi}
317
+ additional method-specific arguments can be provided (see
318
+ transformer_base.py)
319
+ positional_dropout_rate: float, optional
320
+ dropout rate after positional encoding. default 0.0
321
+ nemo_conv_settings: dict, optional
322
+ A dictionary of settings for NeMo Subsampling.
323
+ default None
324
+ conv2d_extra_padding: str, optional
325
+ Add extra padding in conv2d subsampling layers. Choices are
326
+ (feat, feat_time, none, True).
327
+ if True or feat_time, the extra padding is added into non full
328
+ supraframe utts in batch.
329
+ Default: none
330
+ attention_group_size: int, optional
331
+ the number of groups to use for attention, default 1
332
+ (Multi-Head Attention),
333
+ 1 = typical Multi-Head Attention,
334
+ 1 < attention_group_size < attention_heads = Grouped-Query
335
+ Attention
336
+ attention_group_size = attention_heads = Multi-Query Attention
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ input_size,
342
+ chunk_size,
343
+ left_chunk,
344
+ attention_dim=256,
345
+ attention_heads=4,
346
+ input_layer="nemo_conv",
347
+ cnn_out=-1,
348
+ cnn_layer_norm=False,
349
+ time_reduction=4,
350
+ dropout_rate=0.0,
351
+ padding_idx=-1,
352
+ relative_attention_bias_args=None,
353
+ positional_dropout_rate=0.0,
354
+ nemo_conv_settings=None,
355
+ conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
356
+ attention_group_size=1,
357
+ encoder_embedding_config=None,
358
+ ):
359
+ super().__init__()
360
+ self.input_size = input_size
361
+ self.input_layer = input_layer
362
+ self.chunk_size = chunk_size
363
+ self.left_chunk = left_chunk
364
+ self.attention_dim = attention_dim
365
+ self.num_heads = attention_heads
366
+ self.attention_group_size = attention_group_size
367
+ self.time_reduction = time_reduction
368
+ self.nemo_conv_settings = nemo_conv_settings
369
+ self.encoder_embedding_config = encoder_embedding_config
370
+
371
+ if self.input_layer == "nemo_conv":
372
+ default_nemo_conv_settings = {
373
+ "subsampling": "dw_striding",
374
+ "subsampling_factor": self.time_reduction,
375
+ "feat_in": input_size,
376
+ "feat_out": attention_dim,
377
+ "conv_channels": 256,
378
+ "subsampling_conv_chunking_factor": 1,
379
+ "activation": nn.ReLU(),
380
+ "is_causal": False,
381
+ }
382
+ # Override any of the defaults with the incoming, user settings
383
+ if nemo_conv_settings:
384
+ default_nemo_conv_settings.update(nemo_conv_settings)
385
+ for i in ["subsampling_factor", "feat_in", "feat_out"]:
386
+ assert (
387
+ i not in nemo_conv_settings
388
+ ), "{i} should be specified outside of the NeMo dictionary"
389
+
390
+ self.embed = NemoConvSubsampling(
391
+ **default_nemo_conv_settings,
392
+ )
393
+ else:
394
+ raise ValueError("unknown input_layer: " + input_layer)
395
+
396
+ self.pos_emb = AbsolutePositionalEncoding(
397
+ attention_dim, positional_dropout_rate
398
+ )
399
+
400
+ self.relative_attention_bias_type = (
401
+ relative_attention_bias_args.get("type")
402
+ if relative_attention_bias_args
403
+ else None
404
+ )
405
+ if self.relative_attention_bias_type == "t5":
406
+ assert (
407
+ self.num_heads % self.attention_group_size == 0
408
+ ), "attention_group_size must divide n_head"
409
+ self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
410
+ self.num_heads // self.attention_group_size,
411
+ max_distance=relative_attention_bias_args.get(
412
+ "t5_bias_max_distance", 1000
413
+ ),
414
+ symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
415
+ )
416
+ else:
417
+ raise NotImplementedError
418
+
419
+ self.encoder_embedding = MeanVarianceNormLayer(
420
+ self.encoder_embedding_config["input_size"]
421
+ )
422
+
423
+ def compute_lens_change(self, feature_lens):
424
+ """feature_lens: int
425
+ return updated feature lens.
426
+
427
+ This used to return a different lambda function for each case that
428
+ computed the right thing. That does not work within Torchscript.
429
+ If you really need this to be faster, create nn.Module()-s for all
430
+ the cases and return one of them. Torchscript does support that.
431
+ """
432
+ if self.input_layer == "nemo_conv":
433
+ # Handle the special causal case
434
+ subsampling_causal_cond = self.nemo_conv_settings.get(
435
+ "subsampling", "dw_striding"
436
+ ) in [
437
+ "dw_striding",
438
+ "striding",
439
+ "striding_conv1d",
440
+ ]
441
+ is_causal = self.nemo_conv_settings.get("is_causal", False)
442
+ if is_causal and subsampling_causal_cond:
443
+ lens_change = (
444
+ torch.ceil(feature_lens / self.time_reduction).long()
445
+ if isinstance(feature_lens, Tensor)
446
+ else math.ceil(feature_lens / self.time_reduction)
447
+ )
448
+ feature_lens_remainder = feature_lens % self.time_reduction
449
+ if isinstance(feature_lens, Tensor):
450
+ lens_change[feature_lens_remainder != 1] += 1
451
+ elif feature_lens_remainder != 1:
452
+ lens_change += 1
453
+ return lens_change
454
+ ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
455
+ return ceil_func(feature_lens / self.time_reduction)
456
+
457
+ @abc.abstractmethod
458
+ def forward(self):
459
+ """Abstract forward method implementation."""
460
+
461
+ def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
462
+ """If chunk size is a list, we will randomly select a chunk size."""
463
+
464
+ if chunk_size is None:
465
+ chunk_size = self.chunk_size
466
+ if left_chunk is None:
467
+ left_chunk = self.left_chunk
468
+ if isinstance(chunk_size, list):
469
+ # Variable chunk size during training
470
+ chunk_size_index = int(
471
+ torch.randint(low=0, high=len(chunk_size), size=(1,))
472
+ )
473
+ chunk_size_train_eff = chunk_size[chunk_size_index]
474
+ if not isinstance(left_chunk, list):
475
+ raise ValueError(
476
+ "Since chunk_size is a list, left_chunk must be a list"
477
+ )
478
+ if len(left_chunk) != len(chunk_size):
479
+ raise ValueError(
480
+ "The length of left_chunk must be the same as length of "
481
+ "chunk_size."
482
+ )
483
+ left_chunk_train_eff = left_chunk[chunk_size_index]
484
+ else:
485
+ chunk_size_train_eff = chunk_size
486
+ left_chunk_train_eff = left_chunk
487
+
488
+ return chunk_size_train_eff, left_chunk_train_eff
489
+
490
+ def _get_embed_class(self, embed):
491
+ # pylint: disable=protected-access
492
+ is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
493
+ is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
494
+ embed_class = embed
495
+ if is_embed_using_act_chkpt:
496
+ embed_class = embed._checkpoint_wrapped_module
497
+ if is_embed_fsdp_wrapped:
498
+ embed_class = embed.module
499
+ return embed_class
500
+
501
+ def _forward_embeddings_core(self, input_tensor, masks):
502
+ embed_class = self._get_embed_class(self.embed)
503
+ assert isinstance(embed_class, NemoConvSubsampling)
504
+ input_tensor, masks = self.embed(input_tensor, masks)
505
+ return input_tensor, masks
506
+
507
+ def _position_embedding(self, input_tensor):
508
+ pos_k = None
509
+ pos_v = None
510
+ if self.relative_attention_bias_layer is None:
511
+ input_tensor = self.pos_emb(
512
+ input_tensor
513
+ ) # default to add abs sinusoid embedding
514
+ return pos_k, pos_v
515
+
516
+ def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
517
+ chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
518
+ chunk_size, left_chunk
519
+ )
520
+
521
+ # Create mask matrix for streaming
522
+ # S stores start index. if chunksize is 18, s is [0,18,36,....]
523
+ chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
524
+
525
+ enc_streaming_mask = (
526
+ adaptive_enc_mask(
527
+ seq_len, chunk_start_idx, left_window=left_chunk_train_eff
528
+ )
529
+ .unsqueeze(0)
530
+ .expand([batch_size, -1, -1])
531
+ )
532
+ return enc_streaming_mask
533
+
534
+ def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None):
535
+ """Forwarding the inputs through the top embedding layers
536
+
537
+ Args:
538
+ xs_pad: torch.Tensor
539
+ input tensor
540
+ masks: torch.Tensor
541
+ input mask
542
+ chunk_size_nc: (optional, default is None) chunk size for
543
+ non-causal layers
544
+ left_chunk_nc: (optional, default is None) # of left chunks for
545
+ non-causal layers
546
+ """
547
+ # pylint: disable=R0915
548
+ # get new lens.
549
+ seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
550
+ if seq_len <= 0:
551
+ raise ValueError(
552
+ f"""The sequence length after time reduction is invalid:
553
+ {seq_len}. Your input feature is too short. Consider
554
+ filtering out the very short sentence from data
555
+ loader""",
556
+ )
557
+
558
+ batch_size = xs_pad.shape[0]
559
+
560
+ enc_streaming_mask = self._streaming_mask(
561
+ seq_len, batch_size, self.chunk_size, self.left_chunk
562
+ )
563
+
564
+ if xs_pad.is_cuda:
565
+ enc_streaming_mask = enc_streaming_mask.cuda()
566
+ xs_pad = xs_pad.cuda()
567
+
568
+ input_tensor = xs_pad
569
+ input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
570
+
571
+ streaming_mask = enc_streaming_mask
572
+ if streaming_mask is not None and masks is not None:
573
+ hs_mask = masks & streaming_mask
574
+ elif masks is not None:
575
+ hs_mask = masks
576
+ else:
577
+ hs_mask = streaming_mask
578
+
579
+ if chunk_size_nc is not None:
580
+ enc_streaming_mask_nc = self._streaming_mask(
581
+ seq_len, batch_size, chunk_size_nc, left_chunk_nc
582
+ )
583
+ if xs_pad.is_cuda:
584
+ enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
585
+ if masks is not None:
586
+ hs_mask_nc = masks & enc_streaming_mask_nc
587
+ else:
588
+ hs_mask_nc = enc_streaming_mask_nc
589
+ else:
590
+ hs_mask_nc = None
591
+
592
+ pos_k, pos_v = self._position_embedding(input_tensor)
593
+
594
+ if chunk_size_nc is None:
595
+ return input_tensor, pos_k, pos_v, hs_mask, masks
596
+ return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
597
+
598
+ def get_offset(self):
599
+ """Returns offset used when retaining inputs for decoding.
600
+
601
+ This is essentially, how many additional frames have to be added to
602
+ the front-end CNN input to ensure it can produce a single output.
603
+ So if the "padding" parameter is 0, typically offset will be > 0.
604
+ """
605
+ return get_offset(self.input_layer, self.time_reduction)
606
+
607
+
608
+ class ConformerEncoder(TransformerEncoderBase):
609
+ """ConformerEncoder module.
610
+ see original paper for more details:
611
+ https://arxiv.org/abs/2005.08100
612
+
613
+ Please set causal = True in streaming model
614
+ Args:
615
+ input_size: int
616
+ input feature dimension.
617
+ chunk_size: int, list(int)
618
+ Number of frames for each chunk
619
+ This variable can take 2 forms:
620
+ int: Used for inference, or single chunk size training
621
+ list(int) : Used only for variable chunk size training
622
+ Some examples for the 2 cases:
623
+ chunk_size = 12
624
+ chunk_size = [6, 8, 12, 24]
625
+ left_chunk: int, list(int)
626
+ Number of chunks used for masking in streaming mode.
627
+ This variable can take 2 forms:
628
+ int: Used for inference, or single chunk size training
629
+ list(int) : Used only for variable chunk size training. When
630
+ chunk_size is a list, left_chunk must be a list with same length.
631
+ Some examples for the 2 cases:
632
+ left_chunk = 6
633
+ left_chunk = [12, 9, 6, 3]
634
+ left_chunk: int
635
+ number of chunks used for masking in streaming mode.
636
+ num_lang: int
637
+ This parameter is used to store the number of languages in the
638
+ lang_dict, only used for multiseed/multilingual models.
639
+ default None.
640
+ attention_dim: int, optional
641
+ attention dimension. default 256.
642
+ attention_heads: int, optional
643
+ the number of heads. default 4
644
+ linear_units:
645
+ the number of units of position-wise feed forward.
646
+ default 2048
647
+ num_block:
648
+ number of Transformer layer. default 6
649
+ dropout_rate: float, optional
650
+ dropout rate. default 0.1
651
+ input_layer: str, optional
652
+ input layer type before Conformer,
653
+ one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
654
+ default "conv2d"
655
+ causal: bool, optional
656
+ if set to True, convolution have no access
657
+ to future frames. default False.
658
+ batch_norm: bool, optional
659
+ if set to True, apply batchnorm before activation
660
+ in ConvModule layer of the conformer.
661
+ default False
662
+ cnn_out: int, optional
663
+ the number of CNN channels before Conformer.
664
+ default -1.
665
+ cnn_layer_norm: bool, optional
666
+ layer norm between Conformer and the first CNN.
667
+ default False.
668
+ ext_pw_out_channel: int, optional
669
+ the number of channel for CNN
670
+ before depthwise_seperable_CNN.
671
+ If 0 then use linear. default 0.
672
+ ext_pw_kernel_size: int, optional
673
+ kernel size of N before depthwise_seperable_CNN.
674
+ only work for ext_pw_out_channel > 0.
675
+ default 1
676
+ depthwise_seperable_out_channel: int, optional
677
+ the number of channel for
678
+ depthwise_seperable_CNN.
679
+ default 256.
680
+ depthwise_multiplier: int, optional
681
+ the number of multiplier for
682
+ depthwise_seperable_CNN.
683
+ default 1.
684
+ chunk_se: int, optional
685
+ 0 for offline SE.
686
+ 1 for streaming SE, where mean is computed
687
+ by accumulated history until current chunk_se.
688
+ 2 for streaming SE, where mean is computed
689
+ by only the current chunk.
690
+ default 0.
691
+ kernel_size: int, optional
692
+ the number of kernels for depthwise_seperable_CNN.
693
+ default 3.
694
+ activation: str, optional
695
+ FeedForward block activation.
696
+ one of ["relu", "swish", "sigmoid"]
697
+ default "relu".
698
+ conv_activation: str, optional
699
+ activation function used in ConvModule part
700
+ of the conformer, default "relu".
701
+ conv_glu_type: str, optional
702
+ activation used use glu in depthwise_seperable_CNN,
703
+ default "sigmoid"
704
+ bias_in_glu: bool, optional
705
+ if set to True, use additive bias in the weight module
706
+ before GLU. default True
707
+ linear_glu_in_convm: bool, optional
708
+ if set to True, use GLULinear module,
709
+ otherwise, used GLUPointWiseConv module.
710
+ default to False.
711
+ attention_glu_type: str
712
+ only work for glu_in_attention !=0
713
+ default "swish".
714
+ export: bool, optional
715
+ if set to True, it remove the padding from convolutional layers
716
+ and allow the onnx conversion for inference.
717
+ default False.
718
+ activation_checkpointing: str, optional
719
+ a dictionarry of {"module","interval","offload"}, where
720
+ "module": str
721
+ accept ["transformer", "attention"] to select
722
+ which module should do activation checkpointing.
723
+ "interval": int, default 1,
724
+ interval of applying activation checkpointing,
725
+ interval = 1 means that we apply checkpointing
726
+ on every layer (if activation), otherwise,
727
+ we apply it every x interval.
728
+ "offload": bool, default False,
729
+ if set to True, we offload activation to cpu and
730
+ reload it during backward, otherwise,
731
+ we recalculate activation in backward.
732
+ default "".
733
+ extra_layer_output_idx: int
734
+ the layer index to be exposed.
735
+ relative_attention_bias_args: dict, optional
736
+ use more efficient scalar bias-based relative multihead attention
737
+ (Q*K^T + B) implemented in cmb.basics.embedding.
738
+ [T5/ALiBi]RelativeAttentionLogitBias
739
+ usage: relative_attention_bias_args={"type": t5/alibi}
740
+ additional method-specific arguments can be provided (see
741
+ transformer_base.py)
742
+ time_reduction: int optional
743
+ time reduction factor
744
+ default 4
745
+ use_pt_scaled_dot_product_attention: whether to use pytorch scaled
746
+ dot product attention in training.
747
+ Default: False
748
+ nemo_conv_settings: dict, optional
749
+ A dictionary of settings for NeMo Subsampling.
750
+ default: None
751
+ usage: nemo_conv_settings=
752
+ {
753
+ "subsampling":
754
+ dw_striding/striding/dw_striding_conv1d/striding_conv1d,
755
+ "conv_channels": int,
756
+ "subsampling_conv_chunking_factor": int,
757
+ "is_causal": True/False
758
+ }
759
+ conv2d_extra_padding: str, optional
760
+ Add extra padding in conv2d subsampling layers. Choices are
761
+ (feat, feat_time, none, True)
762
+ Default: none
763
+ replication_pad_for_subsample_embedding: For batched-streaming
764
+ decoding, use "replication" padding for the cache at start of
765
+ utterance.
766
+ Default: False
767
+ attention_group_size: int, optional
768
+ the number of groups to use for attention, default 1
769
+ (Multi-Head Attention),
770
+ 1 = typical Multi-Head Attention,
771
+ 1 < attention_group_size < attention_heads = Grouped-Query
772
+ Attention
773
+ attention_group_size = attention_heads = Multi-Query Attention
774
+ """
775
+
776
+ extra_multi_layer_output_idxs: list[int]
777
+
778
+ def __init__( # pylint: disable-all
779
+ self,
780
+ input_size,
781
+ chunk_size,
782
+ left_chunk,
783
+ num_lang=None,
784
+ attention_dim=256,
785
+ attention_heads=4,
786
+ linear_units=2048,
787
+ num_blocks=6,
788
+ dropout_rate=0.1,
789
+ input_layer="nemo_conv",
790
+ causal=True,
791
+ batch_norm=False,
792
+ cnn_out=-1,
793
+ cnn_layer_norm=False,
794
+ ext_pw_out_channel=0,
795
+ ext_pw_kernel_size=1,
796
+ depthwise_seperable_out_channel=256,
797
+ depthwise_multiplier=1,
798
+ chunk_se=0,
799
+ kernel_size=3,
800
+ activation="relu",
801
+ conv_activation="relu",
802
+ conv_glu_type="sigmoid",
803
+ bias_in_glu=True,
804
+ linear_glu_in_convm=False,
805
+ attention_glu_type="swish",
806
+ export=False,
807
+ extra_layer_output_idx=-1,
808
+ extra_multi_layer_output_idxs=[], # noqa
809
+ activation_checkpointing="",
810
+ relative_attention_bias_args=None,
811
+ time_reduction=4,
812
+ use_pt_scaled_dot_product_attention=False,
813
+ nemo_conv_settings=None,
814
+ conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
815
+ replication_pad_for_subsample_embedding=False,
816
+ attention_group_size=1,
817
+ encoder_embedding_config=None,
818
+ ):
819
+ super().__init__(
820
+ input_size,
821
+ chunk_size,
822
+ left_chunk,
823
+ attention_dim,
824
+ attention_heads,
825
+ input_layer,
826
+ cnn_out,
827
+ cnn_layer_norm,
828
+ time_reduction,
829
+ dropout_rate=dropout_rate,
830
+ relative_attention_bias_args=relative_attention_bias_args,
831
+ positional_dropout_rate=0.0,
832
+ nemo_conv_settings=nemo_conv_settings,
833
+ conv2d_extra_padding=conv2d_extra_padding,
834
+ attention_group_size=attention_group_size,
835
+ encoder_embedding_config=encoder_embedding_config,
836
+ )
837
+ self.num_blocks = num_blocks
838
+ self.num_lang = num_lang
839
+ self.kernel_size = kernel_size
840
+ self.replication_pad_for_subsample_embedding: bool = (
841
+ replication_pad_for_subsample_embedding
842
+ )
843
+ assert (
844
+ self.num_heads % attention_group_size == 0
845
+ ), "attention_group_size must divide n_head"
846
+ self.num_heads_k = self.num_heads // attention_group_size
847
+
848
+ self.encoders = MultiSequential(
849
+ *[
850
+ ConformerEncoderLayer(
851
+ d_model=attention_dim,
852
+ ext_pw_out_channel=ext_pw_out_channel,
853
+ depthwise_seperable_out_channel=depthwise_seperable_out_channel,
854
+ depthwise_multiplier=depthwise_multiplier,
855
+ n_head=attention_heads,
856
+ d_ffn=linear_units,
857
+ ext_pw_kernel_size=ext_pw_kernel_size,
858
+ kernel_size=kernel_size,
859
+ dropout_rate=dropout_rate,
860
+ causal=causal,
861
+ batch_norm=batch_norm,
862
+ activation=activation,
863
+ chunk_se=chunk_se,
864
+ chunk_size=chunk_size,
865
+ conv_activation=conv_activation,
866
+ conv_glu_type=conv_glu_type,
867
+ bias_in_glu=bias_in_glu,
868
+ linear_glu_in_convm=linear_glu_in_convm,
869
+ attention_glu_type=attention_glu_type,
870
+ activation_checkpointing=activation_checkpointing,
871
+ export=export,
872
+ use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
873
+ attn_group_sizes=attention_group_size,
874
+ )
875
+ for _ in range(num_blocks)
876
+ ]
877
+ )
878
+ self.extra_layer_output_idx = extra_layer_output_idx
879
+ self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
880
+ # Make a zeros scalar we can use in get_initial_state to determine
881
+ # the device and the needed dtype:
882
+ self.register_buffer("dev_type", torch.zeros(()), persistent=False)
883
+
884
+ def init_relative_attention_bias(self, input_tensor):
885
+ if self.relative_attention_bias_layer:
886
+ return self.relative_attention_bias_layer(input_tensor)
887
+
888
+ def calculate_hs_mask(self, xs_pad, device, mask):
889
+ max_audio_length = xs_pad.shape[1]
890
+ batch_size = xs_pad.shape[0]
891
+ enc_streaming_mask = self._streaming_mask(
892
+ max_audio_length, batch_size, self.chunk_size, self.left_chunk
893
+ )
894
+ enc_streaming_mask = enc_streaming_mask.to(device)
895
+ if mask is None:
896
+ return enc_streaming_mask
897
+
898
+ feature_lens = mask.sum(1)
899
+ padding_length = feature_lens
900
+ pad_mask = torch.arange(0, max_audio_length, device=device).expand(
901
+ padding_length.size(0), -1
902
+ ) < padding_length.unsqueeze(1)
903
+ pad_mask = pad_mask.unsqueeze(1)
904
+ pad_mask = pad_mask & enc_streaming_mask
905
+ return pad_mask
906
+
907
+ @torch.jit.ignore
908
+ def forward(self, xs_pad, masks):
909
+ """Conformer Forward function
910
+
911
+ Args:
912
+ xs_pad: torch.Tensor
913
+ input tensor
914
+ masks: torch.Tensor
915
+ post-embedding input lengths
916
+ """
917
+ xs_pad = self.encoder_embedding(xs_pad)
918
+ input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
919
+ xs_pad, masks
920
+ )
921
+
922
+ unfolded = False
923
+ ori_bz, seq_len, D = input_tensor.shape
924
+ max_seq_len = 500 # maximum position for absolute positional encoding
925
+ if seq_len > max_seq_len:
926
+ # audio sequence is longer than max_seq_len, unfold it into chunks
927
+ # of max_seq_len
928
+ unfolded = True
929
+ # the unfold op will drop residual frames, pad it to the multiple
930
+ # of max_seq_len
931
+ if seq_len % max_seq_len > 0:
932
+ chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
933
+ else:
934
+ chunk_pad_size = 0
935
+ if chunk_pad_size > 0:
936
+ input_tensor_pad = F.pad(
937
+ input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
938
+ )
939
+ input_tensor = input_tensor_pad.to(input_tensor.device)
940
+ input_tensor = unfold_tensor(input_tensor, max_seq_len)
941
+ if masks is not None:
942
+ # revise hs_mask here because the previous calculated hs_mask
943
+ # did not consider extra pad
944
+ subsampled_pad_mask = masks.squeeze(
945
+ 1
946
+ ) # [bz, subsampled_unmask_seq_len]
947
+ extra_padded_subsamlped_pad_mask = F.pad(
948
+ subsampled_pad_mask, (0, chunk_pad_size), "constant", False
949
+ ) # extra padding to the pad mask
950
+ extra_padded_subsamlped_pad_mask = (
951
+ extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
952
+ )
953
+ masks_unfold = unfold_tensor(
954
+ extra_padded_subsamlped_pad_mask, max_seq_len
955
+ ) # unfold the pad mask like we did to the input tensor
956
+ masks_unfold = masks_unfold.squeeze(
957
+ -1
958
+ ).bool() # unfold op does not support bool tensor
959
+ else:
960
+ masks_unfold = None
961
+ hs_mask = self.calculate_hs_mask(
962
+ input_tensor, input_tensor.device, masks_unfold
963
+ ) # calculate hs_mask based on the unfolded pad mask
964
+
965
+ # layer_emb = None
966
+
967
+ relative_attention_bias = self.init_relative_attention_bias(input_tensor)
968
+
969
+ _simplified_path = (
970
+ self.extra_layer_output_idx == -1 and relative_attention_bias is None
971
+ )
972
+
973
+ if _simplified_path:
974
+ input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
975
+ else:
976
+ for i, layer in enumerate(self.encoders):
977
+ input_tensor, _, _, _ = layer(
978
+ input_tensor,
979
+ pos_k,
980
+ pos_v,
981
+ hs_mask,
982
+ relative_attention_bias=relative_attention_bias,
983
+ )
984
+
985
+ # if i == self.extra_layer_output_idx:
986
+ # layer_emb = input_tensor
987
+
988
+ if unfolded:
989
+ embed_dim = input_tensor.shape[-1]
990
+ input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
991
+ # if we ever padded before unfolding, we need to remove the padding
992
+ if chunk_pad_size > 0:
993
+ input_tensor = input_tensor[:, :-chunk_pad_size, :]
994
+
995
+ return input_tensor, masks # , layer_emb
996
+
997
+
998
+ class WindowQformer(nn.Module):
999
+ """Window-level Qformer"""
1000
+
1001
+ def __init__(
1002
+ self,
1003
+ window_size: int = 8,
1004
+ num_queries: int = 1,
1005
+ num_blocks: int = 2,
1006
+ attention_dim: int = 512,
1007
+ attention_heads: int = 8,
1008
+ linear_units: int = 2048,
1009
+ dropout_rate: float = 0.0,
1010
+ normalize_before: bool = True,
1011
+ ):
1012
+ super().__init__()
1013
+
1014
+ self.decoders = nn.ModuleList(
1015
+ [
1016
+ nn.TransformerDecoderLayer(
1017
+ d_model=attention_dim,
1018
+ nhead=attention_heads,
1019
+ dim_feedforward=linear_units,
1020
+ dropout=dropout_rate,
1021
+ activation="relu",
1022
+ batch_first=True,
1023
+ norm_first=normalize_before, # TODO need to verify
1024
+ )
1025
+ for _ in range(num_blocks)
1026
+ ]
1027
+ )
1028
+
1029
+ self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
1030
+ self.after_norm = (
1031
+ nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
1032
+ )
1033
+ self.window_size = window_size
1034
+
1035
+ def forward(self, audio_embed, mask, embed_len=None):
1036
+ """forward decoder"""
1037
+ # audio_embed: N x T x D => N x D x T
1038
+
1039
+ audio_embed = audio_embed.transpose(1, 2)
1040
+ # audio_embed: N x D x 1 x T => N x DK x T'
1041
+ padding = audio_embed.shape[-1] % self.window_size
1042
+ if padding > 0:
1043
+ audio_embed = F.pad(
1044
+ audio_embed, (0, self.window_size - padding), "constant", 0
1045
+ )
1046
+
1047
+ embed_chunk = F.unfold(
1048
+ audio_embed[..., None, :],
1049
+ kernel_size=(1, self.window_size),
1050
+ stride=(1, self.window_size),
1051
+ )
1052
+ bsz, _, slen = embed_chunk.shape
1053
+ # N x D x K x T'
1054
+ embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
1055
+ # N x T' x K x D
1056
+ embed_chunk = embed_chunk.transpose(1, 3).contiguous()
1057
+ # NT' x K x D
1058
+ embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
1059
+ # NT' x 1 x D
1060
+ q = self.queries.expand(bsz * slen, -1, -1)
1061
+ for layer in self.decoders:
1062
+ q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
1063
+
1064
+ if self.after_norm is not None:
1065
+ q = self.after_norm(q)
1066
+
1067
+ if embed_len is not None:
1068
+ embed_len = embed_len // self.window_size
1069
+ # N x T' x D
1070
+ out = q.view(bsz, slen, -1)
1071
+
1072
+ return out, embed_len
1073
+
1074
+
1075
+ class AudioEmbedding(nn.Module):
1076
+ """Image embedding."""
1077
+
1078
+ def __init__(self, config: PretrainedConfig, **kwargs) -> None:
1079
+ super().__init__()
1080
+ self.config = config
1081
+ # n_embed or hidden_size for text LM
1082
+ hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
1083
+
1084
+ # self.wte = nn.Embedding(config.vocab_size, hidden_size)
1085
+
1086
+ audio_dim_out = (
1087
+ None # Set this variable according to the actual audio processor
1088
+ )
1089
+ self.layer_idx = -2
1090
+
1091
+ if (
1092
+ isinstance(config.audio_processor, dict)
1093
+ and config.audio_processor.get("name", None) == "cascades"
1094
+ ):
1095
+ encoder_config = config.audio_processor.get("config", None)
1096
+ assert encoder_config is not None
1097
+ self.encoder = ConformerEncoder(**encoder_config)
1098
+
1099
+ audio_dim_out = encoder_config["attention_dim"]
1100
+ n_mels = encoder_config["input_size"]
1101
+ else:
1102
+ raise NotImplementedError("")
1103
+
1104
+ assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
1105
+ self.audio_dim_out = audio_dim_out
1106
+ self.audio_dim_in = n_mels
1107
+
1108
+ self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
1109
+
1110
+ self.downsample_rate = kwargs.get("downsample_rate", 1)
1111
+
1112
+ if kwargs.get("use_qformer", False):
1113
+ qformer_config = kwargs.get("qformer_config", {})
1114
+ qformer_config["attention_dim"] = audio_dim_out
1115
+ self.qformer = WindowQformer(**qformer_config)
1116
+ else:
1117
+ self.qformer = None
1118
+
1119
+ if kwargs.get("use_conv_downsample", False):
1120
+ assert (
1121
+ self.qformer is None
1122
+ ), "don't support use qformer and conv downsample together"
1123
+ nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
1124
+ default_nemo_conv_settings = {
1125
+ "subsampling": "dw_striding",
1126
+ "subsampling_factor": self.downsample_rate,
1127
+ "feat_in": audio_dim_out,
1128
+ "feat_out": audio_dim_out,
1129
+ "conv_channels": 256,
1130
+ "subsampling_conv_chunking_factor": 1,
1131
+ "activation": nn.ReLU(),
1132
+ "is_causal": False,
1133
+ }
1134
+ # Override any of the defaults with the incoming, user settings
1135
+ if nemo_conv_settings:
1136
+ default_nemo_conv_settings.update(nemo_conv_settings)
1137
+ for i in ["subsampling_factor", "feat_in", "feat_out"]:
1138
+ assert (
1139
+ i not in nemo_conv_settings
1140
+ ), "{i} should be specified outside of the NeMo dictionary"
1141
+
1142
+ self.conv_ds = NemoConvSubsampling(
1143
+ **default_nemo_conv_settings,
1144
+ )
1145
+ else:
1146
+ self.conv_ds = None
1147
+
1148
+ projection_cls = kwargs.get("projection_cls", "linear")
1149
+ if projection_cls == "linear":
1150
+ self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
1151
+ elif projection_cls == "mlp":
1152
+ # follow llava-v1.5's implementation
1153
+ # (do not use image_projection and image_proj_norm)
1154
+ dim_projection = hidden_size
1155
+ depth = 2
1156
+ self.linear_downsample_rate = (
1157
+ 1 if (self.qformer or self.conv_ds) else self.downsample_rate
1158
+ )
1159
+ layers = [
1160
+ nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
1161
+ ]
1162
+ for _ in range(1, depth):
1163
+ layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
1164
+ self.audio_projection = nn.Sequential(*layers)
1165
+ # NOTE vision-speech tasks use a separate projection layer
1166
+ layers = [
1167
+ nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
1168
+ ]
1169
+ for _ in range(1, depth):
1170
+ layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
1171
+ self.audio_projection_for_vision = nn.Sequential(*layers)
1172
+ else:
1173
+ raise NotImplementedError(
1174
+ f"projection_cls = {projection_cls}, not implemented"
1175
+ )
1176
+
1177
+ # TODO: audio sequence compression - Qformer
1178
+ self.vocab_size = config.vocab_size
1179
+ self.input_embeds = None
1180
+ self.audio_embed_sizes = None
1181
+
1182
+ def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None:
1183
+ self.input_embeds = input_embeds
1184
+
1185
+ def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None:
1186
+ self.audio_embed_sizes = audio_embed_sizes
1187
+
1188
+ def get_audio_features(
1189
+ self,
1190
+ input_embeds: torch.FloatTensor,
1191
+ audio_attention_mask: torch.Tensor = None,
1192
+ audio_projection_mode: str = "speech",
1193
+ ) -> torch.FloatTensor:
1194
+ """
1195
+ arguments:
1196
+ input_embeds: audio features (B, T, D) B: num audios in a sequence
1197
+ """
1198
+ if self.freeze_audio_processor:
1199
+ with torch.no_grad():
1200
+ audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
1201
+ else:
1202
+ audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
1203
+
1204
+ if self.qformer is not None:
1205
+ audio_features, _ = self.qformer(audio_features, mask=None)
1206
+
1207
+ if self.conv_ds is not None:
1208
+ if masks is not None:
1209
+ masks = masks.squeeze(1)
1210
+
1211
+ audio_features, masks = self.conv_ds(audio_features, mask=masks)
1212
+
1213
+ if self.linear_downsample_rate != 1:
1214
+ bs, seq_len, feat_dim = audio_features.size()
1215
+ padding = seq_len % self.linear_downsample_rate
1216
+ if padding > 0:
1217
+ audio_features = F.pad(
1218
+ audio_features,
1219
+ (0, 0, 0, self.linear_downsample_rate - padding),
1220
+ "constant",
1221
+ 0,
1222
+ )
1223
+
1224
+ seq_len = audio_features.size(1)
1225
+ audio_features = audio_features.view(
1226
+ bs,
1227
+ seq_len // self.linear_downsample_rate,
1228
+ feat_dim * self.linear_downsample_rate,
1229
+ )
1230
+
1231
+ if audio_projection_mode == "speech":
1232
+ audio_set_tensor = self.audio_projection(audio_features)
1233
+ elif audio_projection_mode == "vision":
1234
+ audio_set_tensor = self.audio_projection_for_vision(audio_features)
1235
+ else:
1236
+ raise ValueError(
1237
+ f"audio_projection_mode = {audio_projection_mode} not " "implemented"
1238
+ )
1239
+
1240
+ return audio_set_tensor
1241
+
1242
+ def forward(
1243
+ self,
1244
+ audio_features: torch.FloatTensor,
1245
+ audio_attention_mask: torch.Tensor = None,
1246
+ audio_projection_mode: str = "speech",
1247
+ ) -> torch.FloatTensor:
1248
+ """
1249
+ arguments:
1250
+ audio_features: audio features (num_audio_tokens, T, D)
1251
+
1252
+ returns:
1253
+ audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
1254
+ """
1255
+ audio_embeds = self.get_audio_features(
1256
+ audio_features,
1257
+ audio_attention_mask=audio_attention_mask,
1258
+ audio_projection_mode=audio_projection_mode,
1259
+ )
1260
+ return audio_embeds