sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import Iterable, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
+
9
+ from sglang.srt.distributed import (
10
+ get_moe_expert_parallel_world_size,
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ parallel_state,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from sglang.srt.hf_transformers_utils import get_processor
17
+ from sglang.srt.layers.dp_attention import (
18
+ get_attention_tp_rank,
19
+ get_attention_tp_size,
20
+ get_local_attention_dp_size,
21
+ )
22
+ from sglang.srt.layers.logits_processor import LogitsProcessor
23
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
24
+ from sglang.srt.layers.pooler import Pooler, PoolingType
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
27
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.models.glm4_moe import Glm4MoeModel
30
+ from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
31
+ from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
32
+
33
+ _is_cuda = is_cuda()
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ cached_get_processor = lru_cache(get_processor)
38
+
39
+
40
+ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
41
+ def __init__(
42
+ self,
43
+ config: Glm4vMoeConfig,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ prefix: str = "",
46
+ ) -> None:
47
+ nn.Module.__init__(self)
48
+
49
+ config.moe_layer_freq = 1
50
+ self.config = config
51
+ self.tp_size = get_tensor_model_parallel_world_size()
52
+ self.dp_size = get_local_attention_dp_size()
53
+ self.quant_config = quant_config
54
+ self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
55
+ self.num_fused_shared_experts = (
56
+ 0
57
+ if global_server_args_dict["disable_shared_experts_fusion"]
58
+ else config.n_shared_experts
59
+ )
60
+
61
+ self.model = Glm4MoeModel(
62
+ config,
63
+ quant_config,
64
+ prefix=add_prefix("language_model", prefix),
65
+ )
66
+ self.visual = Glm4vVisionModel(
67
+ config.vision_config,
68
+ norm_eps=getattr(config, "rms_norm_eps", 1e-5),
69
+ quant_config=quant_config,
70
+ prefix=add_prefix("visual", prefix),
71
+ )
72
+
73
+ self.lm_head = ParallelLMHead(
74
+ config.vocab_size,
75
+ config.hidden_size,
76
+ quant_config=quant_config,
77
+ prefix=add_prefix("lm_head", prefix),
78
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
79
+ )
80
+ self.logits_processor = LogitsProcessor(config)
81
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
82
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
83
+
84
+ def determine_num_fused_shared_experts(
85
+ self, architecture: str = "Glm4MoeForCausalLM"
86
+ ):
87
+ self.num_fused_shared_experts = 0
88
+ if global_server_args_dict["disable_shared_experts_fusion"]:
89
+ return
90
+
91
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
92
+ disable_reason = None
93
+ if (
94
+ not _is_cuda
95
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
96
+ or self.config.architectures[0] != architecture
97
+ or self.config.n_shared_experts != 1
98
+ ):
99
+ disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
100
+ elif get_moe_expert_parallel_world_size() > 1:
101
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
102
+
103
+ if disable_reason is not None:
104
+ global_server_args_dict["disable_shared_experts_fusion"] = True
105
+ self.num_fused_shared_experts = 0
106
+ log_info_on_rank0(
107
+ logger,
108
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
109
+ )
110
+ return
111
+
112
+ self.num_fused_shared_experts = self.config.n_shared_experts
113
+
114
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
115
+
116
+ if is_nextn:
117
+ if hasattr(self.config, "num_nextn_predict_layers"):
118
+ num_nextn_layers = self.config.num_nextn_predict_layers
119
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
120
+ # compatible with old design
121
+ nextn_layer_id = (
122
+ 0
123
+ if self.config.num_hidden_layers == 1
124
+ else self.config.num_hidden_layers
125
+ )
126
+ else:
127
+ raise ValueError("num_nextn_predict_layers is not in the config")
128
+
129
+ stacked_params_mapping = [
130
+ # (param_name, shard_name, shard_id)
131
+ ("qkv_proj", "q_proj", "q"),
132
+ ("qkv_proj", "k_proj", "k"),
133
+ ("qkv_proj", "v_proj", "v"),
134
+ ("gate_up_proj", "gate_proj", 0),
135
+ ("gate_up_proj", "up_proj", 1),
136
+ ]
137
+ if self.num_fused_shared_experts > 0:
138
+ assert self.num_fused_shared_experts == 1
139
+ weights_list = list(weights)
140
+ weights_dict = dict(weights_list)
141
+ if self.quant_config is not None:
142
+ if self.quant_config.get_name() == "w8a8_int8":
143
+ suffix_list = [
144
+ "down_proj.weight",
145
+ "down_proj.weight_scale",
146
+ "gate_proj.weight",
147
+ "gate_proj.weight_scale",
148
+ "up_proj.weight",
149
+ "up_proj.weight_scale",
150
+ ]
151
+ elif (
152
+ self.quant_config.get_name() == "fp8"
153
+ or self.quant_config.get_name() == "blockwise_int8"
154
+ or self.quant_config.get_name() == "compressed_tensors"
155
+ ):
156
+ suffix_list = [
157
+ "down_proj.weight",
158
+ "down_proj.weight_scale",
159
+ "gate_proj.weight",
160
+ "gate_proj.weight_scale",
161
+ "up_proj.weight",
162
+ "up_proj.weight_scale",
163
+ ]
164
+ elif self.quant_config.get_name() == "awq":
165
+ suffix_list = [
166
+ "down_proj.qweight",
167
+ "down_proj.qzeros",
168
+ "down_proj.scales",
169
+ "gate_proj.qweight",
170
+ "gate_proj.qzeros",
171
+ "gate_proj.scales",
172
+ "up_proj.qweight",
173
+ "up_proj.qzeros",
174
+ "up_proj.scales",
175
+ ]
176
+ elif self.quant_config.get_name() == "modelopt_fp4":
177
+ suffix_list = [
178
+ "down_proj.weight",
179
+ "down_proj.weight_scale",
180
+ "down_proj.weight_scale_2",
181
+ "down_proj.input_scale",
182
+ "gate_proj.weight",
183
+ "gate_proj.weight_scale",
184
+ "gate_proj.weight_scale_2",
185
+ "gate_proj.input_scale",
186
+ "up_proj.weight",
187
+ "up_proj.weight_scale",
188
+ "up_proj.weight_scale_2",
189
+ "up_proj.input_scale",
190
+ ]
191
+ else:
192
+ raise ValueError(
193
+ f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
194
+ )
195
+ else:
196
+ suffix_list = [
197
+ "down_proj.weight",
198
+ "gate_proj.weight",
199
+ "up_proj.weight",
200
+ ]
201
+ names_to_remove = []
202
+
203
+ moe_layers = (
204
+ range(
205
+ self.config.first_k_dense_replace,
206
+ self.config.num_hidden_layers,
207
+ self.config.moe_layer_freq,
208
+ )
209
+ if not is_nextn
210
+ else [nextn_layer_id]
211
+ )
212
+
213
+ for moe_layer in moe_layers:
214
+ for suffix in suffix_list:
215
+ shared_expert_weight_name = (
216
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
217
+ )
218
+ # online fp8 quantization does not load weight_scale
219
+ if shared_expert_weight_name not in weights_dict:
220
+ continue
221
+ weights_list.append(
222
+ (
223
+ f"model.layers.{moe_layer}."
224
+ f"mlp.experts."
225
+ f"{self.config.n_routed_experts + 0}"
226
+ f".{suffix}",
227
+ weights_dict[shared_expert_weight_name],
228
+ )
229
+ )
230
+ names_to_remove += [shared_expert_weight_name]
231
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
232
+
233
+ # Params for weights, fp8 weight scales, fp8 activation scales
234
+ # (param_name, weight_name, expert_id, shard_id)
235
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
236
+ ckpt_gate_proj_name="gate_proj",
237
+ ckpt_down_proj_name="down_proj",
238
+ ckpt_up_proj_name="up_proj",
239
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
240
+ )
241
+
242
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
243
+ fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
244
+ self.config.q_lora_rank is not None
245
+ )
246
+ cached_a_proj = {} if fuse_qkv_a_proj else None
247
+
248
+ if is_nextn:
249
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
250
+ nextn_spec_weight_names = [
251
+ "shared_head.norm",
252
+ "eh_proj",
253
+ "enorm",
254
+ "hnorm",
255
+ ]
256
+
257
+ params_dict = dict(self.named_parameters())
258
+ weight_names = []
259
+ for name, loaded_weight in weights:
260
+ weight_names.append(name)
261
+
262
+ if not is_nextn:
263
+ if hasattr(self.config, "num_nextn_predict_layers"):
264
+ num_nextn_layers = self.config.num_nextn_predict_layers
265
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
266
+ name_list = name.split(".")
267
+ if (
268
+ len(name_list) >= 3
269
+ and int(name_list[2]) >= self.config.num_hidden_layers
270
+ ):
271
+ continue
272
+ else:
273
+ if not name.startswith(nextn_layer_prefix):
274
+ continue
275
+
276
+ # Use shared head and embed weights from target model
277
+ if "shared_head.head" in name or "embed_tokens" in name:
278
+ continue
279
+
280
+ is_decoder = True
281
+ # For nextn specific weights
282
+ for weight_name in nextn_spec_weight_names:
283
+ if weight_name in name:
284
+ name = name.replace(nextn_layer_prefix, "model")
285
+ is_decoder = False
286
+ break
287
+ # For decoder layer weights
288
+ if is_decoder:
289
+ name = name.replace(nextn_layer_prefix, "model.decoder")
290
+
291
+ if "language_model." in name:
292
+ name = name.replace("language_model.", "")
293
+ if "model.visual." in name:
294
+ name = name.replace("model.visual.", "visual.")
295
+ if "rotary_emb.inv_freq" in name:
296
+ continue
297
+ for param_name, weight_name, shard_id in stacked_params_mapping:
298
+ # Skip non-stacked layers and experts (experts handled below).
299
+ if weight_name not in name:
300
+ continue
301
+ # We have mlp.experts[0].gate_proj in the checkpoint.
302
+ # Since we handle the experts below in expert_params_mapping,
303
+ # we need to skip here BEFORE we update the name, otherwise
304
+ # name will be updated to mlp.experts[0].gate_up_proj, which
305
+ # will then be updated below in expert_params_mapping
306
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
307
+ if ("mlp.experts." in name) and name not in params_dict:
308
+ continue
309
+ name = name.replace(weight_name, param_name)
310
+ # Skip loading extra bias for GPTQ models.
311
+ if name.endswith(".bias") and name not in params_dict:
312
+ continue
313
+ param = params_dict[name]
314
+
315
+ weight_loader = param.weight_loader
316
+ weight_loader(param, loaded_weight, shard_id)
317
+ break
318
+ else:
319
+ for mapping in expert_params_mapping:
320
+ param_name, weight_name, expert_id, shard_id = mapping
321
+ if weight_name not in name:
322
+ continue
323
+ name = name.replace(weight_name, param_name)
324
+ param = params_dict[name]
325
+ weight_loader = param.weight_loader
326
+ weight_loader(
327
+ param,
328
+ loaded_weight,
329
+ name,
330
+ shard_id=shard_id,
331
+ expert_id=expert_id,
332
+ )
333
+ break
334
+ else:
335
+ if "visual" in name:
336
+ # adapt to VisionAttention
337
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
338
+
339
+ # Skip loading extra bias for GPTQ models.
340
+ if name.endswith(".bias") and name not in params_dict:
341
+ continue
342
+ if fuse_qkv_a_proj and (
343
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
344
+ ):
345
+ cached_a_proj[name] = loaded_weight
346
+ q_a_proj_name = (
347
+ name
348
+ if "q_a_proj" in name
349
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
350
+ )
351
+ kv_a_proj_name = (
352
+ name
353
+ if "kv_a_proj_with_mqa" in name
354
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
355
+ )
356
+
357
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
358
+ if (
359
+ q_a_proj_name in cached_a_proj
360
+ and kv_a_proj_name in cached_a_proj
361
+ ):
362
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
363
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
364
+ fused_weight = torch.cat(
365
+ [q_a_proj_weight, kv_a_proj_weight], dim=0
366
+ )
367
+ param_name = (
368
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
369
+ if "q_a_proj" in name
370
+ else name.replace(
371
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
372
+ )
373
+ )
374
+ param = params_dict[param_name]
375
+
376
+ weight_loader = getattr(
377
+ param, "weight_loader", default_weight_loader
378
+ )
379
+ weight_loader(param, fused_weight)
380
+ cached_a_proj.pop(q_a_proj_name)
381
+ cached_a_proj.pop(kv_a_proj_name)
382
+ else:
383
+ if (
384
+ "k_scale" in name or "v_scale" in name
385
+ ) and name not in params_dict:
386
+ # modelopt attn kv scale is named differently
387
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
388
+ name = name.replace("_proj", "attn_mqa")
389
+ else:
390
+ logger.warning(
391
+ f"Unknown scale found in checkpoint: {name}"
392
+ )
393
+ param = params_dict[name]
394
+ weight_loader = getattr(
395
+ param, "weight_loader", default_weight_loader
396
+ )
397
+ weight_loader(param, loaded_weight)
398
+
399
+
400
+ EntryClass = [Glm4vMoeForConditionalGeneration]
@@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import (
41
41
  get_attention_tp_rank,
42
42
  get_attention_tp_size,
43
43
  get_local_attention_dp_size,
44
+ is_dp_attention_enabled,
44
45
  )
45
46
  from sglang.srt.layers.layernorm import RMSNorm
46
47
  from sglang.srt.layers.linear import (
@@ -56,7 +57,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
57
  from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
57
58
  from sglang.srt.layers.radix_attention import RadixAttention
58
59
  from sglang.srt.layers.rotary_embedding import get_rope
59
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
60
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
60
61
  from sglang.srt.layers.vocab_parallel_embedding import (
61
62
  ParallelLMHead,
62
63
  VocabParallelEmbedding,
@@ -64,7 +65,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
65
  from sglang.srt.managers.schedule_batch import global_server_args_dict
65
66
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
67
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
- from sglang.srt.utils import add_prefix, make_layers
68
+ from sglang.srt.utils import (
69
+ LazyValue,
70
+ add_prefix,
71
+ is_cuda,
72
+ is_flashinfer_available,
73
+ make_layers,
74
+ )
75
+
76
+ _is_cuda = is_cuda()
77
+ _is_flashinfer_available = is_flashinfer_available()
78
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
79
+
80
+
81
+ if _is_cuda:
82
+ from sgl_kernel import FusedSetKVBufferArg
68
83
 
69
84
 
70
85
  class GptOssConfig(PretrainedConfig):
@@ -151,10 +166,13 @@ class GptOssSparseMoeBlock(nn.Module):
151
166
  )
152
167
 
153
168
  def forward(
154
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ forward_batch: Optional[ForwardBatch] = None,
172
+ should_allreduce_fusion: bool = False,
155
173
  ) -> torch.Tensor:
156
174
  if not global_server_args_dict["moe_a2a_backend"].is_deepep():
157
- return self.forward_normal(hidden_states)
175
+ return self.forward_normal(hidden_states, should_allreduce_fusion)
158
176
  else:
159
177
  raise Exception("forward_deepep branch not implemented yet")
160
178
 
@@ -165,7 +183,11 @@ class GptOssSparseMoeBlock(nn.Module):
165
183
  if name not in ["correction_bias"]
166
184
  ]
167
185
 
168
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
186
+ def forward_normal(
187
+ self,
188
+ hidden_states: torch.Tensor,
189
+ should_allreduce_fusion: bool = False,
190
+ ) -> torch.Tensor:
169
191
  num_tokens, hidden_dim = hidden_states.shape
170
192
  hidden_states = hidden_states.view(-1, hidden_dim)
171
193
 
@@ -179,13 +201,39 @@ class GptOssSparseMoeBlock(nn.Module):
179
201
  kwargs["topk_output"] = (self.top_k, router_logits)
180
202
  final_hidden_states = self.experts(**kwargs)
181
203
 
182
- if self.tp_size > 1:
204
+ if self.tp_size > 1 and not should_allreduce_fusion:
183
205
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
184
206
 
185
207
  ans = final_hidden_states.view(num_tokens, hidden_dim)
186
208
  return ans
187
209
 
188
210
 
211
+ def _enable_fused_set_kv_buffer():
212
+ return _is_cuda
213
+
214
+
215
+ # TODO maybe move to a model-common utils
216
+ def _create_fused_set_kv_buffer_arg(
217
+ value: torch.Tensor,
218
+ layer: RadixAttention,
219
+ forward_batch: ForwardBatch,
220
+ ):
221
+ layer_id = layer.layer_id
222
+ token_to_kv_pool = forward_batch.token_to_kv_pool
223
+
224
+ k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
225
+ v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
226
+
227
+ return FusedSetKVBufferArg(
228
+ value=value,
229
+ k_buffer=k_buffer.view(k_buffer.shape[0], -1),
230
+ v_buffer=v_buffer.view(v_buffer.shape[0], -1),
231
+ k_scale=layer.k_scale,
232
+ v_scale=layer.v_scale,
233
+ cache_loc=forward_batch.out_cache_loc,
234
+ )
235
+
236
+
189
237
  class GptOssAttention(nn.Module):
190
238
  def __init__(
191
239
  self,
@@ -246,8 +294,12 @@ class GptOssAttention(nn.Module):
246
294
  prefix=add_prefix("qkv_proj", prefix),
247
295
  )
248
296
 
297
+ # Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
298
+ # others can use bfloat16
299
+ attn_backend = global_server_args_dict.get("attention_backend")
300
+ sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
249
301
  self.sinks = nn.Parameter(
250
- torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
302
+ torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
251
303
  )
252
304
 
253
305
  self.o_proj = RowParallelLinear(
@@ -293,7 +345,21 @@ class GptOssAttention(nn.Module):
293
345
  return hidden_states, forward_batch, None
294
346
  qkv, _ = self.qkv_proj(hidden_states)
295
347
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
296
- q, k = self.rotary_emb(positions, q, k)
348
+
349
+ q, k = self.rotary_emb(
350
+ positions,
351
+ q,
352
+ k,
353
+ fused_set_kv_buffer_arg=(
354
+ _create_fused_set_kv_buffer_arg(
355
+ value=v,
356
+ layer=self.attn,
357
+ forward_batch=forward_batch,
358
+ )
359
+ if _enable_fused_set_kv_buffer()
360
+ else None
361
+ ),
362
+ )
297
363
  inner_state = q, k, v, forward_batch
298
364
  return None, forward_batch, inner_state
299
365
 
@@ -301,7 +367,11 @@ class GptOssAttention(nn.Module):
301
367
  hidden_states, forward_batch, inner_state = intermediate_state
302
368
  if inner_state is None:
303
369
  return hidden_states
304
- attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
370
+ attn_output = self.attn(
371
+ *inner_state,
372
+ sinks=self.sinks,
373
+ save_kv_cache=not _enable_fused_set_kv_buffer(),
374
+ )
305
375
  output, _ = self.o_proj(attn_output)
306
376
  return output
307
377
 
@@ -370,6 +440,7 @@ class GptOssDecoderLayer(nn.Module):
370
440
 
371
441
  # GptOss all layers are sparse and have no nextn now
372
442
  self.is_layer_sparse = True
443
+ self.is_nextn = False
373
444
  is_previous_layer_sparse = True
374
445
 
375
446
  self.layer_scatter_modes = LayerScatterModes.init_new(
@@ -402,6 +473,42 @@ class GptOssDecoderLayer(nn.Module):
402
473
  post_attention_layernorm=self.post_attention_layernorm,
403
474
  )
404
475
 
476
+ self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
477
+
478
+ def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
479
+ """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
480
+
481
+ batch_size = (
482
+ forward_batch.input_ids.shape[0]
483
+ if hasattr(forward_batch, "input_ids")
484
+ else 0
485
+ )
486
+
487
+ if batch_size > 128:
488
+ return False
489
+
490
+ return self._fuse_allreduce_lookup_table.get(batch_size, False)
491
+
492
+ def _build_fuse_allreduce_lookup_table(self):
493
+ static_conditions_met = (
494
+ self.layer_id != self.config.num_hidden_layers - 1
495
+ and get_tensor_model_parallel_world_size() > 1
496
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
497
+ and _is_sm100_supported
498
+ and _is_flashinfer_available
499
+ )
500
+
501
+ if not static_conditions_met:
502
+ return {}
503
+
504
+ lookup_table = {}
505
+ for batch_size in range(129): # 0 to 128
506
+ is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
507
+ should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
508
+ lookup_table[batch_size] = should_fuse
509
+
510
+ return lookup_table
511
+
405
512
  def forward(
406
513
  self,
407
514
  positions: torch.Tensor,
@@ -424,12 +531,21 @@ class GptOssDecoderLayer(nn.Module):
424
531
  hidden_states, residual, forward_batch
425
532
  )
426
533
 
427
- hidden_states = self.mlp(hidden_states, forward_batch)
428
-
429
- hidden_states, residual = self.layer_communicator.postprocess_layer(
430
- hidden_states, residual, forward_batch
534
+ should_allreduce_fusion = (
535
+ self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
536
+ and not self.is_nextn
431
537
  )
432
538
 
539
+ hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
540
+
541
+ if should_allreduce_fusion:
542
+ hidden_states._sglang_needs_allreduce_fusion = True
543
+
544
+ if not should_allreduce_fusion:
545
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
546
+ hidden_states, residual, forward_batch
547
+ )
548
+
433
549
  return hidden_states, residual
434
550
 
435
551
 
@@ -450,7 +566,7 @@ class GptOssModel(nn.Module):
450
566
  self.embed_tokens = VocabParallelEmbedding(
451
567
  config.vocab_size,
452
568
  config.hidden_size,
453
- enable_tp=not global_server_args_dict["enable_dp_attention"],
569
+ enable_tp=not is_dp_attention_enabled(),
454
570
  prefix=add_prefix("embed_tokens", prefix),
455
571
  )
456
572
  else:
@@ -550,6 +666,18 @@ class GptOssForCausalLM(nn.Module):
550
666
  self.logits_processor = LogitsProcessor(config)
551
667
  self.capture_aux_hidden_states = False
552
668
 
669
+ self._routed_experts_weights_of_layer = LazyValue(
670
+ lambda: {
671
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
672
+ for layer_id in range(self.start_layer, self.end_layer)
673
+ if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
674
+ }
675
+ )
676
+
677
+ @property
678
+ def routed_experts_weights_of_layer(self):
679
+ return self._routed_experts_weights_of_layer.value
680
+
553
681
  @torch.no_grad()
554
682
  def forward(
555
683
  self,
@@ -1033,12 +1161,6 @@ class GptOssForCausalLM(nn.Module):
1033
1161
  else:
1034
1162
  logging.info("All parameters loaded successfully.")
1035
1163
 
1036
- self.routed_experts_weights_of_layer = {
1037
- layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
1038
- for layer_id in range(self.start_layer, self.end_layer)
1039
- if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
1040
- }
1041
-
1042
1164
  def get_embed_and_head(self):
1043
1165
  return self.model.embed_tokens.weight, self.lm_head.weight
1044
1166