sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
121
121
  set_cpu_offload_max_bytes,
122
122
  set_cuda_arch,
123
123
  )
124
+ from sglang.srt.weight_sync.tensor_bucket import (
125
+ FlattenedTensorBucket,
126
+ FlattenedTensorMetadata,
127
+ )
124
128
 
125
129
  _is_hip = is_hip()
126
130
  _is_npu = is_npu()
@@ -378,6 +382,25 @@ class ModelRunner:
378
382
  )
379
383
  server_args.attention_backend = "torch_native"
380
384
 
385
+ if server_args.prefill_attention_backend is not None and (
386
+ server_args.prefill_attention_backend
387
+ == server_args.decode_attention_backend
388
+ ): # override the default attention backend
389
+ server_args.attention_backend = server_args.prefill_attention_backend
390
+
391
+ if (
392
+ getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
393
+ is not None
394
+ ):
395
+ if server_args.attention_backend is None:
396
+ server_args.attention_backend = "dual_chunk_flash_attn"
397
+ logger.info("Dual chunk attention is turned on by default.")
398
+ elif server_args.attention_backend != "dual_chunk_flash_attn":
399
+ raise ValueError(
400
+ "Dual chunk attention is enabled, but attention backend is set to "
401
+ f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
402
+ )
403
+
381
404
  if server_args.attention_backend is None:
382
405
  """
383
406
  Auto select the fastest attention backend.
@@ -397,7 +420,6 @@ class ModelRunner:
397
420
  is_hopper_with_cuda_12_3()
398
421
  and is_no_spec_infer_or_topk_one(server_args)
399
422
  and is_fa3_default_architecture(self.model_config.hf_config)
400
- and (not server_args.enable_hierarchical_cache)
401
423
  ):
402
424
  server_args.attention_backend = "fa3"
403
425
  elif _is_hip:
@@ -410,9 +432,7 @@ class ModelRunner:
410
432
  )
411
433
  else:
412
434
  # MLA architecture
413
- if is_hopper_with_cuda_12_3() and (
414
- not server_args.enable_hierarchical_cache
415
- ):
435
+ if is_hopper_with_cuda_12_3():
416
436
  server_args.attention_backend = "fa3"
417
437
  elif is_sm100_supported():
418
438
  server_args.attention_backend = "flashinfer"
@@ -500,6 +520,27 @@ class ModelRunner:
500
520
  if self.model_config.context_len > 8192:
501
521
  self.mem_fraction_static *= 0.85
502
522
 
523
+ if (
524
+ server_args.enable_hierarchical_cache
525
+ and server_args.hicache_io_backend == "kernel"
526
+ ):
527
+ # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
528
+ if server_args.decode_attention_backend is None:
529
+ if not self.use_mla_backend:
530
+ server_args.decode_attention_backend = (
531
+ "flashinfer" if is_flashinfer_available() else "triton"
532
+ )
533
+ else:
534
+ server_args.decode_attention_backend = (
535
+ "flashinfer" if is_sm100_supported() else "triton"
536
+ )
537
+ elif server_args.decode_attention_backend == "fa3":
538
+ server_args.hicache_io_backend = "direct"
539
+ logger.warning(
540
+ "FlashAttention3 decode backend is not compatible with hierarchical cache. "
541
+ f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
542
+ )
543
+
503
544
  def init_torch_distributed(self):
504
545
  logger.info("Init torch distributed begin.")
505
546
 
@@ -871,8 +912,18 @@ class ModelRunner:
871
912
  named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
872
913
  load_format: Optional[str] = None,
873
914
  ):
915
+ monkey_patch_torch_reductions()
916
+ if load_format == "flattened_bucket":
917
+ # Handle flattened bucket format
918
+ return self._update_weights_from_flattened_bucket(
919
+ flattened_tensor_bucket_dict=named_tensors
920
+ )
921
+
922
+ # We need to get device after patch otherwise the device would be wrong
923
+ infered_device = torch.cuda.current_device()
924
+
874
925
  named_tensors = [
875
- (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
926
+ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
876
927
  for name, tensor in named_tensors
877
928
  ]
878
929
  if load_format == "direct":
@@ -886,6 +937,38 @@ class ModelRunner:
886
937
  raise NotImplementedError(f"Unknown load_format={load_format}")
887
938
  return True, "Success"
888
939
 
940
+ def _update_weights_from_flattened_bucket(
941
+ self,
942
+ flattened_tensor_bucket_dict,
943
+ ):
944
+ """Handle flattened bucket format for weight updates"""
945
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
946
+ metadata = flattened_tensor_bucket_dict["metadata"]
947
+
948
+ # Convert metadata dict to our format
949
+ converted_metadata = []
950
+ for meta in metadata:
951
+ converted_meta = FlattenedTensorMetadata(
952
+ name=meta.name,
953
+ shape=meta.shape,
954
+ dtype=meta.dtype,
955
+ start_idx=meta.start_idx,
956
+ end_idx=meta.end_idx,
957
+ numel=meta.numel,
958
+ )
959
+ converted_metadata.append(converted_meta)
960
+
961
+ # Create bucket and reconstruct tensors
962
+ bucket = FlattenedTensorBucket(
963
+ flattened_tensor=flattened_tensor, metadata=converted_metadata
964
+ )
965
+ reconstructed_tensors = bucket.reconstruct_tensors()
966
+
967
+ # Load the reconstructed tensors using the standard method
968
+ self.model.load_weights(reconstructed_tensors)
969
+
970
+ return True, "Success"
971
+
889
972
  def get_weights_by_name(
890
973
  self, name: str, truncate_size: int = 100
891
974
  ) -> Optional[torch.Tensor]:
@@ -1181,30 +1264,33 @@ class ModelRunner:
1181
1264
  # Draft worker shares req_to_token_pool with the target worker.
1182
1265
  assert self.is_draft_worker
1183
1266
 
1184
- if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
1185
- self.token_to_kv_pool = AscendTokenToKVPool(
1186
- self.max_total_num_tokens,
1187
- page_size=self.page_size,
1188
- dtype=self.kv_cache_dtype,
1189
- head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
1190
- head_dim=self.model_config.head_dim,
1191
- layer_num=self.model_config.num_hidden_layers,
1192
- device=self.device,
1193
- enable_memory_saver=self.server_args.enable_memory_saver,
1194
- )
1195
- elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
1196
- self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
1197
- self.max_total_num_tokens,
1198
- page_size=self.page_size,
1199
- dtype=self.kv_cache_dtype,
1200
- kv_lora_rank=self.model_config.kv_lora_rank,
1201
- qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1202
- layer_num=self.num_effective_layers,
1203
- device=self.device,
1204
- enable_memory_saver=self.server_args.enable_memory_saver,
1205
- start_layer=self.start_layer,
1206
- end_layer=self.end_layer,
1207
- )
1267
+ if self.server_args.attention_backend == "ascend":
1268
+ if self.use_mla_backend:
1269
+ self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
1270
+ self.max_total_num_tokens,
1271
+ page_size=self.page_size,
1272
+ dtype=self.kv_cache_dtype,
1273
+ kv_lora_rank=self.model_config.kv_lora_rank,
1274
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
1275
+ layer_num=self.num_effective_layers,
1276
+ device=self.device,
1277
+ enable_memory_saver=self.server_args.enable_memory_saver,
1278
+ start_layer=self.start_layer,
1279
+ end_layer=self.end_layer,
1280
+ )
1281
+ else:
1282
+ self.token_to_kv_pool = AscendTokenToKVPool(
1283
+ self.max_total_num_tokens,
1284
+ page_size=self.page_size,
1285
+ dtype=self.kv_cache_dtype,
1286
+ head_num=self.model_config.get_num_kv_heads(
1287
+ get_attention_tp_size()
1288
+ ),
1289
+ head_dim=self.model_config.head_dim,
1290
+ layer_num=self.model_config.num_hidden_layers,
1291
+ device=self.device,
1292
+ enable_memory_saver=self.server_args.enable_memory_saver,
1293
+ )
1208
1294
  elif self.use_mla_backend:
1209
1295
  self.token_to_kv_pool = MLATokenToKVPool(
1210
1296
  self.max_total_num_tokens,
@@ -1263,6 +1349,7 @@ class ModelRunner:
1263
1349
  end_layer=self.end_layer,
1264
1350
  )
1265
1351
 
1352
+ need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
1266
1353
  if self.token_to_kv_pool_allocator is None:
1267
1354
  if self.page_size == 1:
1268
1355
  if self.is_hybrid:
@@ -1272,6 +1359,7 @@ class ModelRunner:
1272
1359
  dtype=self.kv_cache_dtype,
1273
1360
  device=self.device,
1274
1361
  kvcache=self.token_to_kv_pool,
1362
+ need_sort=need_sort,
1275
1363
  )
1276
1364
  else:
1277
1365
  self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
@@ -1279,23 +1367,26 @@ class ModelRunner:
1279
1367
  dtype=self.kv_cache_dtype,
1280
1368
  device=self.device,
1281
1369
  kvcache=self.token_to_kv_pool,
1370
+ need_sort=need_sort,
1282
1371
  )
1283
1372
  else:
1284
- if _is_npu:
1285
- self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1373
+ if not _is_npu:
1374
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1286
1375
  self.max_total_num_tokens,
1287
1376
  page_size=self.page_size,
1288
1377
  dtype=self.kv_cache_dtype,
1289
1378
  device=self.device,
1290
1379
  kvcache=self.token_to_kv_pool,
1380
+ need_sort=need_sort,
1291
1381
  )
1292
1382
  else:
1293
- self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
1383
+ self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
1294
1384
  self.max_total_num_tokens,
1295
1385
  page_size=self.page_size,
1296
1386
  dtype=self.kv_cache_dtype,
1297
1387
  device=self.device,
1298
1388
  kvcache=self.token_to_kv_pool,
1389
+ need_sort=need_sort,
1299
1390
  )
1300
1391
  else:
1301
1392
  assert self.is_draft_worker
@@ -1396,6 +1487,10 @@ class ModelRunner:
1396
1487
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1397
1488
 
1398
1489
  return AiterAttnBackend(self)
1490
+ elif self.server_args.attention_backend == "wave":
1491
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
1492
+
1493
+ return WaveAttnBackend(self)
1399
1494
  elif backend_str == "ascend":
1400
1495
  from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1401
1496
 
@@ -1785,11 +1880,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
1785
1880
  default_weight_loader(params_dict[name], tensor)
1786
1881
 
1787
1882
 
1788
- def _unwrap_tensor(tensor, tp_rank):
1883
+ def _unwrap_tensor(tensor, tp_rank, device):
1789
1884
  if isinstance(tensor, LocalSerializedTensor):
1790
- monkey_patch_torch_reductions()
1791
1885
  tensor = tensor.get(tp_rank)
1792
- return tensor.to(torch.cuda.current_device())
1886
+ return tensor.to(device)
1793
1887
 
1794
1888
 
1795
1889
  @dataclass
@@ -162,12 +162,24 @@ def _initialize_model(
162
162
  model_class, _ = get_model_architecture(model_config)
163
163
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
164
164
  if _is_npu:
165
- packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
166
- "q_a_proj",
167
- "kv_a_proj_with_mqa",
168
- ]
169
- packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
170
- packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
165
+ packed_modules_mapping.update(
166
+ {
167
+ "visual": {"qkv_proj": ["qkv"]},
168
+ "vision_model": {
169
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
170
+ "proj": ["out_proj"],
171
+ },
172
+ "model": {
173
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
174
+ "gate_up_proj": ["gate_proj", "up_proj"],
175
+ "fused_qkv_a_proj_with_mqa": [
176
+ "q_a_proj",
177
+ "kv_a_proj_with_mqa",
178
+ ],
179
+ },
180
+ }
181
+ )
182
+
171
183
  quant_config = _get_quantization_config(
172
184
  model_config, load_config, packed_modules_mapping
173
185
  )
@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
212
212
  self,
213
213
  x,
214
214
  forward_batch=None,
215
- can_fuse_mlp_allreduce: bool = False,
215
+ should_allreduce_fusion: bool = False,
216
216
  use_reduce_scatter: bool = False,
217
217
  ):
218
218
  if (self.tp_size == 1) and x.shape[0] == 0:
@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
221
221
  gate_up, _ = self.gate_up_proj(x)
222
222
  x = self.act_fn(gate_up)
223
223
  x, _ = self.down_proj(
224
- x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
224
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
225
225
  )
226
226
  return x
227
227
 
@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
448
448
  self,
449
449
  hidden_states: torch.Tensor,
450
450
  forward_batch: Optional[ForwardBatch] = None,
451
- can_fuse_mlp_allreduce: bool = False,
451
+ should_allreduce_fusion: bool = False,
452
452
  use_reduce_scatter: bool = False,
453
453
  ) -> torch.Tensor:
454
454
  if not self._enable_deepep_moe:
@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
459
459
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
460
460
  ):
461
461
  return self.forward_normal_dual_stream(
462
- hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
462
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
463
463
  )
464
464
  else:
465
465
  return self.forward_normal(
466
- hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
466
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
467
467
  )
468
468
  else:
469
469
  return self.forward_deepep(hidden_states, forward_batch)
@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
471
471
  def forward_normal_dual_stream(
472
472
  self,
473
473
  hidden_states: torch.Tensor,
474
- can_fuse_mlp_allreduce: bool = False,
474
+ should_allreduce_fusion: bool = False,
475
475
  use_reduce_scatter: bool = False,
476
476
  ) -> torch.Tensor:
477
477
 
@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
500
500
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
501
501
  final_hidden_states = final_hidden_states_out
502
502
  sm.tag(final_hidden_states)
503
- if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
503
+ if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
504
504
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
505
505
  return final_hidden_states
506
506
 
507
507
  def forward_normal(
508
508
  self,
509
509
  hidden_states: torch.Tensor,
510
- can_fuse_mlp_allreduce: bool = False,
510
+ should_allreduce_fusion: bool = False,
511
511
  use_reduce_scatter: bool = False,
512
512
  ) -> torch.Tensor:
513
513
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
514
514
  self.shared_experts.gate_up_proj
515
515
  ):
516
- return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
516
+ return self.forward_cpu(hidden_states, should_allreduce_fusion)
517
517
 
518
518
  shared_output = self._forward_shared_experts(hidden_states)
519
519
  # router_logits: (num_tokens, n_experts)
@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
537
537
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
538
538
  final_hidden_states = final_hidden_states_out
539
539
  sm.tag(final_hidden_states)
540
- if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
540
+ if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
541
541
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
542
542
  return final_hidden_states
543
543
 
544
544
  def forward_cpu(
545
- self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
545
+ self,
546
+ hidden_states: torch.Tensor,
547
+ should_allreduce_fusion: bool = False,
546
548
  ) -> torch.Tensor:
547
549
  # router_logits: (num_tokens, n_experts)
548
550
  router_logits = self.gate(hidden_states)
@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
593
595
  None, # a2_scale
594
596
  True, # is_vnni
595
597
  )
596
- if self.tp_size > 1 and not can_fuse_mlp_allreduce:
598
+ if self.tp_size > 1 and not should_allreduce_fusion:
597
599
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
598
600
  return final_hidden_states
599
601
 
@@ -1194,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
1194
1196
  output, _ = self.o_proj(attn_output)
1195
1197
  return output
1196
1198
 
1199
+ def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
1200
+ """
1201
+ Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
1202
+ """
1203
+ return (
1204
+ self.current_attention_backend == "trtllm_mla"
1205
+ and forward_batch.forward_mode.is_decode_or_idle()
1206
+ and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1207
+ )
1208
+
1197
1209
  def forward_absorb_prepare(
1198
1210
  self,
1199
1211
  positions: torch.Tensor,
@@ -1273,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1273
1285
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1274
1286
 
1275
1287
  q_nope_out = q_nope_out.transpose(0, 1)
1276
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1288
+
1289
+ if not self._fuse_rope_for_trtllm_mla(forward_batch):
1290
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1277
1291
 
1278
1292
  return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1279
1293
 
@@ -1286,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
1286
1300
  or self.current_attention_backend == "cutlass_mla"
1287
1301
  or self.current_attention_backend == "trtllm_mla"
1288
1302
  ):
1303
+ extra_args = {}
1304
+ if self._fuse_rope_for_trtllm_mla(forward_batch):
1305
+ extra_args = {
1306
+ "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1307
+ "is_neox": self.rotary_emb.is_neox_style,
1308
+ }
1289
1309
  attn_output = self.attn_mqa(
1290
- q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
1310
+ q_nope_out,
1311
+ k_nope,
1312
+ k_nope,
1313
+ forward_batch,
1314
+ q_rope=q_pe,
1315
+ k_rope=k_pe,
1316
+ **extra_args,
1291
1317
  )
1292
1318
  else:
1293
1319
  q = torch.cat([q_nope_out, q_pe], dim=-1)
@@ -1842,6 +1868,8 @@ class DeepseekV2DecoderLayer(nn.Module):
1842
1868
  allow_reduce_scatter=True,
1843
1869
  )
1844
1870
 
1871
+ self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
1872
+
1845
1873
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1846
1874
  return is_nextn or (
1847
1875
  self.config.n_routed_experts is not None
@@ -1850,27 +1878,18 @@ class DeepseekV2DecoderLayer(nn.Module):
1850
1878
  )
1851
1879
 
1852
1880
  def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1853
- """Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
1854
-
1855
- if (
1856
- self.layer_id == self.config.num_hidden_layers - 1
1857
- or get_tensor_model_parallel_world_size() <= 1
1858
- ):
1859
- return False
1860
-
1861
- if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
1862
- return False
1881
+ """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
1863
1882
 
1864
- if not _is_sm100_supported or not _is_flashinfer_available:
1865
- return False
1883
+ batch_size = (
1884
+ forward_batch.input_ids.shape[0]
1885
+ if hasattr(forward_batch, "input_ids")
1886
+ else 0
1887
+ )
1866
1888
 
1867
- if hasattr(forward_batch, "input_ids") and (
1868
- forward_batch.input_ids.shape[0] == 0
1869
- or forward_batch.input_ids.shape[0] > 128
1870
- ):
1889
+ if batch_size > 128:
1871
1890
  return False
1872
1891
 
1873
- return True
1892
+ return self._fuse_allreduce_lookup_table.get(batch_size, False)
1874
1893
 
1875
1894
  def forward(
1876
1895
  self,
@@ -1896,7 +1915,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1896
1915
  hidden_states, residual, forward_batch
1897
1916
  )
1898
1917
 
1899
- can_fuse_mlp_allreduce = (
1918
+ should_allreduce_fusion = (
1900
1919
  self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1901
1920
  and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1902
1921
  and not self.is_nextn
@@ -1907,13 +1926,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1907
1926
  forward_batch
1908
1927
  )
1909
1928
  hidden_states = self.mlp(
1910
- hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
1929
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
1911
1930
  )
1912
1931
 
1913
- if can_fuse_mlp_allreduce:
1932
+ if should_allreduce_fusion:
1914
1933
  hidden_states._sglang_needs_allreduce_fusion = True
1915
1934
 
1916
- if not can_fuse_mlp_allreduce:
1935
+ if not should_allreduce_fusion:
1917
1936
  hidden_states, residual = self.layer_communicator.postprocess_layer(
1918
1937
  hidden_states, residual, forward_batch
1919
1938
  )
@@ -1990,6 +2009,26 @@ class DeepseekV2DecoderLayer(nn.Module):
1990
2009
  )
1991
2010
  return output
1992
2011
 
2012
+ def _build_fuse_allreduce_lookup_table(self):
2013
+ static_conditions_met = (
2014
+ self.layer_id != self.config.num_hidden_layers - 1
2015
+ and get_tensor_model_parallel_world_size() > 1
2016
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
2017
+ and _is_sm100_supported
2018
+ and _is_flashinfer_available
2019
+ )
2020
+
2021
+ if not static_conditions_met:
2022
+ return {}
2023
+
2024
+ lookup_table = {}
2025
+ for batch_size in range(129): # 0 to 128
2026
+ is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
2027
+ should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
2028
+ lookup_table[batch_size] = should_fuse
2029
+
2030
+ return lookup_table
2031
+
1993
2032
 
1994
2033
  class DeepseekV2Model(nn.Module):
1995
2034
  fall_back_to_pt_during_load = False
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
432
432
 
433
433
  return result
434
434
 
435
- def get_hidden_dim(self, module_name):
436
- # return input_dim, output_dim
437
- if module_name in ["q_proj", "qkv_proj"]:
438
- return (
439
- self.config.hidden_size,
440
- self.config.head_dim * self.config.num_attention_heads,
441
- )
442
- elif module_name in ["o_proj"]:
443
- return (
444
- self.config.head_dim * self.config.num_attention_heads,
445
- self.config.hidden_size,
446
- )
447
- elif module_name in ["kv_proj"]:
448
- return (
449
- self.config.hidden_size,
450
- self.config.head_dim * self.config.num_key_value_heads,
451
- )
452
- elif module_name == "gate_up_proj":
453
- return self.config.hidden_size, self.config.intermediate_size
454
- elif module_name == "down_proj":
455
- return self.config.intermediate_size, self.config.hidden_size
456
- else:
457
- raise NotImplementedError()
458
-
459
- def get_module_name(self, name):
460
- params_mapping = {
461
- "q_proj": "qkv_proj",
462
- "k_proj": "qkv_proj",
463
- "v_proj": "qkv_proj",
464
- "gate_proj": "gate_up_proj",
465
- "up_proj": "gate_up_proj",
466
- }
467
- return params_mapping.get(name, name)
468
-
469
435
  def get_attention_sliding_window_size(self):
470
436
  return get_attention_sliding_window_size(self.config)
471
437
 
@@ -501,27 +501,26 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
501
501
 
502
502
  def get_hidden_dim(self, module_name):
503
503
  # return input_dim, output_dim
504
- if module_name in ["q_proj", "qkv_proj"]:
504
+ if module_name == "qkv_proj":
505
505
  return (
506
506
  self.config.hidden_size,
507
- self.config.head_dim * self.config.num_attention_heads,
507
+ self.config.head_dim
508
+ * (
509
+ self.config.num_attention_heads
510
+ + self.config.num_key_value_heads * 2
511
+ ),
508
512
  )
509
- elif module_name in ["o_proj"]:
513
+ elif module_name == "o_proj":
510
514
  return (
511
515
  self.config.head_dim * self.config.num_attention_heads,
512
516
  self.config.hidden_size,
513
517
  )
514
- elif module_name in ["kv_proj"]:
515
- return (
516
- self.config.hidden_size,
517
- self.config.head_dim * self.config.num_key_value_heads,
518
- )
519
518
  elif module_name == "gate_up_proj":
520
519
  assert len(set(self.config.intermediate_size)) == 1, (
521
520
  "Currently SGLang requires uniform intermediate size for all layers. "
522
521
  "Please file an issue if you need support for non-uniform intermediate sizes."
523
522
  )
524
- return self.config.hidden_size, self.config.intermediate_size[0]
523
+ return self.config.hidden_size, self.config.intermediate_size[0] * 2
525
524
  elif module_name == "down_proj":
526
525
  assert len(set(self.config.intermediate_size)) == 1, (
527
526
  "Currently SGLang requires uniform intermediate size for all layers. "
sglang/srt/models/glm4.py CHANGED
@@ -218,6 +218,12 @@ class Glm4Model(nn.Module):
218
218
 
219
219
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
220
220
 
221
+ def get_input_embeddings(self) -> nn.Embedding:
222
+ return self.embed_tokens
223
+
224
+ def dtype(self) -> torch.dtype:
225
+ return next(self.parameters()).dtype
226
+
221
227
  @torch.no_grad()
222
228
  def forward(
223
229
  self,