sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
43
  ParallelLMHead,
44
44
  VocabParallelEmbedding,
45
45
  )
46
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
48
  from sglang.srt.model_loader.weight_utils import (
48
49
  default_weight_loader,
@@ -100,6 +101,7 @@ class Qwen2Attention(nn.Module):
100
101
  hidden_size: int,
101
102
  num_heads: int,
102
103
  num_kv_heads: int,
104
+ head_dim: Optional[int] = None,
103
105
  layer_id: int = 0,
104
106
  rope_theta: float = 1000000,
105
107
  rope_scaling: Optional[Dict[str, Any]] = None,
@@ -123,7 +125,10 @@ class Qwen2Attention(nn.Module):
123
125
  # the KV heads across multiple tensor parallel GPUs.
124
126
  assert tp_size % self.total_num_kv_heads == 0
125
127
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
126
- self.head_dim = hidden_size // self.total_num_heads
128
+ if head_dim is not None:
129
+ self.head_dim = head_dim
130
+ else:
131
+ self.head_dim = hidden_size // self.total_num_heads
127
132
  self.q_size = self.num_heads * self.head_dim
128
133
  self.kv_size = self.num_kv_heads * self.head_dim
129
134
  self.scaling = self.head_dim**-0.5
@@ -185,16 +190,19 @@ class Qwen2DecoderLayer(nn.Module):
185
190
  layer_id: int = 0,
186
191
  quant_config: Optional[QuantizationConfig] = None,
187
192
  prefix: str = "",
193
+ alt_stream: Optional[torch.cuda.Stream] = None,
188
194
  ) -> None:
189
195
  super().__init__()
190
196
  self.hidden_size = config.hidden_size
191
197
  rope_theta = getattr(config, "rope_theta", 1000000)
192
198
  rope_scaling = getattr(config, "rope_scaling", None)
193
199
  max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
200
+ head_dim = getattr(config, "head_dim", None)
194
201
  self.self_attn = Qwen2Attention(
195
202
  hidden_size=self.hidden_size,
196
203
  num_heads=config.num_attention_heads,
197
204
  num_kv_heads=config.num_key_value_heads,
205
+ head_dim=head_dim,
198
206
  layer_id=layer_id,
199
207
  rope_theta=rope_theta,
200
208
  rope_scaling=rope_scaling,
@@ -246,6 +254,7 @@ class Qwen2Model(nn.Module):
246
254
  quant_config: Optional[QuantizationConfig] = None,
247
255
  prefix: str = "",
248
256
  decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
257
+ alt_stream: Optional[torch.cuda.Stream] = None,
249
258
  ) -> None:
250
259
  super().__init__()
251
260
  self.config = config
@@ -258,6 +267,7 @@ class Qwen2Model(nn.Module):
258
267
  config.vocab_size,
259
268
  config.hidden_size,
260
269
  quant_config=quant_config,
270
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
261
271
  prefix=add_prefix("embed_tokens", prefix),
262
272
  )
263
273
  else:
@@ -272,6 +282,7 @@ class Qwen2Model(nn.Module):
272
282
  config=config,
273
283
  quant_config=quant_config,
274
284
  prefix=prefix,
285
+ alt_stream=alt_stream,
275
286
  ),
276
287
  pp_rank=self.pp_group.rank_in_group,
277
288
  pp_size=self.pp_group.world_size,
@@ -282,6 +293,9 @@ class Qwen2Model(nn.Module):
282
293
  else:
283
294
  self.norm = PPMissingLayer(return_tuple=True)
284
295
 
296
+ # For EAGLE3 support
297
+ self.layers_to_capture = []
298
+
285
299
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
286
300
  if hasattr(self.config, "scale_emb"):
287
301
  return self.get_input_embeddings()(input_ids) * self.config.scale_emb
@@ -310,7 +324,12 @@ class Qwen2Model(nn.Module):
310
324
  hidden_states = pp_proxy_tensors["hidden_states"]
311
325
  residual = pp_proxy_tensors["residual"]
312
326
 
327
+ aux_hidden_states = []
313
328
  for i in range(self.start_layer, self.end_layer):
329
+ if i in self.layers_to_capture:
330
+ aux_hidden_states.append(
331
+ hidden_states + residual if residual is not None else hidden_states
332
+ )
314
333
  layer = self.layers[i]
315
334
  hidden_states, residual = layer(
316
335
  positions,
@@ -326,8 +345,16 @@ class Qwen2Model(nn.Module):
326
345
  }
327
346
  )
328
347
  else:
329
- hidden_states, _ = self.norm(hidden_states, residual)
330
- return hidden_states
348
+ if hidden_states.shape[0] != 0:
349
+ if residual is None:
350
+ hidden_states = self.norm(hidden_states)
351
+ else:
352
+ hidden_states, _ = self.norm(hidden_states, residual)
353
+
354
+ if len(aux_hidden_states) == 0:
355
+ return hidden_states
356
+
357
+ return hidden_states, aux_hidden_states
331
358
 
332
359
  # If this function is called, it should always initialize KV cache scale
333
360
  # factors (or else raise an exception). Thus, handled exceptions should
@@ -398,6 +425,7 @@ class Qwen2ForCausalLM(nn.Module):
398
425
  quant_config=quant_config,
399
426
  prefix=add_prefix("lm_head", prefix),
400
427
  )
428
+
401
429
  else:
402
430
  # ranks other than the last rank will have a placeholder layer
403
431
  self.lm_head = PPMissingLayer()
@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
493
493
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
494
494
 
495
495
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
496
- # Get all special token IDs
497
- im_token_id: int = mm_inputs.im_token_id
498
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
496
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
499
497
  return pattern.pad_input_tokens(input_ids, mm_inputs)
500
498
 
501
499
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
@@ -0,0 +1,200 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
4
+ # Copyright 2024 The Qwen team.
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
+ import logging
26
+ import math
27
+ from functools import lru_cache, partial
28
+ from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from einops import rearrange
34
+ from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
35
+ from transformers.activations import ACT2FN
36
+ from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
37
+ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
38
+ Qwen2AudioEncoder,
39
+ Qwen2AudioMultiModalProjector,
40
+ )
41
+
42
+ from sglang.srt.hf_transformers_utils import get_processor
43
+ from sglang.srt.layers.activation import QuickGELU
44
+ from sglang.srt.layers.attention.vision import VisionAttention
45
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46
+ from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.pooler import Pooler, PoolingType
48
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
+ from sglang.srt.layers.utils import get_layer_id
50
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51
+ from sglang.srt.managers.mm_utils import (
52
+ MultiModalityDataPaddingPatternMultimodalTokens,
53
+ general_mm_embed_routine,
54
+ )
55
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
56
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
58
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
59
+ from sglang.srt.utils import add_prefix
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ class Qwen2AudioForConditionalGeneration(nn.Module):
65
+ # BitandBytes specific attributes
66
+ default_bitsandbytes_target_modules = [
67
+ ".gate_proj.",
68
+ ".down_proj.",
69
+ ".up_proj.",
70
+ ".q_proj.",
71
+ ".k_proj.",
72
+ ".v_proj.",
73
+ ".o_proj.",
74
+ ]
75
+ bitsandbytes_stacked_params_mapping = {
76
+ # shard_name, weight_name, index
77
+ "q_proj": ("qkv_proj", 0),
78
+ "k_proj": ("qkv_proj", 1),
79
+ "v_proj": ("qkv_proj", 2),
80
+ "gate_proj": ("gate_up_proj", 0),
81
+ "up_proj": ("gate_up_proj", 1),
82
+ }
83
+
84
+ def __init__(
85
+ self,
86
+ config: Qwen2AudioConfig,
87
+ quant_config: Optional[QuantizationConfig] = None,
88
+ prefix: str = "",
89
+ ) -> None:
90
+ super().__init__()
91
+
92
+ self.config = config
93
+
94
+ if getattr(self.config, "audio_config", None) is None:
95
+ self.config.audio_config = Qwen2AudioEncoderConfig(
96
+ self.config._name_or_path
97
+ )
98
+
99
+ if getattr(self.config, "text_config", None) is None:
100
+ self.config.text_config = Qwen2Config(self.config._name_or_path)
101
+
102
+ self.audio_tower = Qwen2AudioEncoder(
103
+ config.audio_config,
104
+ )
105
+ self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
106
+ self.language_model = Qwen2ForCausalLM(
107
+ config.text_config, quant_config, prefix=add_prefix("model", prefix)
108
+ )
109
+
110
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
111
+ # Get all special token IDs for audio
112
+ audio_token_id: int = getattr(
113
+ mm_inputs, "audio_token_id", mm_inputs.im_token_id
114
+ )
115
+
116
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
117
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
118
+
119
+ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120
+ # Extract audio features from input items
121
+ input_features = torch.cat([item.audio_features for item in items], dim=0).type(
122
+ self.audio_tower.dtype
123
+ )
124
+
125
+ audio_embeds = self.audio_tower(input_features).last_hidden_state
126
+ audio_embeds = self.multi_modal_projector(audio_embeds)
127
+
128
+ audio_feature_lens = torch.cat([item.audio_feature_lens for item in items])
129
+ new_embeds = []
130
+ for i, d in zip(audio_feature_lens, audio_embeds):
131
+ new_embeds.append(d[: i.item()])
132
+
133
+ return torch.cat(new_embeds, dim=0)
134
+
135
+ def forward(
136
+ self,
137
+ input_ids: torch.Tensor,
138
+ positions: torch.Tensor,
139
+ forward_batch: ForwardBatch,
140
+ **kwargs: Any,
141
+ ) -> torch.Tensor:
142
+ hidden_states = general_mm_embed_routine(
143
+ input_ids=input_ids,
144
+ forward_batch=forward_batch,
145
+ language_model=self.language_model,
146
+ audio_data_embedding_func=self.get_audio_feature,
147
+ positions=positions,
148
+ )
149
+
150
+ return hidden_states
151
+
152
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
153
+ stacked_params_mapping = [
154
+ # (param_name, shard_name, shard_id)
155
+ ("qkv_proj", "q_proj", "q"),
156
+ ("qkv_proj", "k_proj", "k"),
157
+ ("qkv_proj", "v_proj", "v"),
158
+ ("gate_up_proj", "gate_proj", 0),
159
+ ("gate_up_proj", "up_proj", 1),
160
+ ]
161
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
162
+
163
+ for name, loaded_weight in weights:
164
+ if "rotary_emb.inv_freq" in name:
165
+ continue
166
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
167
+ # Models trained using ColossalAI may include these tensors in
168
+ # the checkpoint. Skip them.
169
+ continue
170
+
171
+ if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name:
172
+ continue
173
+
174
+ for param_name, weight_name, shard_id in stacked_params_mapping:
175
+ if weight_name not in name or "audio_tower" in name:
176
+ continue
177
+ name_tmp = name.replace(weight_name, param_name)
178
+
179
+ # Skip loading extra bias for GPTQ models.
180
+ if name_tmp.endswith(".bias") and name_tmp not in params_dict:
181
+ continue
182
+ param = params_dict[name_tmp]
183
+ weight_loader = param.weight_loader
184
+ weight_loader(param, loaded_weight, shard_id)
185
+ break
186
+ else:
187
+ try:
188
+ # Skip loading extra bias for GPTQ models.
189
+ if name.endswith(".bias") and name not in params_dict:
190
+ continue
191
+ param = params_dict[name]
192
+ except KeyError:
193
+ print(params_dict.keys())
194
+ raise
195
+
196
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
197
+ weight_loader(param, loaded_weight)
198
+
199
+
200
+ EntryClass = Qwen2AudioForConditionalGeneration
@@ -31,6 +31,11 @@ from sglang.srt.distributed import (
31
31
  get_tensor_model_parallel_world_size,
32
32
  tensor_model_parallel_all_reduce,
33
33
  )
34
+ from sglang.srt.eplb.expert_distribution import (
35
+ ExpertDistributionRecorder,
36
+ get_global_expert_distribution_recorder,
37
+ )
38
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
34
39
  from sglang.srt.layers.activation import SiluAndMul
35
40
  from sglang.srt.layers.communicator import (
36
41
  LayerCommunicator,
@@ -64,11 +69,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
69
  ParallelLMHead,
65
70
  VocabParallelEmbedding,
66
71
  )
67
- from sglang.srt.managers.expert_distribution import (
68
- ExpertDistributionRecorder,
69
- get_global_expert_distribution_recorder,
70
- )
71
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
72
72
  from sglang.srt.managers.schedule_batch import global_server_args_dict
73
73
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
74
74
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -143,6 +143,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
143
143
  renormalize=config.norm_topk_prob,
144
144
  quant_config=quant_config,
145
145
  prefix=add_prefix("experts", prefix),
146
+ # Additional args for FusedMoE
147
+ **(
148
+ dict(
149
+ enable_flashinfer_moe=True,
150
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
151
+ )
152
+ if global_server_args_dict["enable_flashinfer_moe"]
153
+ else {}
154
+ ),
146
155
  )
147
156
 
148
157
  self.gate = ReplicatedLinear(
@@ -291,6 +300,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
291
300
  layer_id: int,
292
301
  quant_config: Optional[QuantizationConfig] = None,
293
302
  prefix: str = "",
303
+ alt_stream: Optional[torch.cuda.Stream] = None,
294
304
  ) -> None:
295
305
  super().__init__()
296
306
  self.config = config
@@ -393,6 +403,7 @@ class Qwen2MoeModel(nn.Module):
393
403
  quant_config: Optional[QuantizationConfig] = None,
394
404
  prefix: str = "",
395
405
  decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
406
+ alt_stream: Optional[torch.cuda.Stream] = None,
396
407
  ) -> None:
397
408
  super().__init__()
398
409
  self.padding_idx = config.pad_token_id
@@ -418,6 +429,7 @@ class Qwen2MoeModel(nn.Module):
418
429
  config=config,
419
430
  quant_config=quant_config,
420
431
  prefix=prefix,
432
+ alt_stream=alt_stream,
421
433
  ),
422
434
  pp_rank=self.pp_group.rank_in_group,
423
435
  pp_size=self.pp_group.world_size,
@@ -428,6 +440,9 @@ class Qwen2MoeModel(nn.Module):
428
440
  else:
429
441
  self.norm = PPMissingLayer(return_tuple=True)
430
442
 
443
+ # For EAGLE3 support
444
+ self.layers_to_capture = []
445
+
431
446
  def forward(
432
447
  self,
433
448
  input_ids: torch.Tensor,
@@ -447,6 +462,7 @@ class Qwen2MoeModel(nn.Module):
447
462
  hidden_states = pp_proxy_tensors["hidden_states"]
448
463
  residual = pp_proxy_tensors["residual"]
449
464
 
465
+ aux_hidden_states = []
450
466
  if forward_batch.can_run_tbo:
451
467
  hidden_states, residual = model_forward_maybe_tbo(
452
468
  layers=self.layers,
@@ -459,6 +475,12 @@ class Qwen2MoeModel(nn.Module):
459
475
  )
460
476
  else:
461
477
  for i in range(self.start_layer, self.end_layer):
478
+ if i in self.layers_to_capture:
479
+ aux_hidden_states.append(
480
+ hidden_states + residual
481
+ if residual is not None
482
+ else hidden_states
483
+ )
462
484
  with get_global_expert_distribution_recorder().with_current_layer(i):
463
485
  layer = self.layers[i]
464
486
  hidden_states, residual = layer(
@@ -477,7 +499,11 @@ class Qwen2MoeModel(nn.Module):
477
499
  hidden_states = self.norm(hidden_states)
478
500
  else:
479
501
  hidden_states, _ = self.norm(hidden_states, residual)
480
- return hidden_states
502
+
503
+ if len(aux_hidden_states) == 0:
504
+ return hidden_states
505
+
506
+ return hidden_states, aux_hidden_states
481
507
 
482
508
 
483
509
  class Qwen2MoeForCausalLM(nn.Module):
@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
479
479
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
480
480
 
481
481
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
482
- # Get all special token IDs
483
- im_token_id: int = mm_inputs.im_token_id
484
-
485
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
482
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
486
483
  return pattern.pad_input_tokens(input_ids, mm_inputs)
487
484
 
488
485
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: