sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
18
 
19
+ import concurrent.futures
19
20
  import logging
20
21
  import os
21
22
  from enum import IntEnum, auto
@@ -57,7 +58,7 @@ from sglang.srt.layers.linear import (
57
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
58
59
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
59
60
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
60
- from sglang.srt.layers.moe.topk import select_experts
61
+ from sglang.srt.layers.moe.topk import TopK
61
62
  from sglang.srt.layers.quantization import deep_gemm_wrapper
62
63
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
63
64
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -126,6 +127,10 @@ if _is_cuda:
126
127
  )
127
128
  elif _is_cpu and _is_cpu_amx_available:
128
129
  pass
130
+ elif _is_hip:
131
+ from sglang.srt.layers.quantization.awq_triton import (
132
+ awq_dequantize_triton as awq_dequantize,
133
+ )
129
134
  else:
130
135
  from vllm._custom_ops import awq_dequantize
131
136
 
@@ -224,7 +229,7 @@ class MoEGate(nn.Module):
224
229
  )
225
230
  if config.topk_method == "noaux_tc":
226
231
  self.e_score_correction_bias = nn.Parameter(
227
- torch.empty((config.n_routed_experts))
232
+ torch.empty((config.n_routed_experts), dtype=torch.float32)
228
233
  )
229
234
  else:
230
235
  self.e_score_correction_bias = None
@@ -249,9 +254,8 @@ class MoEGate(nn.Module):
249
254
  and self.weight.shape[0] == 256
250
255
  and _device_sm >= 90
251
256
  ):
252
- logits = dsv3_router_gemm(hidden_states, self.weight).to(
253
- hidden_states.dtype
254
- )
257
+ # router gemm output float32
258
+ logits = dsv3_router_gemm(hidden_states, self.weight)
255
259
  else:
256
260
  logits = F.linear(hidden_states, self.weight, None)
257
261
 
@@ -298,6 +302,17 @@ class DeepseekV2MoE(nn.Module):
298
302
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
299
303
  )
300
304
 
305
+ self.topk = TopK(
306
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
307
+ renormalize=config.norm_topk_prob,
308
+ use_grouped_topk=True,
309
+ num_expert_group=config.n_group,
310
+ num_fused_shared_experts=self.num_fused_shared_experts,
311
+ topk_group=config.topk_group,
312
+ correction_bias=self.gate.e_score_correction_bias,
313
+ routed_scaling_factor=self.routed_scaling_factor,
314
+ )
315
+
301
316
  self.experts = get_moe_impl_class()(
302
317
  num_experts=config.n_routed_experts
303
318
  + self.num_fused_shared_experts
@@ -306,13 +321,7 @@ class DeepseekV2MoE(nn.Module):
306
321
  hidden_size=config.hidden_size,
307
322
  intermediate_size=config.moe_intermediate_size,
308
323
  layer_id=self.layer_id,
309
- renormalize=config.norm_topk_prob,
310
324
  quant_config=quant_config,
311
- use_grouped_topk=True,
312
- num_expert_group=config.n_group,
313
- num_fused_shared_experts=self.num_fused_shared_experts,
314
- topk_group=config.topk_group,
315
- correction_bias=self.gate.e_score_correction_bias,
316
325
  routed_scaling_factor=self.routed_scaling_factor,
317
326
  prefix=add_prefix("experts", prefix),
318
327
  **(
@@ -354,6 +363,7 @@ class DeepseekV2MoE(nn.Module):
354
363
  self.shared_experts.gate_up_proj.quant_method, "quant_config"
355
364
  ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
356
365
  "awq",
366
+ "awq_marlin",
357
367
  "moe_wna16",
358
368
  }
359
369
  self.shared_experts_is_int8 = (
@@ -437,21 +447,22 @@ class DeepseekV2MoE(nn.Module):
437
447
  def forward_normal_dual_stream(
438
448
  self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
439
449
  ) -> torch.Tensor:
440
- # router_logits: (num_tokens, n_experts)
441
- router_logits = self.gate(hidden_states)
442
450
 
443
451
  current_stream = torch.cuda.current_stream()
444
452
  self.alt_stream.wait_stream(current_stream)
445
453
  shared_output = self._forward_shared_experts(hidden_states)
446
454
 
447
455
  with torch.cuda.stream(self.alt_stream):
456
+ # router_logits: (num_tokens, n_experts)
457
+ router_logits = self.gate(hidden_states)
458
+ topk_output = self.topk(hidden_states, router_logits)
448
459
  final_hidden_states = self.experts(
449
- hidden_states=hidden_states, router_logits=router_logits
460
+ hidden_states=hidden_states, topk_output=topk_output
450
461
  )
451
462
  if not _is_cuda:
452
463
  final_hidden_states *= self.routed_scaling_factor
453
464
  current_stream.wait_stream(self.alt_stream)
454
- final_hidden_states = final_hidden_states + shared_output
465
+ final_hidden_states += shared_output
455
466
  if self.tp_size > 1 and not can_fuse_mlp_allreduce:
456
467
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
457
468
  return final_hidden_states
@@ -462,13 +473,14 @@ class DeepseekV2MoE(nn.Module):
462
473
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
463
474
  self.shared_experts.gate_up_proj
464
475
  ):
465
- return self.forward_cpu(hidden_states)
476
+ return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
466
477
 
467
478
  shared_output = self._forward_shared_experts(hidden_states)
468
479
  # router_logits: (num_tokens, n_experts)
469
480
  router_logits = self.gate(hidden_states)
481
+ topk_output = self.topk(hidden_states, router_logits)
470
482
  final_hidden_states = self.experts(
471
- hidden_states=hidden_states, router_logits=router_logits
483
+ hidden_states=hidden_states, topk_output=topk_output
472
484
  )
473
485
  if not _is_cuda and not _use_aiter:
474
486
  # fused in biased_grouped_topk so we can skip here
@@ -479,11 +491,14 @@ class DeepseekV2MoE(nn.Module):
479
491
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
480
492
  return final_hidden_states
481
493
 
482
- def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
494
+ def forward_cpu(
495
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
496
+ ) -> torch.Tensor:
483
497
  # router_logits: (num_tokens, n_experts)
484
498
  router_logits = self.gate(hidden_states)
499
+ topk_output = self.topk(hidden_states, router_logits)
485
500
  fused_experts_out = self.experts(
486
- hidden_states=hidden_states, router_logits=router_logits
501
+ hidden_states=hidden_states, topk_output=topk_output
487
502
  )
488
503
 
489
504
  assert use_intel_amx_backend(
@@ -528,7 +543,7 @@ class DeepseekV2MoE(nn.Module):
528
543
  None, # a2_scale
529
544
  True, # is_vnni
530
545
  )
531
- if self.tp_size > 1 and not self.can_fuse_mlp_allreduce:
546
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
532
547
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
533
548
  return final_hidden_states
534
549
 
@@ -541,17 +556,9 @@ class DeepseekV2MoE(nn.Module):
541
556
  # router_logits: (num_tokens, n_experts)
542
557
  router_logits = self.gate(hidden_states)
543
558
  shared_output = self._forward_shared_experts(hidden_states)
544
- topk_weights, topk_idx = select_experts(
545
- hidden_states=hidden_states,
546
- router_logits=router_logits,
547
- top_k=self.top_k,
548
- use_grouped_topk=True,
549
- renormalize=self.renormalize,
550
- topk_group=self.topk_group,
551
- num_expert_group=self.num_expert_group,
552
- num_fused_shared_experts=self.num_fused_shared_experts,
553
- correction_bias=self.correction_bias,
554
- routed_scaling_factor=self.routed_scaling_factor,
559
+ topk_weights, topk_idx, _ = self.topk(
560
+ hidden_states,
561
+ router_logits,
555
562
  num_token_non_padded=forward_batch.num_token_non_padded,
556
563
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
557
564
  layer_id=self.layer_id,
@@ -641,17 +648,9 @@ class DeepseekV2MoE(nn.Module):
641
648
  with get_global_expert_distribution_recorder().with_current_layer(
642
649
  self.layer_id
643
650
  ):
644
- state.topk_weights_local, state.topk_idx_local = select_experts(
651
+ state.topk_weights_local, state.topk_idx_local, _ = self.topk(
645
652
  hidden_states=hidden_states,
646
653
  router_logits=router_logits,
647
- top_k=self.top_k,
648
- use_grouped_topk=True,
649
- renormalize=self.renormalize,
650
- topk_group=self.topk_group,
651
- num_expert_group=self.num_expert_group,
652
- num_fused_shared_experts=self.num_fused_shared_experts,
653
- correction_bias=self.correction_bias,
654
- routed_scaling_factor=self.routed_scaling_factor,
655
654
  num_token_non_padded=state.forward_batch.num_token_non_padded,
656
655
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
657
656
  layer_id=self.layer_id,
@@ -926,7 +925,7 @@ class DeepseekV2AttentionMLA(nn.Module):
926
925
  has_fused_proj
927
926
  and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
928
927
  and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
929
- in {"awq", "moe_wna16"}
928
+ in {"awq", "awq_marlin", "moe_wna16"}
930
929
  )
931
930
  self.use_min_latency_fused_a_gemm = (
932
931
  has_fused_proj
@@ -1151,7 +1150,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1151
1150
  _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1152
1151
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1153
1152
  latent_cache = latent_cache.unsqueeze(1)
1154
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
1153
+ kv_a = self.kv_a_layernorm(kv_a)
1155
1154
  kv = self.kv_b_proj(kv_a)[0]
1156
1155
  kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1157
1156
  k_nope = kv[..., : self.qk_nope_head_dim]
@@ -1690,7 +1689,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1690
1689
  _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1691
1690
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1692
1691
  latent_cache = latent_cache.unsqueeze(1)
1693
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
1692
+ kv_a = self.kv_a_layernorm(kv_a)
1694
1693
  kv = self.kv_b_proj(kv_a)[0]
1695
1694
  kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1696
1695
  k_nope = kv[..., : self.qk_nope_head_dim]
@@ -2172,7 +2171,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2172
2171
  )
2173
2172
  if hasattr(self_attn.kv_b_proj, "qweight"):
2174
2173
  # AWQ compatible
2175
- if _is_cuda:
2174
+ if _is_cuda or _is_hip:
2176
2175
  w = awq_dequantize(
2177
2176
  self_attn.kv_b_proj.qweight,
2178
2177
  self_attn.kv_b_proj.scales,
@@ -2434,154 +2433,175 @@ class DeepseekV2ForCausalLM(nn.Module):
2434
2433
  assert self.num_fused_shared_experts == 1
2435
2434
  log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2436
2435
 
2437
- params_dict = dict(self.named_parameters())
2438
- weight_names = []
2439
- for name, loaded_weight in weights:
2440
- if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2441
- name = name.replace(
2442
- "mlp.shared_experts",
2443
- f"mlp.experts.{self.config.n_routed_experts}",
2444
- )
2436
+ with concurrent.futures.ThreadPoolExecutor() as executor:
2437
+ futures = []
2438
+ params_dict = dict(self.named_parameters())
2439
+ weight_names = []
2440
+ for name, loaded_weight in weights:
2441
+ if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2442
+ name = name.replace(
2443
+ "mlp.shared_experts",
2444
+ f"mlp.experts.{self.config.n_routed_experts}",
2445
+ )
2445
2446
 
2446
- weight_names.append(name)
2447
+ weight_names.append(name)
2447
2448
 
2448
- if not is_nextn:
2449
- if hasattr(self.config, "num_nextn_predict_layers"):
2450
- num_nextn_layers = self.config.num_nextn_predict_layers
2451
- if num_nextn_layers > 0 and name.startswith("model.layers"):
2452
- name_list = name.split(".")
2453
- if (
2454
- len(name_list) >= 3
2455
- and int(name_list[2]) >= self.config.num_hidden_layers
2456
- ):
2457
- continue
2458
- else:
2459
- if not name.startswith(nextn_layer_prefix):
2460
- continue
2449
+ if not is_nextn:
2450
+ if hasattr(self.config, "num_nextn_predict_layers"):
2451
+ num_nextn_layers = self.config.num_nextn_predict_layers
2452
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
2453
+ name_list = name.split(".")
2454
+ if (
2455
+ len(name_list) >= 3
2456
+ and int(name_list[2]) >= self.config.num_hidden_layers
2457
+ ):
2458
+ continue
2459
+ else:
2460
+ if not name.startswith(nextn_layer_prefix):
2461
+ continue
2461
2462
 
2462
- # Use shared head and embed weights from target model
2463
- if "shared_head.head" in name or "embed_tokens" in name:
2464
- continue
2463
+ # Use shared head and embed weights from target model
2464
+ if "shared_head.head" in name or "embed_tokens" in name:
2465
+ continue
2465
2466
 
2466
- is_decoder = True
2467
- # For nextn specific weights
2468
- for weight_name in nextn_spec_weight_names:
2469
- if weight_name in name:
2470
- name = name.replace(nextn_layer_prefix, "model")
2471
- is_decoder = False
2472
- break
2473
- # For decoder layer weights
2474
- if is_decoder:
2475
- name = name.replace(nextn_layer_prefix, "model.decoder")
2476
-
2477
- if "rotary_emb.inv_freq" in name:
2478
- continue
2479
- for param_name, weight_name, shard_id in stacked_params_mapping:
2480
- # Skip non-stacked layers and experts (experts handled below).
2481
- if weight_name not in name:
2482
- continue
2483
- # We have mlp.experts[0].gate_proj in the checkpoint.
2484
- # Since we handle the experts below in expert_params_mapping,
2485
- # we need to skip here BEFORE we update the name, otherwise
2486
- # name will be updated to mlp.experts[0].gate_up_proj, which
2487
- # will then be updated below in expert_params_mapping
2488
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
2489
- if ("mlp.experts." in name) and name not in params_dict:
2490
- continue
2491
- name = name.replace(weight_name, param_name)
2492
- # Skip loading extra bias for GPTQ models.
2493
- if name.endswith(".bias") and name not in params_dict:
2467
+ is_decoder = True
2468
+ # For nextn specific weights
2469
+ for weight_name in nextn_spec_weight_names:
2470
+ if weight_name in name:
2471
+ name = name.replace(nextn_layer_prefix, "model")
2472
+ is_decoder = False
2473
+ break
2474
+ # For decoder layer weights
2475
+ if is_decoder:
2476
+ name = name.replace(nextn_layer_prefix, "model.decoder")
2477
+
2478
+ if "rotary_emb.inv_freq" in name:
2494
2479
  continue
2495
- param = params_dict[name]
2496
- weight_loader = param.weight_loader
2497
- weight_loader(param, loaded_weight, shard_id)
2498
- break
2499
- else:
2500
- for mapping in expert_params_mapping:
2501
- param_name, weight_name, expert_id, shard_id = mapping
2480
+ for param_name, weight_name, shard_id in stacked_params_mapping:
2481
+ # Skip non-stacked layers and experts (experts handled below).
2502
2482
  if weight_name not in name:
2503
2483
  continue
2484
+ # We have mlp.experts[0].gate_proj in the checkpoint.
2485
+ # Since we handle the experts below in expert_params_mapping,
2486
+ # we need to skip here BEFORE we update the name, otherwise
2487
+ # name will be updated to mlp.experts[0].gate_up_proj, which
2488
+ # will then be updated below in expert_params_mapping
2489
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
2490
+ if ("mlp.experts." in name) and name not in params_dict:
2491
+ continue
2504
2492
  name = name.replace(weight_name, param_name)
2493
+ # Skip loading extra bias for GPTQ models.
2494
+ if name.endswith(".bias") and name not in params_dict:
2495
+ continue
2505
2496
  param = params_dict[name]
2506
2497
  weight_loader = param.weight_loader
2507
- weight_loader(
2508
- param,
2509
- loaded_weight,
2510
- name,
2511
- shard_id=shard_id,
2512
- expert_id=expert_id,
2498
+ futures.append(
2499
+ executor.submit(weight_loader, param, loaded_weight, shard_id)
2513
2500
  )
2514
2501
  break
2515
2502
  else:
2516
- # Skip loading extra bias for GPTQ models.
2517
- if name.endswith(".bias") and name not in params_dict:
2518
- continue
2519
- if fuse_qkv_a_proj and (
2520
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2521
- ):
2522
- cached_a_proj[name] = loaded_weight
2523
- q_a_proj_name = (
2524
- name
2525
- if "q_a_proj" in name
2526
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
2527
- )
2528
- kv_a_proj_name = (
2529
- name
2530
- if "kv_a_proj_with_mqa" in name
2531
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2503
+ for mapping in expert_params_mapping:
2504
+ param_name, weight_name, expert_id, shard_id = mapping
2505
+ if weight_name not in name:
2506
+ continue
2507
+ name = name.replace(weight_name, param_name)
2508
+ param = params_dict[name]
2509
+ weight_loader = param.weight_loader
2510
+ futures.append(
2511
+ executor.submit(
2512
+ weight_loader,
2513
+ param,
2514
+ loaded_weight,
2515
+ name,
2516
+ shard_id=shard_id,
2517
+ expert_id=expert_id,
2518
+ )
2532
2519
  )
2533
-
2534
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
2535
- if (
2536
- q_a_proj_name in cached_a_proj
2537
- and kv_a_proj_name in cached_a_proj
2520
+ break
2521
+ else:
2522
+ # Skip loading extra bias for GPTQ models.
2523
+ if name.endswith(".bias") and name not in params_dict:
2524
+ continue
2525
+ if fuse_qkv_a_proj and (
2526
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2538
2527
  ):
2539
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
2540
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2541
- cat_dim = 0
2542
- if self.quant_config is not None and (
2543
- self.quant_config.get_name() == "awq"
2544
- or self.quant_config.get_name() == "moe_wna16"
2545
- ):
2546
- cat_dim = 1
2547
- fused_weight = torch.cat(
2548
- [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2549
- )
2550
- param_name = (
2551
- name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
2528
+ cached_a_proj[name] = loaded_weight
2529
+ q_a_proj_name = (
2530
+ name
2552
2531
  if "q_a_proj" in name
2553
- else name.replace(
2554
- "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
2555
- )
2532
+ else name.replace("kv_a_proj_with_mqa", "q_a_proj")
2533
+ )
2534
+ kv_a_proj_name = (
2535
+ name
2536
+ if "kv_a_proj_with_mqa" in name
2537
+ else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2556
2538
  )
2557
- param = params_dict[param_name]
2558
2539
 
2540
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
2541
+ if (
2542
+ q_a_proj_name in cached_a_proj
2543
+ and kv_a_proj_name in cached_a_proj
2544
+ ):
2545
+ q_a_proj_weight = cached_a_proj[q_a_proj_name]
2546
+ kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2547
+ cat_dim = 0
2548
+ if self.quant_config is not None and (
2549
+ self.quant_config.get_name() == "awq"
2550
+ or self.quant_config.get_name() == "awq_marlin"
2551
+ or self.quant_config.get_name() == "moe_wna16"
2552
+ ):
2553
+ cat_dim = 1
2554
+ fused_weight = torch.cat(
2555
+ [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
2556
+ )
2557
+ param_name = (
2558
+ name.replace(
2559
+ "q_a_proj", "fused_qkv_a_proj_with_mqa"
2560
+ )
2561
+ if "q_a_proj" in name
2562
+ else name.replace(
2563
+ "kv_a_proj_with_mqa",
2564
+ "fused_qkv_a_proj_with_mqa",
2565
+ )
2566
+ )
2567
+ param = params_dict[param_name]
2568
+
2569
+ weight_loader = getattr(
2570
+ param, "weight_loader", default_weight_loader
2571
+ )
2572
+ futures.append(
2573
+ executor.submit(weight_loader, param, fused_weight)
2574
+ )
2575
+ cached_a_proj.pop(q_a_proj_name)
2576
+ cached_a_proj.pop(kv_a_proj_name)
2577
+ else:
2578
+ if (
2579
+ "k_scale" in name or "v_scale" in name
2580
+ ) and name not in params_dict:
2581
+ # modelopt attn kv scale is named differently
2582
+ for scale in ["k_scale", "v_scale"]:
2583
+ if scale in name:
2584
+ name = name.replace(
2585
+ f"{scale[0]}_proj", "attn_mqa"
2586
+ )
2587
+ break
2588
+ if name not in params_dict:
2589
+ # modelopt ckpt contains not needed weights for MTP module:
2590
+ # model.decoder.self_attn.attn_mqa.v_scale and
2591
+ # model.decoder.self_attn.attn_mqa.k_scale
2592
+ logger.warning(f"{name} not found in params_dict.")
2593
+ continue
2594
+ param = params_dict[name]
2559
2595
  weight_loader = getattr(
2560
2596
  param, "weight_loader", default_weight_loader
2561
2597
  )
2562
- weight_loader(param, fused_weight)
2563
- cached_a_proj.pop(q_a_proj_name)
2564
- cached_a_proj.pop(kv_a_proj_name)
2565
- else:
2566
- if (
2567
- "k_scale" in name or "v_scale" in name
2568
- ) and name not in params_dict:
2569
- # modelopt attn kv scale is named differently
2570
- for scale in ["k_scale", "v_scale"]:
2571
- if scale in name:
2572
- name = name.replace(f"{scale[0]}_proj", "attn_mqa")
2573
- break
2574
- if name not in params_dict:
2575
- # modelopt ckpt contains not needed weights for MTP module:
2576
- # model.decoder.self_attn.attn_mqa.v_scale and
2577
- # model.decoder.self_attn.attn_mqa.k_scale
2578
- logger.warning(f"{name} not found in params_dict.")
2579
- continue
2580
- param = params_dict[name]
2581
- weight_loader = getattr(
2582
- param, "weight_loader", default_weight_loader
2583
- )
2584
- weight_loader(param, loaded_weight)
2598
+ futures.append(
2599
+ executor.submit(weight_loader, param, loaded_weight)
2600
+ )
2601
+
2602
+ # Wait for all tasks to complete and raise any exceptions.
2603
+ for future in concurrent.futures.as_completed(futures):
2604
+ future.result()
2585
2605
 
2586
2606
  self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
2587
2607
 
@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
260
260
  def get_image_feature(self, items: List[MultimodalDataItem]):
261
261
 
262
262
  images_spatial_crop = torch.cat(
263
- [item.image_spatial_crop for item in items], dim=0
263
+ [item.images_spatial_crop for item in items], dim=0
264
264
  )
265
265
 
266
266
  assert images_spatial_crop.dim() == 3
@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
268
268
  # TODO: can it be batched ?
269
269
  images_in_this_batch = []
270
270
  for item in items:
271
- assert item.pixel_values.dim() == 4
271
+ assert item.feature.dim() == 4
272
272
  image_feature = self.vision.forward_features(
273
- item.pixel_values.type(next(self.vision.parameters()).dtype).to(
273
+ item.feature.type(next(self.vision.parameters()).dtype).to(
274
274
  device=next(self.vision.parameters()).device
275
275
  )
276
276
  )
@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
278
278
  _, hw, n_dim = images_embeds.shape
279
279
  h = w = int(hw**0.5)
280
280
  tile_index = 0
281
- for jdx in range(item.image_spatial_crop.shape[1]):
282
- num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
281
+ for jdx in range(item.images_spatial_crop.shape[1]):
282
+ num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
283
283
  if num_width_tiles == 0 or num_height_tiles == 0:
284
284
  break
285
285
  num_tiles_in_image = num_width_tiles * num_height_tiles
@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
318
318
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
319
319
  )
320
320
 
321
+ @torch.no_grad()
322
+ def forward_split_prefill(
323
+ self,
324
+ input_ids: torch.Tensor,
325
+ positions: torch.Tensor,
326
+ forward_batch: ForwardBatch,
327
+ split_interval: Tuple[int, int], # [start, end) 0-based
328
+ input_embeds: torch.Tensor = None,
329
+ ):
330
+ start, end = split_interval
331
+ # embed
332
+ if start == 0:
333
+ if input_embeds is None:
334
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
335
+ else:
336
+ forward_batch.hidden_states = input_embeds
337
+
338
+ # Normalize the embedding by sqrt(hidden_size)
339
+ forward_batch.hidden_states *= self.model.config.hidden_size**0.5
340
+
341
+ # decoder layer
342
+ for i in range(start, end):
343
+ layer = self.model.layers[i]
344
+ forward_batch.hidden_states, forward_batch.residual = layer(
345
+ positions,
346
+ forward_batch.hidden_states,
347
+ forward_batch,
348
+ forward_batch.residual,
349
+ )
350
+
351
+ if end == self.model.config.num_hidden_layers:
352
+ # norm
353
+ forward_batch.hidden_states, _ = self.model.norm(
354
+ forward_batch.hidden_states, forward_batch.residual
355
+ )
356
+
357
+ # logits process
358
+ result = self.logits_processor(
359
+ input_ids,
360
+ forward_batch.hidden_states,
361
+ self.model.embed_tokens,
362
+ forward_batch,
363
+ )
364
+ else:
365
+ result = None
366
+
367
+ return result
368
+
321
369
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
370
  stacked_params_mapping = [
323
371
  # (param_name, shard_name, shard_id)