sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
4
+ # Copyright 2024 The Qwen team.
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
+ import logging
26
+ import math
27
+ from functools import lru_cache, partial
28
+ from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from einops import rearrange
34
+ from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
35
+ from transformers.activations import ACT2FN
36
+ from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
37
+ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
38
+ Qwen2AudioEncoder,
39
+ Qwen2AudioMultiModalProjector,
40
+ )
41
+
42
+ from sglang.srt.hf_transformers_utils import get_processor
43
+ from sglang.srt.layers.activation import QuickGELU
44
+ from sglang.srt.layers.attention.vision import VisionAttention
45
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46
+ from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.pooler import Pooler, PoolingType
48
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
+ from sglang.srt.layers.utils import get_layer_id
50
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51
+ from sglang.srt.managers.mm_utils import (
52
+ MultiModalityDataPaddingPatternMultimodalTokens,
53
+ general_mm_embed_routine,
54
+ )
55
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
56
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
58
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
59
+ from sglang.srt.utils import add_prefix
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ class Qwen2AudioForConditionalGeneration(nn.Module):
65
+ # BitandBytes specific attributes
66
+ default_bitsandbytes_target_modules = [
67
+ ".gate_proj.",
68
+ ".down_proj.",
69
+ ".up_proj.",
70
+ ".q_proj.",
71
+ ".k_proj.",
72
+ ".v_proj.",
73
+ ".o_proj.",
74
+ ]
75
+ bitsandbytes_stacked_params_mapping = {
76
+ # shard_name, weight_name, index
77
+ "q_proj": ("qkv_proj", 0),
78
+ "k_proj": ("qkv_proj", 1),
79
+ "v_proj": ("qkv_proj", 2),
80
+ "gate_proj": ("gate_up_proj", 0),
81
+ "up_proj": ("gate_up_proj", 1),
82
+ }
83
+
84
+ def __init__(
85
+ self,
86
+ config: Qwen2AudioConfig,
87
+ quant_config: Optional[QuantizationConfig] = None,
88
+ prefix: str = "",
89
+ ) -> None:
90
+ super().__init__()
91
+
92
+ self.config = config
93
+
94
+ if getattr(self.config, "audio_config", None) is None:
95
+ self.config.audio_config = Qwen2AudioEncoderConfig(
96
+ self.config._name_or_path
97
+ )
98
+
99
+ if getattr(self.config, "text_config", None) is None:
100
+ self.config.text_config = Qwen2Config(self.config._name_or_path)
101
+
102
+ self.audio_tower = Qwen2AudioEncoder(
103
+ config.audio_config,
104
+ )
105
+ self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
106
+ self.language_model = Qwen2ForCausalLM(
107
+ config.text_config, quant_config, prefix=add_prefix("model", prefix)
108
+ )
109
+
110
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
111
+ # Get all special token IDs for audio
112
+ audio_token_id: int = getattr(
113
+ mm_inputs, "audio_token_id", mm_inputs.im_token_id
114
+ )
115
+
116
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
117
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
118
+
119
+ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120
+ # Extract audio features from input items
121
+ input_features = torch.cat([item.audio_features for item in items], dim=0).type(
122
+ self.audio_tower.dtype
123
+ )
124
+
125
+ audio_embeds = self.audio_tower(input_features).last_hidden_state
126
+ audio_embeds = self.multi_modal_projector(audio_embeds)
127
+
128
+ audio_feature_lens = torch.cat([item.audio_feature_lens for item in items])
129
+ new_embeds = []
130
+ for i, d in zip(audio_feature_lens, audio_embeds):
131
+ new_embeds.append(d[: i.item()])
132
+
133
+ return torch.cat(new_embeds, dim=0)
134
+
135
+ def forward(
136
+ self,
137
+ input_ids: torch.Tensor,
138
+ positions: torch.Tensor,
139
+ forward_batch: ForwardBatch,
140
+ **kwargs: Any,
141
+ ) -> torch.Tensor:
142
+ hidden_states = general_mm_embed_routine(
143
+ input_ids=input_ids,
144
+ forward_batch=forward_batch,
145
+ language_model=self.language_model,
146
+ audio_data_embedding_func=self.get_audio_feature,
147
+ positions=positions,
148
+ )
149
+
150
+ return hidden_states
151
+
152
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
153
+ stacked_params_mapping = [
154
+ # (param_name, shard_name, shard_id)
155
+ ("qkv_proj", "q_proj", "q"),
156
+ ("qkv_proj", "k_proj", "k"),
157
+ ("qkv_proj", "v_proj", "v"),
158
+ ("gate_up_proj", "gate_proj", 0),
159
+ ("gate_up_proj", "up_proj", 1),
160
+ ]
161
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
162
+
163
+ for name, loaded_weight in weights:
164
+ if "rotary_emb.inv_freq" in name:
165
+ continue
166
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
167
+ # Models trained using ColossalAI may include these tensors in
168
+ # the checkpoint. Skip them.
169
+ continue
170
+
171
+ if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name:
172
+ continue
173
+
174
+ for param_name, weight_name, shard_id in stacked_params_mapping:
175
+ if weight_name not in name or "audio_tower" in name:
176
+ continue
177
+ name_tmp = name.replace(weight_name, param_name)
178
+
179
+ # Skip loading extra bias for GPTQ models.
180
+ if name_tmp.endswith(".bias") and name_tmp not in params_dict:
181
+ continue
182
+ param = params_dict[name_tmp]
183
+ weight_loader = param.weight_loader
184
+ weight_loader(param, loaded_weight, shard_id)
185
+ break
186
+ else:
187
+ try:
188
+ # Skip loading extra bias for GPTQ models.
189
+ if name.endswith(".bias") and name not in params_dict:
190
+ continue
191
+ param = params_dict[name]
192
+ except KeyError:
193
+ print(params_dict.keys())
194
+ raise
195
+
196
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
197
+ weight_loader(param, loaded_weight)
198
+
199
+
200
+ EntryClass = Qwen2AudioForConditionalGeneration
@@ -31,6 +31,11 @@ from sglang.srt.distributed import (
31
31
  get_tensor_model_parallel_world_size,
32
32
  tensor_model_parallel_all_reduce,
33
33
  )
34
+ from sglang.srt.eplb.expert_distribution import (
35
+ ExpertDistributionRecorder,
36
+ get_global_expert_distribution_recorder,
37
+ )
38
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
34
39
  from sglang.srt.layers.activation import SiluAndMul
35
40
  from sglang.srt.layers.communicator import (
36
41
  LayerCommunicator,
@@ -64,11 +69,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
69
  ParallelLMHead,
65
70
  VocabParallelEmbedding,
66
71
  )
67
- from sglang.srt.managers.expert_distribution import (
68
- ExpertDistributionRecorder,
69
- get_global_expert_distribution_recorder,
70
- )
71
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
72
72
  from sglang.srt.managers.schedule_batch import global_server_args_dict
73
73
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
74
74
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -143,6 +143,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
143
143
  renormalize=config.norm_topk_prob,
144
144
  quant_config=quant_config,
145
145
  prefix=add_prefix("experts", prefix),
146
+ # Additional args for FusedMoE
147
+ **(
148
+ dict(
149
+ enable_flashinfer_moe=True,
150
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
151
+ )
152
+ if global_server_args_dict["enable_flashinfer_moe"]
153
+ else {}
154
+ ),
146
155
  )
147
156
 
148
157
  self.gate = ReplicatedLinear(
@@ -291,6 +300,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
291
300
  layer_id: int,
292
301
  quant_config: Optional[QuantizationConfig] = None,
293
302
  prefix: str = "",
303
+ alt_stream: Optional[torch.cuda.Stream] = None,
294
304
  ) -> None:
295
305
  super().__init__()
296
306
  self.config = config
@@ -393,6 +403,7 @@ class Qwen2MoeModel(nn.Module):
393
403
  quant_config: Optional[QuantizationConfig] = None,
394
404
  prefix: str = "",
395
405
  decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
406
+ alt_stream: Optional[torch.cuda.Stream] = None,
396
407
  ) -> None:
397
408
  super().__init__()
398
409
  self.padding_idx = config.pad_token_id
@@ -418,6 +429,7 @@ class Qwen2MoeModel(nn.Module):
418
429
  config=config,
419
430
  quant_config=quant_config,
420
431
  prefix=prefix,
432
+ alt_stream=alt_stream,
421
433
  ),
422
434
  pp_rank=self.pp_group.rank_in_group,
423
435
  pp_size=self.pp_group.world_size,
@@ -428,6 +440,9 @@ class Qwen2MoeModel(nn.Module):
428
440
  else:
429
441
  self.norm = PPMissingLayer(return_tuple=True)
430
442
 
443
+ # For EAGLE3 support
444
+ self.layers_to_capture = []
445
+
431
446
  def forward(
432
447
  self,
433
448
  input_ids: torch.Tensor,
@@ -447,6 +462,7 @@ class Qwen2MoeModel(nn.Module):
447
462
  hidden_states = pp_proxy_tensors["hidden_states"]
448
463
  residual = pp_proxy_tensors["residual"]
449
464
 
465
+ aux_hidden_states = []
450
466
  if forward_batch.can_run_tbo:
451
467
  hidden_states, residual = model_forward_maybe_tbo(
452
468
  layers=self.layers,
@@ -459,6 +475,12 @@ class Qwen2MoeModel(nn.Module):
459
475
  )
460
476
  else:
461
477
  for i in range(self.start_layer, self.end_layer):
478
+ if i in self.layers_to_capture:
479
+ aux_hidden_states.append(
480
+ hidden_states + residual
481
+ if residual is not None
482
+ else hidden_states
483
+ )
462
484
  with get_global_expert_distribution_recorder().with_current_layer(i):
463
485
  layer = self.layers[i]
464
486
  hidden_states, residual = layer(
@@ -477,7 +499,11 @@ class Qwen2MoeModel(nn.Module):
477
499
  hidden_states = self.norm(hidden_states)
478
500
  else:
479
501
  hidden_states, _ = self.norm(hidden_states, residual)
480
- return hidden_states
502
+
503
+ if len(aux_hidden_states) == 0:
504
+ return hidden_states
505
+
506
+ return hidden_states, aux_hidden_states
481
507
 
482
508
 
483
509
  class Qwen2MoeForCausalLM(nn.Module):
@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
479
479
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
480
480
 
481
481
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
482
- # Get all special token IDs
483
- im_token_id: int = mm_inputs.im_token_id
484
-
485
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
482
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
486
483
  return pattern.pad_input_tokens(input_ids, mm_inputs)
487
484
 
488
485
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
  from functools import partial
5
- from typing import Any, Dict, Iterable, Optional, Tuple
5
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch import nn
@@ -11,9 +11,9 @@ from sglang.srt.distributed import (
11
11
  get_pp_group,
12
12
  get_tensor_model_parallel_rank,
13
13
  get_tensor_model_parallel_world_size,
14
- split_tensor_along_last_dim,
15
- tensor_model_parallel_all_gather,
16
14
  )
15
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
16
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
17
17
  from sglang.srt.layers.layernorm import RMSNorm
18
18
  from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
19
19
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -23,15 +23,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
23
23
  from sglang.srt.layers.rotary_embedding import get_rope
24
24
  from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
25
25
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
26
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
26
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
27
28
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
29
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
29
30
  from sglang.srt.models.qwen2 import Qwen2Model
30
- from sglang.srt.utils import add_prefix
31
+ from sglang.srt.utils import add_prefix, is_cuda
31
32
 
32
33
  Qwen3Config = None
33
34
 
34
35
  logger = logging.getLogger(__name__)
36
+ _is_cuda = is_cuda()
35
37
 
36
38
 
37
39
  class Qwen3Attention(nn.Module):
@@ -49,23 +51,27 @@ class Qwen3Attention(nn.Module):
49
51
  rms_norm_eps: float = None,
50
52
  attention_bias: bool = False,
51
53
  prefix: str = "",
54
+ alt_stream: Optional[torch.cuda.Stream] = None,
52
55
  ) -> None:
53
56
  super().__init__()
54
57
  self.hidden_size = hidden_size
55
58
  self.tp_size = get_tensor_model_parallel_world_size()
56
59
  self.total_num_heads = num_heads
57
- assert self.total_num_heads % self.tp_size == 0
58
- self.num_heads = self.total_num_heads // self.tp_size
60
+ attn_tp_rank = get_attention_tp_rank()
61
+ attn_tp_size = get_attention_tp_size()
62
+
63
+ assert self.total_num_heads % attn_tp_size == 0
64
+ self.num_heads = self.total_num_heads // attn_tp_size
59
65
  self.total_num_kv_heads = num_kv_heads
60
- if self.total_num_kv_heads >= self.tp_size:
66
+ if self.total_num_kv_heads >= attn_tp_size:
61
67
  # Number of KV heads is greater than TP size, so we partition
62
68
  # the KV heads across multiple tensor parallel GPUs.
63
- assert self.total_num_kv_heads % self.tp_size == 0
69
+ assert self.total_num_kv_heads % attn_tp_size == 0
64
70
  else:
65
71
  # Number of KV heads is less than TP size, so we replicate
66
72
  # the KV heads across multiple tensor parallel GPUs.
67
- assert self.tp_size % self.total_num_kv_heads == 0
68
- self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
73
+ assert attn_tp_size % self.total_num_kv_heads == 0
74
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
69
75
  self.head_dim = head_dim or hidden_size // self.total_num_heads
70
76
  self.q_size = self.num_heads * self.head_dim
71
77
  self.kv_size = self.num_kv_heads * self.head_dim
@@ -84,6 +90,8 @@ class Qwen3Attention(nn.Module):
84
90
  self.total_num_kv_heads,
85
91
  bias=attention_bias,
86
92
  quant_config=quant_config,
93
+ tp_rank=attn_tp_rank,
94
+ tp_size=attn_tp_size,
87
95
  prefix=add_prefix("qkv_proj", prefix),
88
96
  )
89
97
  self.o_proj = RowParallelLinear(
@@ -91,6 +99,9 @@ class Qwen3Attention(nn.Module):
91
99
  hidden_size,
92
100
  bias=attention_bias,
93
101
  quant_config=quant_config,
102
+ tp_rank=attn_tp_rank,
103
+ tp_size=attn_tp_size,
104
+ reduce_results=False,
94
105
  prefix=add_prefix("o_proj", prefix),
95
106
  )
96
107
 
@@ -109,15 +120,27 @@ class Qwen3Attention(nn.Module):
109
120
  layer_id=layer_id,
110
121
  prefix=add_prefix("attn", prefix),
111
122
  )
123
+ self.alt_stream = alt_stream
112
124
 
113
125
  def _apply_qk_norm(
114
126
  self, q: torch.Tensor, k: torch.Tensor
115
127
  ) -> Tuple[torch.Tensor, torch.Tensor]:
116
- q_by_head = q.reshape(-1, self.head_dim)
117
- q_by_head = self.q_norm(q_by_head)
128
+ # overlap qk norm
129
+ if self.alt_stream is not None and get_is_capture_mode():
130
+ current_stream = torch.cuda.current_stream()
131
+ self.alt_stream.wait_stream(current_stream)
132
+ q_by_head = q.reshape(-1, self.head_dim)
133
+ q_by_head = self.q_norm(q_by_head)
134
+ with torch.cuda.stream(self.alt_stream):
135
+ k_by_head = k.reshape(-1, self.head_dim)
136
+ k_by_head = self.k_norm(k_by_head)
137
+ current_stream.wait_stream(self.alt_stream)
138
+ else:
139
+ q_by_head = q.reshape(-1, self.head_dim)
140
+ q_by_head = self.q_norm(q_by_head)
141
+ k_by_head = k.reshape(-1, self.head_dim)
142
+ k_by_head = self.k_norm(k_by_head)
118
143
  q = q_by_head.view(q.shape)
119
- k_by_head = k.reshape(-1, self.head_dim)
120
- k_by_head = self.k_norm(k_by_head)
121
144
  k = k_by_head.view(k.shape)
122
145
  return q, k
123
146
 
@@ -143,6 +166,7 @@ class Qwen3DecoderLayer(nn.Module):
143
166
  layer_id: int = 0,
144
167
  quant_config: Optional[QuantizationConfig] = None,
145
168
  prefix: str = "",
169
+ alt_stream: Optional[torch.cuda.Stream] = None,
146
170
  ) -> None:
147
171
  super().__init__()
148
172
  self.hidden_size = config.hidden_size
@@ -163,6 +187,7 @@ class Qwen3DecoderLayer(nn.Module):
163
187
  rms_norm_eps=config.rms_norm_eps,
164
188
  attention_bias=config.attention_bias,
165
189
  prefix=add_prefix("self_attn", prefix),
190
+ alt_stream=alt_stream,
166
191
  )
167
192
  self.mlp = Qwen3MLP(
168
193
  hidden_size=self.hidden_size,
@@ -176,6 +201,18 @@ class Qwen3DecoderLayer(nn.Module):
176
201
  config.hidden_size, eps=config.rms_norm_eps
177
202
  )
178
203
 
204
+ self.layer_scatter_modes = LayerScatterModes.init_new(
205
+ layer_id=layer_id,
206
+ num_layers=config.num_hidden_layers,
207
+ is_layer_sparse=False,
208
+ is_previous_layer_sparse=False,
209
+ )
210
+ self.layer_communicator = LayerCommunicator(
211
+ layer_scatter_modes=self.layer_scatter_modes,
212
+ input_layernorm=self.input_layernorm,
213
+ post_attention_layernorm=self.post_attention_layernorm,
214
+ )
215
+
179
216
  def forward(
180
217
  self,
181
218
  positions: torch.Tensor,
@@ -184,20 +221,24 @@ class Qwen3DecoderLayer(nn.Module):
184
221
  residual: Optional[torch.Tensor],
185
222
  ) -> Tuple[torch.Tensor, torch.Tensor]:
186
223
  # Self Attention
187
- if residual is None:
188
- residual = hidden_states
189
- hidden_states = self.input_layernorm(hidden_states)
190
- else:
191
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
192
- hidden_states = self.self_attn(
193
- positions=positions,
194
- hidden_states=hidden_states,
195
- forward_batch=forward_batch,
224
+ hidden_states, residual = self.layer_communicator.prepare_attn(
225
+ hidden_states, residual, forward_batch
196
226
  )
227
+ if hidden_states.shape[0] != 0:
228
+ hidden_states = self.self_attn(
229
+ positions=positions,
230
+ hidden_states=hidden_states,
231
+ forward_batch=forward_batch,
232
+ )
197
233
 
198
234
  # Fully Connected
199
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
235
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
236
+ hidden_states, residual, forward_batch
237
+ )
200
238
  hidden_states = self.mlp(hidden_states)
239
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
240
+ hidden_states, residual, forward_batch
241
+ )
201
242
  return hidden_states, residual
202
243
 
203
244
 
@@ -208,11 +249,13 @@ class Qwen3Model(Qwen2Model):
208
249
  quant_config: Optional[QuantizationConfig] = None,
209
250
  prefix: str = "",
210
251
  ) -> None:
252
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
211
253
  super().__init__(
212
254
  config=config,
213
255
  quant_config=quant_config,
214
256
  prefix=prefix,
215
257
  decoder_layer_type=Qwen3DecoderLayer,
258
+ alt_stream=alt_stream,
216
259
  )
217
260
 
218
261
 
@@ -282,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
282
325
  self.logits_processor = LogitsProcessor(config)
283
326
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
284
327
 
328
+ # For EAGLE3 support
329
+ self.capture_aux_hidden_states = False
330
+
285
331
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
286
332
  return self.model.get_input_embeddings(input_ids)
287
333
 
@@ -303,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
303
349
  pp_proxy_tensors=pp_proxy_tensors,
304
350
  )
305
351
 
352
+ aux_hidden_states = None
353
+ if self.capture_aux_hidden_states:
354
+ hidden_states, aux_hidden_states = hidden_states
355
+
306
356
  if self.pp_group.is_last_rank:
307
357
  if not get_embedding:
308
358
  return self.logits_processor(
309
- input_ids, hidden_states, self.lm_head, forward_batch
359
+ input_ids,
360
+ hidden_states,
361
+ self.lm_head,
362
+ forward_batch,
363
+ aux_hidden_states,
310
364
  )
311
365
  else:
312
366
  return self.pooler(hidden_states, forward_batch)
@@ -404,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
404
458
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
405
459
  self.model.load_kv_cache_scales(quantization_param_path)
406
460
 
461
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
462
+ if not self.pp_group.is_last_rank:
463
+ return
464
+
465
+ self.capture_aux_hidden_states = True
466
+ if layer_ids is None:
467
+ num_layers = self.config.num_hidden_layers
468
+ self.model.layers_to_capture = [
469
+ 2,
470
+ num_layers // 2,
471
+ num_layers - 3,
472
+ ] # Specific layers for EAGLE3 support
473
+ else:
474
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
475
+
407
476
 
408
477
  EntryClass = Qwen3ForCausalLM