sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
18
 
19
+ import concurrent.futures
19
20
  import logging
20
21
  import os
21
22
  from enum import IntEnum, auto
@@ -57,7 +58,7 @@ from sglang.srt.layers.linear import (
57
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
58
59
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
59
60
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
60
- from sglang.srt.layers.moe.topk import select_experts
61
+ from sglang.srt.layers.moe.topk import TopK
61
62
  from sglang.srt.layers.quantization import deep_gemm_wrapper
62
63
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
63
64
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -126,6 +127,10 @@ if _is_cuda:
126
127
  )
127
128
  elif _is_cpu and _is_cpu_amx_available:
128
129
  pass
130
+ elif _is_hip:
131
+ from sglang.srt.layers.quantization.awq_triton import (
132
+ awq_dequantize_triton as awq_dequantize,
133
+ )
129
134
  else:
130
135
  from vllm._custom_ops import awq_dequantize
131
136
 
@@ -224,7 +229,7 @@ class MoEGate(nn.Module):
224
229
  )
225
230
  if config.topk_method == "noaux_tc":
226
231
  self.e_score_correction_bias = nn.Parameter(
227
- torch.empty((config.n_routed_experts))
232
+ torch.empty((config.n_routed_experts), dtype=torch.float32)
228
233
  )
229
234
  else:
230
235
  self.e_score_correction_bias = None
@@ -249,9 +254,8 @@ class MoEGate(nn.Module):
249
254
  and self.weight.shape[0] == 256
250
255
  and _device_sm >= 90
251
256
  ):
252
- logits = dsv3_router_gemm(hidden_states, self.weight).to(
253
- hidden_states.dtype
254
- )
257
+ # router gemm output float32
258
+ logits = dsv3_router_gemm(hidden_states, self.weight)
255
259
  else:
256
260
  logits = F.linear(hidden_states, self.weight, None)
257
261
 
@@ -298,6 +302,17 @@ class DeepseekV2MoE(nn.Module):
298
302
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
299
303
  )
300
304
 
305
+ self.topk = TopK(
306
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
307
+ renormalize=config.norm_topk_prob,
308
+ use_grouped_topk=True,
309
+ num_expert_group=config.n_group,
310
+ num_fused_shared_experts=self.num_fused_shared_experts,
311
+ topk_group=config.topk_group,
312
+ correction_bias=self.gate.e_score_correction_bias,
313
+ routed_scaling_factor=self.routed_scaling_factor,
314
+ )
315
+
301
316
  self.experts = get_moe_impl_class()(
302
317
  num_experts=config.n_routed_experts
303
318
  + self.num_fused_shared_experts
@@ -306,13 +321,7 @@ class DeepseekV2MoE(nn.Module):
306
321
  hidden_size=config.hidden_size,
307
322
  intermediate_size=config.moe_intermediate_size,
308
323
  layer_id=self.layer_id,
309
- renormalize=config.norm_topk_prob,
310
324
  quant_config=quant_config,
311
- use_grouped_topk=True,
312
- num_expert_group=config.n_group,
313
- num_fused_shared_experts=self.num_fused_shared_experts,
314
- topk_group=config.topk_group,
315
- correction_bias=self.gate.e_score_correction_bias,
316
325
  routed_scaling_factor=self.routed_scaling_factor,
317
326
  prefix=add_prefix("experts", prefix),
318
327
  **(
@@ -354,6 +363,7 @@ class DeepseekV2MoE(nn.Module):
354
363
  self.shared_experts.gate_up_proj.quant_method, "quant_config"
355
364
  ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
356
365
  "awq",
366
+ "awq_marlin",
357
367
  "moe_wna16",
358
368
  }
359
369
  self.shared_experts_is_int8 = (
@@ -437,21 +447,22 @@ class DeepseekV2MoE(nn.Module):
437
447
  def forward_normal_dual_stream(
438
448
  self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
439
449
  ) -> torch.Tensor:
440
- # router_logits: (num_tokens, n_experts)
441
- router_logits = self.gate(hidden_states)
442
450
 
443
451
  current_stream = torch.cuda.current_stream()
444
452
  self.alt_stream.wait_stream(current_stream)
445
453
  shared_output = self._forward_shared_experts(hidden_states)
446
454
 
447
455
  with torch.cuda.stream(self.alt_stream):
456
+ # router_logits: (num_tokens, n_experts)
457
+ router_logits = self.gate(hidden_states)
458
+ topk_output = self.topk(hidden_states, router_logits)
448
459
  final_hidden_states = self.experts(
449
- hidden_states=hidden_states, router_logits=router_logits
460
+ hidden_states=hidden_states, topk_output=topk_output
450
461
  )
451
462
  if not _is_cuda:
452
463
  final_hidden_states *= self.routed_scaling_factor
453
464
  current_stream.wait_stream(self.alt_stream)
454
- final_hidden_states = final_hidden_states + shared_output
465
+ final_hidden_states += shared_output
455
466
  if self.tp_size > 1 and not can_fuse_mlp_allreduce:
456
467
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
457
468
  return final_hidden_states
@@ -462,13 +473,14 @@ class DeepseekV2MoE(nn.Module):
462
473
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
463
474
  self.shared_experts.gate_up_proj
464
475
  ):
465
- return self.forward_cpu(hidden_states)
476
+ return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
466
477
 
467
478
  shared_output = self._forward_shared_experts(hidden_states)
468
479
  # router_logits: (num_tokens, n_experts)
469
480
  router_logits = self.gate(hidden_states)
481
+ topk_output = self.topk(hidden_states, router_logits)
470
482
  final_hidden_states = self.experts(
471
- hidden_states=hidden_states, router_logits=router_logits
483
+ hidden_states=hidden_states, topk_output=topk_output
472
484
  )
473
485
  if not _is_cuda and not _use_aiter:
474
486
  # fused in biased_grouped_topk so we can skip here
@@ -479,11 +491,14 @@ class DeepseekV2MoE(nn.Module):
479
491
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
480
492
  return final_hidden_states
481
493
 
482
- def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
494
+ def forward_cpu(
495
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
496
+ ) -> torch.Tensor:
483
497
  # router_logits: (num_tokens, n_experts)
484
498
  router_logits = self.gate(hidden_states)
499
+ topk_output = self.topk(hidden_states, router_logits)
485
500
  fused_experts_out = self.experts(
486
- hidden_states=hidden_states, router_logits=router_logits
501
+ hidden_states=hidden_states, topk_output=topk_output
487
502
  )
488
503
 
489
504
  assert use_intel_amx_backend(
@@ -528,30 +543,21 @@ class DeepseekV2MoE(nn.Module):
528
543
  None, # a2_scale
529
544
  True, # is_vnni
530
545
  )
531
- if self.tp_size > 1 and not self.can_fuse_mlp_allreduce:
546
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
532
547
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
533
548
  return final_hidden_states
534
549
 
535
550
  def forward_deepep(
536
551
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
537
552
  ) -> torch.Tensor:
538
- forward_mode = forward_batch.forward_mode
539
553
  shared_output = None
540
- if is_non_idle_and_non_empty(forward_mode, hidden_states):
554
+ if hidden_states.shape[0] > 0:
541
555
  # router_logits: (num_tokens, n_experts)
542
556
  router_logits = self.gate(hidden_states)
543
557
  shared_output = self._forward_shared_experts(hidden_states)
544
- topk_weights, topk_idx = select_experts(
545
- hidden_states=hidden_states,
546
- router_logits=router_logits,
547
- top_k=self.top_k,
548
- use_grouped_topk=True,
549
- renormalize=self.renormalize,
550
- topk_group=self.topk_group,
551
- num_expert_group=self.num_expert_group,
552
- num_fused_shared_experts=self.num_fused_shared_experts,
553
- correction_bias=self.correction_bias,
554
- routed_scaling_factor=self.routed_scaling_factor,
558
+ topk_weights, topk_idx, _ = self.topk(
559
+ hidden_states,
560
+ router_logits,
555
561
  num_token_non_padded=forward_batch.num_token_non_padded,
556
562
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
557
563
  layer_id=self.layer_id,
@@ -641,17 +647,9 @@ class DeepseekV2MoE(nn.Module):
641
647
  with get_global_expert_distribution_recorder().with_current_layer(
642
648
  self.layer_id
643
649
  ):
644
- state.topk_weights_local, state.topk_idx_local = select_experts(
650
+ state.topk_weights_local, state.topk_idx_local, _ = self.topk(
645
651
  hidden_states=hidden_states,
646
652
  router_logits=router_logits,
647
- top_k=self.top_k,
648
- use_grouped_topk=True,
649
- renormalize=self.renormalize,
650
- topk_group=self.topk_group,
651
- num_expert_group=self.num_expert_group,
652
- num_fused_shared_experts=self.num_fused_shared_experts,
653
- correction_bias=self.correction_bias,
654
- routed_scaling_factor=self.routed_scaling_factor,
655
653
  num_token_non_padded=state.forward_batch.num_token_non_padded,
656
654
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
657
655
  layer_id=self.layer_id,
@@ -926,7 +924,7 @@ class DeepseekV2AttentionMLA(nn.Module):
926
924
  has_fused_proj
927
925
  and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
928
926
  and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
929
- in {"awq", "moe_wna16"}
927
+ in {"awq", "awq_marlin", "moe_wna16"}
930
928
  )
931
929
  self.use_min_latency_fused_a_gemm = (
932
930
  has_fused_proj
@@ -1151,7 +1149,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1151
1149
  _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1152
1150
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1153
1151
  latent_cache = latent_cache.unsqueeze(1)
1154
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
1152
+ kv_a = self.kv_a_layernorm(kv_a)
1155
1153
  kv = self.kv_b_proj(kv_a)[0]
1156
1154
  kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1157
1155
  k_nope = kv[..., : self.qk_nope_head_dim]
@@ -1690,7 +1688,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1690
1688
  _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1691
1689
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1692
1690
  latent_cache = latent_cache.unsqueeze(1)
1693
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
1691
+ kv_a = self.kv_a_layernorm(kv_a)
1694
1692
  kv = self.kv_b_proj(kv_a)[0]
1695
1693
  kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1696
1694
  k_nope = kv[..., : self.qk_nope_head_dim]
@@ -2172,7 +2170,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2172
2170
  )
2173
2171
  if hasattr(self_attn.kv_b_proj, "qweight"):
2174
2172
  # AWQ compatible
2175
- if _is_cuda:
2173
+ if _is_cuda or _is_hip:
2176
2174
  w = awq_dequantize(
2177
2175
  self_attn.kv_b_proj.qweight,
2178
2176
  self_attn.kv_b_proj.scales,
@@ -2434,154 +2432,175 @@ class DeepseekV2ForCausalLM(nn.Module):
2434
2432
  assert self.num_fused_shared_experts == 1
2435
2433
  log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2436
2434
 
2437
- params_dict = dict(self.named_parameters())
2438
- weight_names = []
2439
- for name, loaded_weight in weights:
2440
- if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2441
- name = name.replace(
2442
- "mlp.shared_experts",
2443
- f"mlp.experts.{self.config.n_routed_experts}",
2444
- )
2435
+ with concurrent.futures.ThreadPoolExecutor() as executor:
2436
+ futures = []
2437
+ params_dict = dict(self.named_parameters())
2438
+ weight_names = []
2439
+ for name, loaded_weight in weights:
2440
+ if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2441
+ name = name.replace(
2442
+ "mlp.shared_experts",
2443
+ f"mlp.experts.{self.config.n_routed_experts}",
2444
+ )
2445
2445
 
2446
- weight_names.append(name)
2446
+ weight_names.append(name)
2447
2447
 
2448
- if not is_nextn:
2449
- if hasattr(self.config, "num_nextn_predict_layers"):
2450
- num_nextn_layers = self.config.num_nextn_predict_layers
2451
- if num_nextn_layers > 0 and name.startswith("model.layers"):
2452
- name_list = name.split(".")
2453
- if (
2454
- len(name_list) >= 3
2455
- and int(name_list[2]) >= self.config.num_hidden_layers
2456
- ):
2457
- continue
2458
- else:
2459
- if not name.startswith(nextn_layer_prefix):
2460
- continue
2448
+ if not is_nextn:
2449
+ if hasattr(self.config, "num_nextn_predict_layers"):
2450
+ num_nextn_layers = self.config.num_nextn_predict_layers
2451
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
2452
+ name_list = name.split(".")
2453
+ if (
2454
+ len(name_list) >= 3
2455
+ and int(name_list[2]) >= self.config.num_hidden_layers
2456
+ ):
2457
+ continue
2458
+ else:
2459
+ if not name.startswith(nextn_layer_prefix):
2460
+ continue
2461
2461
 
2462
- # Use shared head and embed weights from target model
2463
- if "shared_head.head" in name or "embed_tokens" in name:
2464
- continue
2462
+ # Use shared head and embed weights from target model
2463
+ if "shared_head.head" in name or "embed_tokens" in name:
2464
+ continue
2465
2465
 
2466
- is_decoder = True
2467
- # For nextn specific weights
2468
- for weight_name in nextn_spec_weight_names:
2469
- if weight_name in name:
2470
- name = name.replace(nextn_layer_prefix, "model")
2471
- is_decoder = False
2472
- break
2473
- # For decoder layer weights
2474
- if is_decoder:
2475
- name = name.replace(nextn_layer_prefix, "model.decoder")
2476
-
2477
- if "rotary_emb.inv_freq" in name:
2478
- continue
2479
- for param_name, weight_name, shard_id in stacked_params_mapping:
2480
- # Skip non-stacked layers and experts (experts handled below).
2481
- if weight_name not in name:
2482
- continue
2483
- # We have mlp.experts[0].gate_proj in the checkpoint.
2484
- # Since we handle the experts below in expert_params_mapping,
2485
- # we need to skip here BEFORE we update the name, otherwise
2486
- # name will be updated to mlp.experts[0].gate_up_proj, which
2487
- # will then be updated below in expert_params_mapping
2488
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
2489
- if ("mlp.experts." in name) and name not in params_dict:
2490
- continue
2491
- name = name.replace(weight_name, param_name)
2492
- # Skip loading extra bias for GPTQ models.
2493
- if name.endswith(".bias") and name not in params_dict:
2466
+ is_decoder = True
2467
+ # For nextn specific weights
2468
+ for weight_name in nextn_spec_weight_names:
2469
+ if weight_name in name:
2470
+ name = name.replace(nextn_layer_prefix, "model")
2471
+ is_decoder = False
2472
+ break
2473
+ # For decoder layer weights
2474
+ if is_decoder:
2475
+ name = name.replace(nextn_layer_prefix, "model.decoder")
2476
+
2477
+ if "rotary_emb.inv_freq" in name:
2494
2478
  continue
2495
- param = params_dict[name]
2496
- weight_loader = param.weight_loader
2497
- weight_loader(param, loaded_weight, shard_id)
2498
- break
2499
- else:
2500
- for mapping in expert_params_mapping:
2501
- param_name, weight_name, expert_id, shard_id = mapping
2479
+ for param_name, weight_name, shard_id in stacked_params_mapping:
2480
+ # Skip non-stacked layers and experts (experts handled below).
2502
2481
  if weight_name not in name:
2503
2482
  continue
2483
+ # We have mlp.experts[0].gate_proj in the checkpoint.
2484
+ # Since we handle the experts below in expert_params_mapping,
2485
+ # we need to skip here BEFORE we update the name, otherwise
2486
+ # name will be updated to mlp.experts[0].gate_up_proj, which
2487
+ # will then be updated below in expert_params_mapping
2488
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
2489
+ if ("mlp.experts." in name) and name not in params_dict:
2490
+ continue
2504
2491
  name = name.replace(weight_name, param_name)
2492
+ # Skip loading extra bias for GPTQ models.
2493
+ if name.endswith(".bias") and name not in params_dict:
2494
+ continue
2505
2495
  param = params_dict[name]
2506
2496
  weight_loader = param.weight_loader
2507
- weight_loader(
2508
- param,
2509
- loaded_weight,
2510
- name,
2511
- shard_id=shard_id,
2512
- expert_id=expert_id,
2497
+ futures.append(
2498
+ executor.submit(weight_loader, param, loaded_weight, shard_id)
2513
2499
  )
2514
2500
  break
2515
2501
  else:
2516
- # Skip loading extra bias for GPTQ models.
2517
- if name.endswith(".bias") and name not in params_dict:
2518
- continue
2519
- if fuse_qkv_a_proj and (
2520
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2521
- ):
2522
- cached_a_proj[name] = loaded_weight
2523
- q_a_proj_name = (
2524
- name
2525
- if "q_a_proj" in name
2526
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
2527
- )
2528
- kv_a_proj_name = (
2529
- name
2530
- if "kv_a_proj_with_mqa" in name
2531
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2502
+ for mapping in expert_params_mapping:
2503
+ param_name, weight_name, expert_id, shard_id = mapping
2504
+ if weight_name not in name:
2505
+ continue
2506
+ name = name.replace(weight_name, param_name)
2507
+ param = params_dict[name]
2508
+ weight_loader = param.weight_loader
2509
+ futures.append(
2510
+ executor.submit(
2511
+ weight_loader,
2512
+ param,
2513
+ loaded_weight,
2514
+ name,
2515
+ shard_id=shard_id,
2516
+ expert_id=expert_id,
2517
+ )
2532
2518
  )
2533
-
2534
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
2535
- if (
2536
- q_a_proj_name in cached_a_proj
2537
- and kv_a_proj_name in cached_a_proj
2519
+ break
2520
+ else:
2521
+ # Skip loading extra bias for GPTQ models.
2522
+ if name.endswith(".bias") and name not in params_dict:
2523
+ continue
2524
+ if fuse_qkv_a_proj and (
2525
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2538
2526
  ):
2539
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
2540
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2541
- cat_dim = 0
2542
- if self.quant_config is not None and (
2543
- self.quant_config.get_name() == "awq"
2544
- or self.quant_config.get_name() == "moe_wna16"
2545
- ):
2546
- cat_dim = 1
2547
- fused_weight = torch.cat(
2548
- [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2549
- )
2550
- param_name = (
2551
- name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
2527
+ cached_a_proj[name] = loaded_weight
2528
+ q_a_proj_name = (
2529
+ name
2552
2530
  if "q_a_proj" in name
2553
- else name.replace(
2554
- "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
2555
- )
2531
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
2532
+ )
2533
+ kv_a_proj_name = (
2534
+ name
2535
+ if "kv_a_proj_with_mqa" in name
2536
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2556
2537
  )
2557
- param = params_dict[param_name]
2558
2538
 
2539
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
2540
+ if (
2541
+ q_a_proj_name in cached_a_proj
2542
+ and kv_a_proj_name in cached_a_proj
2543
+ ):
2544
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
2545
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2546
+ cat_dim = 0
2547
+ if self.quant_config is not None and (
2548
+ self.quant_config.get_name() == "awq"
2549
+ or self.quant_config.get_name() == "awq_marlin"
2550
+ or self.quant_config.get_name() == "moe_wna16"
2551
+ ):
2552
+ cat_dim = 1
2553
+ fused_weight = torch.cat(
2554
+ [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2555
+ )
2556
+ param_name = (
2557
+ name.replace(
2558
+ "q_a_proj", "fused_qkv_a_proj_with_mqa"
2559
+ )
2560
+ if "q_a_proj" in name
2561
+ else name.replace(
2562
+ "kv_a_proj_with_mqa",
2563
+ "fused_qkv_a_proj_with_mqa",
2564
+ )
2565
+ )
2566
+ param = params_dict[param_name]
2567
+
2568
+ weight_loader = getattr(
2569
+ param, "weight_loader", default_weight_loader
2570
+ )
2571
+ futures.append(
2572
+ executor.submit(weight_loader, param, fused_weight)
2573
+ )
2574
+ cached_a_proj.pop(q_a_proj_name)
2575
+ cached_a_proj.pop(kv_a_proj_name)
2576
+ else:
2577
+ if (
2578
+ "k_scale" in name or "v_scale" in name
2579
+ ) and name not in params_dict:
2580
+ # modelopt attn kv scale is named differently
2581
+ for scale in ["k_scale", "v_scale"]:
2582
+ if scale in name:
2583
+ name = name.replace(
2584
+ f"{scale[0]}_proj", "attn_mqa"
2585
+ )
2586
+ break
2587
+ if name not in params_dict:
2588
+ # modelopt ckpt contains not needed weights for MTP module:
2589
+ # model.decoder.self_attn.attn_mqa.v_scale and
2590
+ # model.decoder.self_attn.attn_mqa.k_scale
2591
+ logger.warning(f"{name} not found in params_dict.")
2592
+ continue
2593
+ param = params_dict[name]
2559
2594
  weight_loader = getattr(
2560
2595
  param, "weight_loader", default_weight_loader
2561
2596
  )
2562
- weight_loader(param, fused_weight)
2563
- cached_a_proj.pop(q_a_proj_name)
2564
- cached_a_proj.pop(kv_a_proj_name)
2565
- else:
2566
- if (
2567
- "k_scale" in name or "v_scale" in name
2568
- ) and name not in params_dict:
2569
- # modelopt attn kv scale is named differently
2570
- for scale in ["k_scale", "v_scale"]:
2571
- if scale in name:
2572
- name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2573
- break
2574
- if name not in params_dict:
2575
- # modelopt ckpt contains not needed weights for MTP module:
2576
- # model.decoder.self_attn.attn_mqa.v_scale and
2577
- # model.decoder.self_attn.attn_mqa.k_scale
2578
- logger.warning(f"{name} not found in params_dict.")
2579
- continue
2580
- param = params_dict[name]
2581
- weight_loader = getattr(
2582
- param, "weight_loader", default_weight_loader
2583
- )
2584
- weight_loader(param, loaded_weight)
2597
+ futures.append(
2598
+ executor.submit(weight_loader, param, loaded_weight)
2599
+ )
2600
+
2601
+ # Wait for all tasks to complete and raise any exceptions.
2602
+ for future in concurrent.futures.as_completed(futures):
2603
+ future.result()
2585
2604
 
2586
2605
  self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
2587
2606
 
@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
260
260
  def get_image_feature(self, items: List[MultimodalDataItem]):
261
261
 
262
262
  images_spatial_crop = torch.cat(
263
- [item.image_spatial_crop for item in items], dim=0
263
+ [item.images_spatial_crop for item in items], dim=0
264
264
  )
265
265
 
266
266
  assert images_spatial_crop.dim() == 3
@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
268
268
  # TODO: can it be batched ?
269
269
  images_in_this_batch = []
270
270
  for item in items:
271
- assert item.pixel_values.dim() == 4
271
+ assert item.feature.dim() == 4
272
272
  image_feature = self.vision.forward_features(
273
- item.pixel_values.type(next(self.vision.parameters()).dtype).to(
273
+ item.feature.type(next(self.vision.parameters()).dtype).to(
274
274
  device=next(self.vision.parameters()).device
275
275
  )
276
276
  )
@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
278
278
  _, hw, n_dim = images_embeds.shape
279
279
  h = w = int(hw**0.5)
280
280
  tile_index = 0
281
- for jdx in range(item.image_spatial_crop.shape[1]):
282
- num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
281
+ for jdx in range(item.images_spatial_crop.shape[1]):
282
+ num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
283
283
  if num_width_tiles == 0 or num_height_tiles == 0:
284
284
  break
285
285
  num_tiles_in_image = num_width_tiles * num_height_tiles
@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
318
318
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
319
319
  )
320
320
 
321
+ @torch.no_grad()
322
+ def forward_split_prefill(
323
+ self,
324
+ input_ids: torch.Tensor,
325
+ positions: torch.Tensor,
326
+ forward_batch: ForwardBatch,
327
+ split_interval: Tuple[int, int], # [start, end) 0-based
328
+ input_embeds: torch.Tensor = None,
329
+ ):
330
+ start, end = split_interval
331
+ # embed
332
+ if start == 0:
333
+ if input_embeds is None:
334
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
335
+ else:
336
+ forward_batch.hidden_states = input_embeds
337
+
338
+ # Normalize the embedding by sqrt(hidden_size)
339
+ forward_batch.hidden_states *= self.model.config.hidden_size**0.5
340
+
341
+ # decoder layer
342
+ for i in range(start, end):
343
+ layer = self.model.layers[i]
344
+ forward_batch.hidden_states, forward_batch.residual = layer(
345
+ positions,
346
+ forward_batch.hidden_states,
347
+ forward_batch,
348
+ forward_batch.residual,
349
+ )
350
+
351
+ if end == self.model.config.num_hidden_layers:
352
+ # norm
353
+ forward_batch.hidden_states, _ = self.model.norm(
354
+ forward_batch.hidden_states, forward_batch.residual
355
+ )
356
+
357
+ # logits process
358
+ result = self.logits_processor(
359
+ input_ids,
360
+ forward_batch.hidden_states,
361
+ self.model.embed_tokens,
362
+ forward_batch,
363
+ )
364
+ else:
365
+ result = None
366
+
367
+ return result
368
+
321
369
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
370
  stacked_params_mapping = [
323
371
  # (param_name, shard_name, shard_id)