sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import Iterable, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
+
9
+ from sglang.srt.distributed import (
10
+ get_moe_expert_parallel_world_size,
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ parallel_state,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from sglang.srt.hf_transformers_utils import get_processor
17
+ from sglang.srt.layers.dp_attention import (
18
+ get_attention_tp_rank,
19
+ get_attention_tp_size,
20
+ get_local_attention_dp_size,
21
+ )
22
+ from sglang.srt.layers.logits_processor import LogitsProcessor
23
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
24
+ from sglang.srt.layers.pooler import Pooler, PoolingType
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
27
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.models.glm4_moe import Glm4MoeModel
30
+ from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
31
+ from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
32
+
33
+ _is_cuda = is_cuda()
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ cached_get_processor = lru_cache(get_processor)
38
+
39
+
40
+ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
41
+ def __init__(
42
+ self,
43
+ config: Glm4vMoeConfig,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ prefix: str = "",
46
+ ) -> None:
47
+ nn.Module.__init__(self)
48
+
49
+ config.moe_layer_freq = 1
50
+ self.config = config
51
+ self.tp_size = get_tensor_model_parallel_world_size()
52
+ self.dp_size = get_local_attention_dp_size()
53
+ self.quant_config = quant_config
54
+ self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
55
+ self.num_fused_shared_experts = (
56
+ 0
57
+ if global_server_args_dict["disable_shared_experts_fusion"]
58
+ else config.n_shared_experts
59
+ )
60
+
61
+ self.model = Glm4MoeModel(
62
+ config,
63
+ quant_config,
64
+ prefix=add_prefix("language_model", prefix),
65
+ )
66
+ self.visual = Glm4vVisionModel(
67
+ config.vision_config,
68
+ norm_eps=getattr(config, "rms_norm_eps", 1e-5),
69
+ quant_config=quant_config,
70
+ prefix=add_prefix("visual", prefix),
71
+ )
72
+
73
+ self.lm_head = ParallelLMHead(
74
+ config.vocab_size,
75
+ config.hidden_size,
76
+ quant_config=quant_config,
77
+ prefix=add_prefix("lm_head", prefix),
78
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
79
+ )
80
+ self.logits_processor = LogitsProcessor(config)
81
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
82
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
83
+
84
+ def determine_num_fused_shared_experts(
85
+ self, architecture: str = "Glm4MoeForCausalLM"
86
+ ):
87
+ self.num_fused_shared_experts = 0
88
+ if global_server_args_dict["disable_shared_experts_fusion"]:
89
+ return
90
+
91
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
92
+ disable_reason = None
93
+ if (
94
+ not _is_cuda
95
+ or torch.cuda.get_device_capability("cuda") < (8, 0)
96
+ or self.config.architectures[0] != architecture
97
+ or self.config.n_shared_experts != 1
98
+ ):
99
+ disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
100
+ elif get_moe_expert_parallel_world_size() > 1:
101
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
102
+
103
+ if disable_reason is not None:
104
+ global_server_args_dict["disable_shared_experts_fusion"] = True
105
+ self.num_fused_shared_experts = 0
106
+ log_info_on_rank0(
107
+ logger,
108
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
109
+ )
110
+ return
111
+
112
+ self.num_fused_shared_experts = self.config.n_shared_experts
113
+
114
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
115
+
116
+ if is_nextn:
117
+ if hasattr(self.config, "num_nextn_predict_layers"):
118
+ num_nextn_layers = self.config.num_nextn_predict_layers
119
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
120
+ # compatible with old design
121
+ nextn_layer_id = (
122
+ 0
123
+ if self.config.num_hidden_layers == 1
124
+ else self.config.num_hidden_layers
125
+ )
126
+ else:
127
+ raise ValueError("num_nextn_predict_layers is not in the config")
128
+
129
+ stacked_params_mapping = [
130
+ # (param_name, shard_name, shard_id)
131
+ ("qkv_proj", "q_proj", "q"),
132
+ ("qkv_proj", "k_proj", "k"),
133
+ ("qkv_proj", "v_proj", "v"),
134
+ ("gate_up_proj", "gate_proj", 0),
135
+ ("gate_up_proj", "up_proj", 1),
136
+ ]
137
+ if self.num_fused_shared_experts > 0:
138
+ assert self.num_fused_shared_experts == 1
139
+ weights_list = list(weights)
140
+ weights_dict = dict(weights_list)
141
+ if self.quant_config is not None:
142
+ if self.quant_config.get_name() == "w8a8_int8":
143
+ suffix_list = [
144
+ "down_proj.weight",
145
+ "down_proj.weight_scale",
146
+ "gate_proj.weight",
147
+ "gate_proj.weight_scale",
148
+ "up_proj.weight",
149
+ "up_proj.weight_scale",
150
+ ]
151
+ elif (
152
+ self.quant_config.get_name() == "fp8"
153
+ or self.quant_config.get_name() == "blockwise_int8"
154
+ or self.quant_config.get_name() == "compressed_tensors"
155
+ ):
156
+ suffix_list = [
157
+ "down_proj.weight",
158
+ "down_proj.weight_scale",
159
+ "gate_proj.weight",
160
+ "gate_proj.weight_scale",
161
+ "up_proj.weight",
162
+ "up_proj.weight_scale",
163
+ ]
164
+ elif self.quant_config.get_name() == "awq":
165
+ suffix_list = [
166
+ "down_proj.qweight",
167
+ "down_proj.qzeros",
168
+ "down_proj.scales",
169
+ "gate_proj.qweight",
170
+ "gate_proj.qzeros",
171
+ "gate_proj.scales",
172
+ "up_proj.qweight",
173
+ "up_proj.qzeros",
174
+ "up_proj.scales",
175
+ ]
176
+ elif self.quant_config.get_name() == "modelopt_fp4":
177
+ suffix_list = [
178
+ "down_proj.weight",
179
+ "down_proj.weight_scale",
180
+ "down_proj.weight_scale_2",
181
+ "down_proj.input_scale",
182
+ "gate_proj.weight",
183
+ "gate_proj.weight_scale",
184
+ "gate_proj.weight_scale_2",
185
+ "gate_proj.input_scale",
186
+ "up_proj.weight",
187
+ "up_proj.weight_scale",
188
+ "up_proj.weight_scale_2",
189
+ "up_proj.input_scale",
190
+ ]
191
+ else:
192
+ raise ValueError(
193
+ f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
194
+ )
195
+ else:
196
+ suffix_list = [
197
+ "down_proj.weight",
198
+ "gate_proj.weight",
199
+ "up_proj.weight",
200
+ ]
201
+ names_to_remove = []
202
+
203
+ moe_layers = (
204
+ range(
205
+ self.config.first_k_dense_replace,
206
+ self.config.num_hidden_layers,
207
+ self.config.moe_layer_freq,
208
+ )
209
+ if not is_nextn
210
+ else [nextn_layer_id]
211
+ )
212
+
213
+ for moe_layer in moe_layers:
214
+ for suffix in suffix_list:
215
+ shared_expert_weight_name = (
216
+ f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
217
+ )
218
+ # online fp8 quantization does not load weight_scale
219
+ if shared_expert_weight_name not in weights_dict:
220
+ continue
221
+ weights_list.append(
222
+ (
223
+ f"model.layers.{moe_layer}."
224
+ f"mlp.experts."
225
+ f"{self.config.n_routed_experts + 0}"
226
+ f".{suffix}",
227
+ weights_dict[shared_expert_weight_name],
228
+ )
229
+ )
230
+ names_to_remove += [shared_expert_weight_name]
231
+ weights = [w for w in weights_list if w[0] not in names_to_remove]
232
+
233
+ # Params for weights, fp8 weight scales, fp8 activation scales
234
+ # (param_name, weight_name, expert_id, shard_id)
235
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
236
+ ckpt_gate_proj_name="gate_proj",
237
+ ckpt_down_proj_name="down_proj",
238
+ ckpt_up_proj_name="up_proj",
239
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
240
+ )
241
+
242
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
243
+ fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
244
+ self.config.q_lora_rank is not None
245
+ )
246
+ cached_a_proj = {} if fuse_qkv_a_proj else None
247
+
248
+ if is_nextn:
249
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
250
+ nextn_spec_weight_names = [
251
+ "shared_head.norm",
252
+ "eh_proj",
253
+ "enorm",
254
+ "hnorm",
255
+ ]
256
+
257
+ params_dict = dict(self.named_parameters())
258
+ weight_names = []
259
+ for name, loaded_weight in weights:
260
+ weight_names.append(name)
261
+
262
+ if not is_nextn:
263
+ if hasattr(self.config, "num_nextn_predict_layers"):
264
+ num_nextn_layers = self.config.num_nextn_predict_layers
265
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
266
+ name_list = name.split(".")
267
+ if (
268
+ len(name_list) >= 3
269
+ and int(name_list[2]) >= self.config.num_hidden_layers
270
+ ):
271
+ continue
272
+ else:
273
+ if not name.startswith(nextn_layer_prefix):
274
+ continue
275
+
276
+ # Use shared head and embed weights from target model
277
+ if "shared_head.head" in name or "embed_tokens" in name:
278
+ continue
279
+
280
+ is_decoder = True
281
+ # For nextn specific weights
282
+ for weight_name in nextn_spec_weight_names:
283
+ if weight_name in name:
284
+ name = name.replace(nextn_layer_prefix, "model")
285
+ is_decoder = False
286
+ break
287
+ # For decoder layer weights
288
+ if is_decoder:
289
+ name = name.replace(nextn_layer_prefix, "model.decoder")
290
+
291
+ if "language_model." in name:
292
+ name = name.replace("language_model.", "")
293
+ if "model.visual." in name:
294
+ name = name.replace("model.visual.", "visual.")
295
+ if "rotary_emb.inv_freq" in name:
296
+ continue
297
+ for param_name, weight_name, shard_id in stacked_params_mapping:
298
+ # Skip non-stacked layers and experts (experts handled below).
299
+ if weight_name not in name:
300
+ continue
301
+ # We have mlp.experts[0].gate_proj in the checkpoint.
302
+ # Since we handle the experts below in expert_params_mapping,
303
+ # we need to skip here BEFORE we update the name, otherwise
304
+ # name will be updated to mlp.experts[0].gate_up_proj, which
305
+ # will then be updated below in expert_params_mapping
306
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
307
+ if ("mlp.experts." in name) and name not in params_dict:
308
+ continue
309
+ name = name.replace(weight_name, param_name)
310
+ # Skip loading extra bias for GPTQ models.
311
+ if name.endswith(".bias") and name not in params_dict:
312
+ continue
313
+ param = params_dict[name]
314
+
315
+ weight_loader = param.weight_loader
316
+ weight_loader(param, loaded_weight, shard_id)
317
+ break
318
+ else:
319
+ for mapping in expert_params_mapping:
320
+ param_name, weight_name, expert_id, shard_id = mapping
321
+ if weight_name not in name:
322
+ continue
323
+ name = name.replace(weight_name, param_name)
324
+ param = params_dict[name]
325
+ weight_loader = param.weight_loader
326
+ weight_loader(
327
+ param,
328
+ loaded_weight,
329
+ name,
330
+ shard_id=shard_id,
331
+ expert_id=expert_id,
332
+ )
333
+ break
334
+ else:
335
+ if "visual" in name:
336
+ # adapt to VisionAttention
337
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
338
+
339
+ # Skip loading extra bias for GPTQ models.
340
+ if name.endswith(".bias") and name not in params_dict:
341
+ continue
342
+ if fuse_qkv_a_proj and (
343
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
344
+ ):
345
+ cached_a_proj[name] = loaded_weight
346
+ q_a_proj_name = (
347
+ name
348
+ if "q_a_proj" in name
349
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
350
+ )
351
+ kv_a_proj_name = (
352
+ name
353
+ if "kv_a_proj_with_mqa" in name
354
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
355
+ )
356
+
357
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
358
+ if (
359
+ q_a_proj_name in cached_a_proj
360
+ and kv_a_proj_name in cached_a_proj
361
+ ):
362
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
363
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
364
+ fused_weight = torch.cat(
365
+ [q_a_proj_weight, kv_a_proj_weight], dim=0
366
+ )
367
+ param_name = (
368
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
369
+ if "q_a_proj" in name
370
+ else name.replace(
371
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
372
+ )
373
+ )
374
+ param = params_dict[param_name]
375
+
376
+ weight_loader = getattr(
377
+ param, "weight_loader", default_weight_loader
378
+ )
379
+ weight_loader(param, fused_weight)
380
+ cached_a_proj.pop(q_a_proj_name)
381
+ cached_a_proj.pop(kv_a_proj_name)
382
+ else:
383
+ if (
384
+ "k_scale" in name or "v_scale" in name
385
+ ) and name not in params_dict:
386
+ # modelopt attn kv scale is named differently
387
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
388
+ name = name.replace("_proj", "attn_mqa")
389
+ else:
390
+ logger.warning(
391
+ f"Unknown scale found in checkpoint: {name}"
392
+ )
393
+ param = params_dict[name]
394
+ weight_loader = getattr(
395
+ param, "weight_loader", default_weight_loader
396
+ )
397
+ weight_loader(param, loaded_weight)
398
+
399
+
400
+ EntryClass = [Glm4vMoeForConditionalGeneration]
@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
56
  from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
57
57
  from sglang.srt.layers.radix_attention import RadixAttention
58
58
  from sglang.srt.layers.rotary_embedding import get_rope
59
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
59
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
60
60
  from sglang.srt.layers.vocab_parallel_embedding import (
61
61
  ParallelLMHead,
62
62
  VocabParallelEmbedding,
@@ -64,7 +64,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
64
  from sglang.srt.managers.schedule_batch import global_server_args_dict
65
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
- from sglang.srt.utils import add_prefix, make_layers
67
+ from sglang.srt.utils import (
68
+ LazyValue,
69
+ add_prefix,
70
+ is_cuda,
71
+ is_flashinfer_available,
72
+ make_layers,
73
+ )
74
+
75
+ _is_cuda = is_cuda()
76
+ _is_flashinfer_available = is_flashinfer_available()
77
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
78
+
79
+
80
+ if _is_cuda:
81
+ from sgl_kernel import FusedSetKVBufferArg
68
82
 
69
83
 
70
84
  class GptOssConfig(PretrainedConfig):
@@ -151,10 +165,13 @@ class GptOssSparseMoeBlock(nn.Module):
151
165
  )
152
166
 
153
167
  def forward(
154
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ forward_batch: Optional[ForwardBatch] = None,
171
+ should_allreduce_fusion: bool = False,
155
172
  ) -> torch.Tensor:
156
173
  if not global_server_args_dict["moe_a2a_backend"].is_deepep():
157
- return self.forward_normal(hidden_states)
174
+ return self.forward_normal(hidden_states, should_allreduce_fusion)
158
175
  else:
159
176
  raise Exception("forward_deepep branch not implemented yet")
160
177
 
@@ -165,7 +182,11 @@ class GptOssSparseMoeBlock(nn.Module):
165
182
  if name not in ["correction_bias"]
166
183
  ]
167
184
 
168
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
185
+ def forward_normal(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ should_allreduce_fusion: bool = False,
189
+ ) -> torch.Tensor:
169
190
  num_tokens, hidden_dim = hidden_states.shape
170
191
  hidden_states = hidden_states.view(-1, hidden_dim)
171
192
 
@@ -179,13 +200,39 @@ class GptOssSparseMoeBlock(nn.Module):
179
200
  kwargs["topk_output"] = (self.top_k, router_logits)
180
201
  final_hidden_states = self.experts(**kwargs)
181
202
 
182
- if self.tp_size > 1:
203
+ if self.tp_size > 1 and not should_allreduce_fusion:
183
204
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
184
205
 
185
206
  ans = final_hidden_states.view(num_tokens, hidden_dim)
186
207
  return ans
187
208
 
188
209
 
210
+ def _enable_fused_set_kv_buffer():
211
+ return _is_cuda
212
+
213
+
214
+ # TODO maybe move to a model-common utils
215
+ def _create_fused_set_kv_buffer_arg(
216
+ value: torch.Tensor,
217
+ layer: RadixAttention,
218
+ forward_batch: ForwardBatch,
219
+ ):
220
+ layer_id = layer.layer_id
221
+ token_to_kv_pool = forward_batch.token_to_kv_pool
222
+
223
+ k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
224
+ v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
225
+
226
+ return FusedSetKVBufferArg(
227
+ value=value,
228
+ k_buffer=k_buffer.view(k_buffer.shape[0], -1),
229
+ v_buffer=v_buffer.view(v_buffer.shape[0], -1),
230
+ k_scale=layer.k_scale,
231
+ v_scale=layer.v_scale,
232
+ cache_loc=forward_batch.out_cache_loc,
233
+ )
234
+
235
+
189
236
  class GptOssAttention(nn.Module):
190
237
  def __init__(
191
238
  self,
@@ -247,7 +294,7 @@ class GptOssAttention(nn.Module):
247
294
  )
248
295
 
249
296
  self.sinks = nn.Parameter(
250
- torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
297
+ torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
251
298
  )
252
299
 
253
300
  self.o_proj = RowParallelLinear(
@@ -293,7 +340,21 @@ class GptOssAttention(nn.Module):
293
340
  return hidden_states, forward_batch, None
294
341
  qkv, _ = self.qkv_proj(hidden_states)
295
342
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
296
- q, k = self.rotary_emb(positions, q, k)
343
+
344
+ q, k = self.rotary_emb(
345
+ positions,
346
+ q,
347
+ k,
348
+ fused_set_kv_buffer_arg=(
349
+ _create_fused_set_kv_buffer_arg(
350
+ value=v,
351
+ layer=self.attn,
352
+ forward_batch=forward_batch,
353
+ )
354
+ if _enable_fused_set_kv_buffer()
355
+ else None
356
+ ),
357
+ )
297
358
  inner_state = q, k, v, forward_batch
298
359
  return None, forward_batch, inner_state
299
360
 
@@ -301,7 +362,11 @@ class GptOssAttention(nn.Module):
301
362
  hidden_states, forward_batch, inner_state = intermediate_state
302
363
  if inner_state is None:
303
364
  return hidden_states
304
- attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
365
+ attn_output = self.attn(
366
+ *inner_state,
367
+ sinks=self.sinks,
368
+ save_kv_cache=not _enable_fused_set_kv_buffer(),
369
+ )
305
370
  output, _ = self.o_proj(attn_output)
306
371
  return output
307
372
 
@@ -370,6 +435,7 @@ class GptOssDecoderLayer(nn.Module):
370
435
 
371
436
  # GptOss all layers are sparse and have no nextn now
372
437
  self.is_layer_sparse = True
438
+ self.is_nextn = False
373
439
  is_previous_layer_sparse = True
374
440
 
375
441
  self.layer_scatter_modes = LayerScatterModes.init_new(
@@ -402,6 +468,42 @@ class GptOssDecoderLayer(nn.Module):
402
468
  post_attention_layernorm=self.post_attention_layernorm,
403
469
  )
404
470
 
471
+ self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
472
+
473
+ def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
474
+ """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
475
+
476
+ batch_size = (
477
+ forward_batch.input_ids.shape[0]
478
+ if hasattr(forward_batch, "input_ids")
479
+ else 0
480
+ )
481
+
482
+ if batch_size > 128:
483
+ return False
484
+
485
+ return self._fuse_allreduce_lookup_table.get(batch_size, False)
486
+
487
+ def _build_fuse_allreduce_lookup_table(self):
488
+ static_conditions_met = (
489
+ self.layer_id != self.config.num_hidden_layers - 1
490
+ and get_tensor_model_parallel_world_size() > 1
491
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
492
+ and _is_sm100_supported
493
+ and _is_flashinfer_available
494
+ )
495
+
496
+ if not static_conditions_met:
497
+ return {}
498
+
499
+ lookup_table = {}
500
+ for batch_size in range(129): # 0 to 128
501
+ is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
502
+ should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
503
+ lookup_table[batch_size] = should_fuse
504
+
505
+ return lookup_table
506
+
405
507
  def forward(
406
508
  self,
407
509
  positions: torch.Tensor,
@@ -424,12 +526,21 @@ class GptOssDecoderLayer(nn.Module):
424
526
  hidden_states, residual, forward_batch
425
527
  )
426
528
 
427
- hidden_states = self.mlp(hidden_states, forward_batch)
428
-
429
- hidden_states, residual = self.layer_communicator.postprocess_layer(
430
- hidden_states, residual, forward_batch
529
+ should_allreduce_fusion = (
530
+ self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
531
+ and not self.is_nextn
431
532
  )
432
533
 
534
+ hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
535
+
536
+ if should_allreduce_fusion:
537
+ hidden_states._sglang_needs_allreduce_fusion = True
538
+
539
+ if not should_allreduce_fusion:
540
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
541
+ hidden_states, residual, forward_batch
542
+ )
543
+
433
544
  return hidden_states, residual
434
545
 
435
546
 
@@ -550,6 +661,18 @@ class GptOssForCausalLM(nn.Module):
550
661
  self.logits_processor = LogitsProcessor(config)
551
662
  self.capture_aux_hidden_states = False
552
663
 
664
+ self._routed_experts_weights_of_layer = LazyValue(
665
+ lambda: {
666
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
667
+ for layer_id in range(self.start_layer, self.end_layer)
668
+ if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
669
+ }
670
+ )
671
+
672
+ @property
673
+ def routed_experts_weights_of_layer(self):
674
+ return self._routed_experts_weights_of_layer.value
675
+
553
676
  @torch.no_grad()
554
677
  def forward(
555
678
  self,
@@ -1033,12 +1156,6 @@ class GptOssForCausalLM(nn.Module):
1033
1156
  else:
1034
1157
  logging.info("All parameters loaded successfully.")
1035
1158
 
1036
- self.routed_experts_weights_of_layer = {
1037
- layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
1038
- for layer_id in range(self.start_layer, self.end_layer)
1039
- if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
1040
- }
1041
-
1042
1159
  def get_embed_and_head(self):
1043
1160
  return self.model.embed_tokens.weight, self.lm_head.weight
1044
1161
 
@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
363
363
  else:
364
364
  return self.pooler(hidden_states, forward_batch)
365
365
 
366
- def get_hidden_dim(self, module_name):
367
- # return input_dim, output_dim
368
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
369
- return self.config.hidden_size, self.config.hidden_size
370
- elif module_name in ["kv_proj"]:
371
- return self.config.hidden_size, self.config.hidden_size // (
372
- self.config.num_attention_heads // self.config.num_key_value_heads
373
- )
374
- elif module_name == "gate_up_proj":
375
- return self.config.hidden_size, self.config.intermediate_size
376
- elif module_name == "down_proj":
377
- return self.config.intermediate_size, self.config.hidden_size
378
- else:
379
- raise NotImplementedError()
380
-
381
- def get_module_name(self, name):
382
- params_mapping = {
383
- "q_proj": "qkv_proj",
384
- "k_proj": "qkv_proj",
385
- "v_proj": "qkv_proj",
386
- "gate_proj": "gate_up_proj",
387
- "up_proj": "gate_up_proj",
388
- }
389
- return params_mapping.get(name, name)
390
-
391
366
  def get_module_name_from_weight_name(self, name):
392
367
  for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
393
368
  if weight_name in name: