sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,8 @@
18
18
 
19
19
  import logging
20
20
  import os
21
- from enum import IntEnum, auto
21
+ from dataclasses import dataclass
22
+ from enum import Enum, IntEnum, auto
22
23
  from typing import Any, Dict, Iterable, Optional, Tuple
23
24
 
24
25
  import torch
@@ -28,6 +29,7 @@ from tqdm import tqdm
28
29
  from transformers import PretrainedConfig
29
30
 
30
31
  from sglang.srt.distributed import (
32
+ get_tensor_model_parallel_rank,
31
33
  get_tensor_model_parallel_world_size,
32
34
  parallel_state,
33
35
  tensor_model_parallel_all_reduce,
@@ -51,10 +53,15 @@ from sglang.srt.layers.linear import (
51
53
  )
52
54
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
55
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
56
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
54
57
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
58
  from sglang.srt.layers.moe.topk import select_experts
56
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
- from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
60
+ from sglang.srt.layers.quantization.fp8_kernel import (
61
+ _enable_jit_deepgemm_bmm,
62
+ per_tensor_quant_mla_deep_gemm_masked_fp8,
63
+ per_tensor_quant_mla_fp8,
64
+ )
58
65
  from sglang.srt.layers.quantization.fp8_utils import (
59
66
  block_quant_to_tensor_quant,
60
67
  channel_quant_to_tensor_quant,
@@ -73,17 +80,16 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
73
80
  from sglang.srt.managers.schedule_batch import global_server_args_dict
74
81
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
75
82
  from sglang.srt.model_loader.weight_utils import default_weight_loader
76
- from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
83
+ from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
77
84
 
78
85
  _is_hip = is_hip()
79
86
  _is_cuda = is_cuda()
80
87
 
81
88
  if _is_cuda:
89
+ from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
82
90
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
83
-
84
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
85
91
  else:
86
- from vllm import _custom_ops as ops
92
+ from vllm._custom_ops import awq_dequantize
87
93
 
88
94
  if _is_hip:
89
95
  from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
@@ -96,7 +102,6 @@ logger = logging.getLogger(__name__)
96
102
 
97
103
 
98
104
  class AttnForwardMethod(IntEnum):
99
-
100
105
  # Use multi-head attention
101
106
  MHA = auto()
102
107
 
@@ -147,7 +152,7 @@ class DeepseekV2MLP(nn.Module):
147
152
  )
148
153
  self.act_fn = SiluAndMul()
149
154
 
150
- def forward(self, x):
155
+ def forward(self, x, forward_mode: Optional[ForwardMode] = None):
151
156
  gate_up, _ = self.gate_up_proj(x)
152
157
  x = self.act_fn(gate_up)
153
158
  x, _ = self.down_proj(x)
@@ -188,11 +193,7 @@ class DeepseekV2MoE(nn.Module):
188
193
  self.tp_size = get_tensor_model_parallel_world_size()
189
194
  self.routed_scaling_factor = config.routed_scaling_factor
190
195
  self.n_shared_experts = config.n_shared_experts
191
- self.n_share_experts_fusion = (
192
- global_server_args_dict["n_share_experts_fusion"]
193
- if global_server_args_dict["n_share_experts_fusion"] is not None
194
- else 0
195
- )
196
+ self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
196
197
 
197
198
  if self.tp_size > config.n_routed_experts:
198
199
  raise ValueError(
@@ -225,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
225
226
  num_expert_group=config.n_group,
226
227
  topk_group=config.topk_group,
227
228
  correction_bias=self.gate.e_score_correction_bias,
229
+ routed_scaling_factor=self.routed_scaling_factor,
228
230
  prefix=add_prefix("experts", prefix),
229
231
  **(
230
232
  dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
@@ -333,6 +335,7 @@ class DeepseekV2MoE(nn.Module):
333
335
  topk_group=self.topk_group,
334
336
  num_expert_group=self.num_expert_group,
335
337
  correction_bias=self.correction_bias,
338
+ routed_scaling_factor=self.routed_scaling_factor,
336
339
  )
337
340
  if self.ep_size > 1:
338
341
  # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
@@ -373,7 +376,7 @@ class DeepseekV2MoE(nn.Module):
373
376
  return final_hidden_states
374
377
 
375
378
  def _forward_shared_experts(self, hidden_states):
376
- if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
379
+ if self.n_share_experts_fusion == 0:
377
380
  return self.shared_experts(hidden_states)
378
381
  else:
379
382
  return None
@@ -387,179 +390,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
387
390
  return 0.1 * mscale * math.log(scale) + 1.0
388
391
 
389
392
 
390
- class DeepseekV2Attention(nn.Module):
391
-
392
- def __init__(
393
- self,
394
- config: PretrainedConfig,
395
- hidden_size: int,
396
- num_heads: int,
397
- qk_nope_head_dim: int,
398
- qk_rope_head_dim: int,
399
- v_head_dim: int,
400
- q_lora_rank: int,
401
- kv_lora_rank: int,
402
- rope_theta: float = 10000,
403
- rope_scaling: Optional[Dict[str, Any]] = None,
404
- max_position_embeddings: int = 8192,
405
- quant_config: Optional[QuantizationConfig] = None,
406
- layer_id=None,
407
- reduce_results: bool = True,
408
- prefix: str = "",
409
- ) -> None:
410
- super().__init__()
411
- self.layer_id = layer_id
412
- self.hidden_size = hidden_size
413
- self.qk_nope_head_dim = qk_nope_head_dim
414
- self.qk_rope_head_dim = qk_rope_head_dim
415
- self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
416
- self.v_head_dim = v_head_dim
417
- self.q_lora_rank = q_lora_rank
418
- self.kv_lora_rank = kv_lora_rank
419
-
420
- self.dp_size = get_attention_dp_size()
421
- attn_tp_rank = get_attention_tp_rank()
422
- attn_tp_size = get_attention_tp_size()
423
-
424
- self.num_heads = num_heads
425
- assert num_heads % attn_tp_size == 0
426
- self.num_local_heads = num_heads // attn_tp_size
427
- self.scaling = self.qk_head_dim**-0.5
428
- self.rope_theta = rope_theta
429
- self.max_position_embeddings = max_position_embeddings
430
-
431
- if self.q_lora_rank is not None:
432
- self.q_a_proj = ReplicatedLinear(
433
- self.hidden_size,
434
- self.q_lora_rank,
435
- bias=False,
436
- quant_config=quant_config,
437
- prefix=add_prefix("q_a_proj", prefix),
438
- )
439
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
440
- self.q_b_proj = ColumnParallelLinear(
441
- q_lora_rank,
442
- self.num_heads * self.qk_head_dim,
443
- bias=False,
444
- quant_config=quant_config,
445
- prefix=add_prefix("q_b_proj", prefix),
446
- )
447
- else:
448
- self.q_proj = ColumnParallelLinear(
449
- self.hidden_size,
450
- self.num_heads * self.qk_head_dim,
451
- bias=False,
452
- quant_config=quant_config,
453
- prefix=add_prefix("q_proj", prefix),
454
- tp_rank=attn_tp_rank,
455
- tp_size=attn_tp_size,
456
- )
457
-
458
- self.kv_a_proj_with_mqa = ReplicatedLinear(
459
- self.hidden_size,
460
- self.kv_lora_rank + self.qk_rope_head_dim,
461
- bias=False,
462
- quant_config=quant_config,
463
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
464
- )
465
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
466
- self.kv_b_proj = ColumnParallelLinear(
467
- self.kv_lora_rank,
468
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
469
- bias=False,
470
- quant_config=quant_config,
471
- prefix=add_prefix("kv_b_proj", prefix),
472
- )
473
- # O projection.
474
- self.o_proj = RowParallelLinear(
475
- self.num_heads * self.v_head_dim,
476
- self.hidden_size,
477
- bias=False,
478
- quant_config=quant_config,
479
- prefix=add_prefix("o_proj", prefix),
480
- reduce_results=reduce_results,
481
- tp_rank=attn_tp_rank,
482
- tp_size=attn_tp_size,
483
- )
484
- rope_scaling["rope_type"] = "deepseek_yarn"
485
- self.rotary_emb = get_rope_wrapper(
486
- qk_rope_head_dim,
487
- rotary_dim=qk_rope_head_dim,
488
- max_position=max_position_embeddings,
489
- base=rope_theta,
490
- rope_scaling=rope_scaling,
491
- is_neox_style=False,
492
- device=global_server_args_dict["device"],
493
- )
494
-
495
- if rope_scaling:
496
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
497
- scaling_factor = rope_scaling["factor"]
498
- mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
499
- self.scaling = self.scaling * mscale * mscale
500
-
501
- # TODO, support head_size 192
502
- self.attn = RadixAttention(
503
- self.num_local_heads,
504
- 256,
505
- self.scaling,
506
- num_kv_heads=self.num_local_heads,
507
- layer_id=layer_id,
508
- quant_config=quant_config,
509
- prefix=add_prefix("attn", prefix),
510
- )
511
-
512
- def forward(
513
- self,
514
- positions: torch.Tensor,
515
- hidden_states: torch.Tensor,
516
- forward_batch: ForwardBatch,
517
- ) -> torch.Tensor:
518
- if hidden_states.shape[0] == 0:
519
- assert (
520
- not self.o_proj.reduce_results
521
- ), "short-circuiting allreduce will lead to hangs"
522
- return hidden_states
523
-
524
- if self.q_lora_rank is not None:
525
- q = self.q_a_proj(hidden_states)[0]
526
- q = self.q_a_layernorm(q)
527
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
528
- else:
529
- q = self.q_proj(hidden_states)[0].view(
530
- -1, self.num_local_heads, self.qk_head_dim
531
- )
532
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
533
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
534
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
535
- latent_cache = latent_cache.unsqueeze(1)
536
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
537
- kv = self.kv_b_proj(kv_a)[0]
538
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
539
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
540
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
541
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
542
- q[..., self.qk_nope_head_dim :] = q_pe
543
- k = torch.empty_like(q)
544
- k[..., : self.qk_nope_head_dim] = k_nope
545
- k[..., self.qk_nope_head_dim :] = k_pe
546
- q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
547
- -1, self.num_local_heads * 256
548
- )
549
- k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
550
- -1, self.num_local_heads * 256
551
- )
552
- v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
553
- -1, self.num_local_heads * 256
554
- )
555
- attn_output = self.attn(q, k, v, forward_batch)
556
- attn_output = attn_output.view(-1, self.num_local_heads, 256)[
557
- ..., : self.v_head_dim
558
- ].reshape(-1, self.num_local_heads * self.v_head_dim)
559
- output, _ = self.o_proj(attn_output)
560
- return output
561
-
562
-
563
393
  class DeepseekV2AttentionMLA(nn.Module):
564
394
 
565
395
  def __init__(
@@ -705,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
705
535
  self.w_vc = None
706
536
  self.w_scale = None
707
537
 
538
+ self.w_scale_k = None
539
+ self.w_scale_v = None
540
+ self.use_deep_gemm_bmm = False
541
+
708
542
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
709
543
  "flashinfer_mla_disable_ragged"
710
544
  ]
@@ -762,6 +596,7 @@ class DeepseekV2AttentionMLA(nn.Module):
762
596
  positions: torch.Tensor,
763
597
  hidden_states: torch.Tensor,
764
598
  forward_batch: ForwardBatch,
599
+ zero_allocator: BumpAllocator,
765
600
  ) -> torch.Tensor:
766
601
  if hidden_states.shape[0] == 0:
767
602
  assert (
@@ -787,9 +622,13 @@ class DeepseekV2AttentionMLA(nn.Module):
787
622
  positions, hidden_states, forward_batch
788
623
  )
789
624
  else:
790
- return self.forward_absorb(positions, hidden_states, forward_batch)
625
+ return self.forward_absorb(
626
+ positions, hidden_states, forward_batch, zero_allocator
627
+ )
791
628
  else:
792
- return self.forward_absorb(positions, hidden_states, forward_batch)
629
+ return self.forward_absorb(
630
+ positions, hidden_states, forward_batch, zero_allocator
631
+ )
793
632
 
794
633
  def forward_normal(
795
634
  self,
@@ -838,6 +677,7 @@ class DeepseekV2AttentionMLA(nn.Module):
838
677
  positions: torch.Tensor,
839
678
  hidden_states: torch.Tensor,
840
679
  forward_batch: ForwardBatch,
680
+ zero_allocator: BumpAllocator,
841
681
  ) -> torch.Tensor:
842
682
  q_len = hidden_states.shape[0]
843
683
  q_input = hidden_states.new_empty(
@@ -853,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
853
693
  )
854
694
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
855
695
 
856
- if self.w_kc.dtype == torch.float8_e4m3fnuz:
696
+ if self.use_deep_gemm_bmm:
697
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
698
+ per_tensor_quant_mla_deep_gemm_masked_fp8(
699
+ q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
700
+ )
701
+ )
702
+ q_nope_out = q_nope.new_empty(
703
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
704
+ )
705
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
706
+ (q_nope_val, q_nope_scale),
707
+ (self.w_kc, self.w_scale_k),
708
+ q_nope_out,
709
+ masked_m,
710
+ expected_m,
711
+ )
712
+ q_nope_out = q_nope_out[:, :expected_m, :]
713
+ elif self.w_kc.dtype == torch.float8_e4m3fnuz:
857
714
  # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
858
715
  q_nope_out = torch.bmm(
859
716
  q_nope.to(torch.bfloat16).transpose(0, 1),
@@ -861,7 +718,8 @@ class DeepseekV2AttentionMLA(nn.Module):
861
718
  )
862
719
  elif self.w_kc.dtype == torch.float8_e4m3fn:
863
720
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
864
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
721
+ q_nope.transpose(0, 1),
722
+ zero_allocator.allocate(1),
865
723
  )
866
724
  q_nope_out = bmm_fp8(
867
725
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -884,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
884
742
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
885
743
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
886
744
 
887
- if self.w_vc.dtype == torch.float8_e4m3fnuz:
745
+ if self.use_deep_gemm_bmm:
746
+ attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
747
+ per_tensor_quant_mla_deep_gemm_masked_fp8(
748
+ attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
749
+ )
750
+ )
751
+ attn_bmm_output = attn_output.new_empty(
752
+ (self.num_local_heads, aligned_m, self.v_head_dim)
753
+ )
754
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
755
+ (attn_output_val, attn_output_scale),
756
+ (self.w_vc, self.w_scale_v),
757
+ attn_bmm_output,
758
+ masked_m,
759
+ expected_m,
760
+ )
761
+ attn_bmm_output = attn_bmm_output[:, :expected_m, :]
762
+ elif self.w_vc.dtype == torch.float8_e4m3fnuz:
888
763
  # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
889
764
  attn_bmm_output = torch.bmm(
890
765
  attn_output.to(torch.bfloat16).transpose(0, 1),
@@ -892,7 +767,8 @@ class DeepseekV2AttentionMLA(nn.Module):
892
767
  )
893
768
  elif self.w_vc.dtype == torch.float8_e4m3fn:
894
769
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
895
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
770
+ attn_output.transpose(0, 1),
771
+ zero_allocator.allocate(1),
896
772
  )
897
773
  attn_bmm_output = bmm_fp8(
898
774
  attn_output_val,
@@ -913,6 +789,7 @@ class DeepseekV2AttentionMLA(nn.Module):
913
789
  positions: torch.Tensor,
914
790
  hidden_states: torch.Tensor,
915
791
  forward_batch: ForwardBatch,
792
+ zero_allocator: BumpAllocator,
916
793
  ) -> torch.Tensor:
917
794
  enable_rope_fusion = (
918
795
  os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
@@ -939,7 +816,9 @@ class DeepseekV2AttentionMLA(nn.Module):
939
816
  )
940
817
  elif self.w_kc.dtype == torch.float8_e4m3fn:
941
818
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
942
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
819
+ q_nope.transpose(0, 1),
820
+ zero_allocator.allocate(1),
821
+ dtype=torch.float8_e4m3fn,
943
822
  )
944
823
  q_nope_out = bmm_fp8(
945
824
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -1035,7 +914,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1035
914
  )
1036
915
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1037
916
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1038
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
917
+ attn_output.transpose(0, 1),
918
+ zero_allocator.allocate(1),
919
+ dtype=torch.float8_e4m3fn,
1039
920
  )
1040
921
  attn_bmm_output = bmm_fp8(
1041
922
  attn_output_val,
@@ -1173,6 +1054,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1173
1054
  return output
1174
1055
 
1175
1056
 
1057
+ class _FFNInputMode(Enum):
1058
+ # The MLP sublayer requires 1/tp_size tokens as input
1059
+ SCATTERED = auto()
1060
+ # The MLP sublayer requires all tokens as input
1061
+ FULL = auto()
1062
+
1063
+
1064
+ @dataclass
1065
+ class _DecoderLayerInfo:
1066
+ is_sparse: bool
1067
+ ffn_input_mode: _FFNInputMode
1068
+
1069
+
1176
1070
  class DeepseekV2DecoderLayer(nn.Module):
1177
1071
 
1178
1072
  def __init__(
@@ -1183,14 +1077,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1183
1077
  is_nextn: bool = False,
1184
1078
  prefix: str = "",
1185
1079
  ) -> None:
1186
-
1187
- def is_sparse_layer(l: int):
1188
- return (
1189
- config.n_routed_experts is not None
1190
- and l >= config.first_k_dense_replace
1191
- and l % config.moe_layer_freq == 0
1192
- )
1193
-
1194
1080
  super().__init__()
1195
1081
  self.hidden_size = config.hidden_size
1196
1082
  rope_theta = getattr(config, "rope_theta", 10000)
@@ -1201,68 +1087,54 @@ class DeepseekV2DecoderLayer(nn.Module):
1201
1087
  self.dp_size = get_attention_dp_size()
1202
1088
  self.attn_tp_size = get_attention_tp_size()
1203
1089
  self.attn_tp_rank = get_attention_tp_rank()
1090
+ self.self_attn = DeepseekV2AttentionMLA(
1091
+ config=config,
1092
+ hidden_size=self.hidden_size,
1093
+ num_heads=config.num_attention_heads,
1094
+ qk_nope_head_dim=config.qk_nope_head_dim,
1095
+ qk_rope_head_dim=config.qk_rope_head_dim,
1096
+ v_head_dim=config.v_head_dim,
1097
+ q_lora_rank=(
1098
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1099
+ ),
1100
+ kv_lora_rank=config.kv_lora_rank,
1101
+ rope_theta=rope_theta,
1102
+ rope_scaling=rope_scaling,
1103
+ max_position_embeddings=max_position_embeddings,
1104
+ quant_config=quant_config,
1105
+ layer_id=layer_id,
1106
+ reduce_results=False,
1107
+ prefix=add_prefix("self_attn", prefix),
1108
+ )
1204
1109
 
1205
- if not global_server_args_dict["disable_mla"]:
1206
- self.self_attn = DeepseekV2AttentionMLA(
1207
- config=config,
1208
- hidden_size=self.hidden_size,
1209
- num_heads=config.num_attention_heads,
1210
- qk_nope_head_dim=config.qk_nope_head_dim,
1211
- qk_rope_head_dim=config.qk_rope_head_dim,
1212
- v_head_dim=config.v_head_dim,
1213
- q_lora_rank=(
1214
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1215
- ),
1216
- kv_lora_rank=config.kv_lora_rank,
1217
- rope_theta=rope_theta,
1218
- rope_scaling=rope_scaling,
1219
- max_position_embeddings=max_position_embeddings,
1220
- quant_config=quant_config,
1221
- layer_id=layer_id,
1222
- reduce_results=False,
1223
- prefix=add_prefix("self_attn", prefix),
1224
- )
1225
- else:
1226
- self.self_attn = DeepseekV2Attention(
1227
- config=config,
1228
- hidden_size=self.hidden_size,
1229
- num_heads=config.num_attention_heads,
1230
- qk_nope_head_dim=config.qk_nope_head_dim,
1231
- qk_rope_head_dim=config.qk_rope_head_dim,
1232
- v_head_dim=config.v_head_dim,
1233
- q_lora_rank=(
1234
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1235
- ),
1236
- kv_lora_rank=config.kv_lora_rank,
1237
- rope_theta=rope_theta,
1238
- rope_scaling=rope_scaling,
1239
- max_position_embeddings=max_position_embeddings,
1240
- quant_config=quant_config,
1241
- layer_id=layer_id,
1242
- reduce_results=False,
1243
- prefix=add_prefix("self_attn", prefix),
1244
- )
1110
+ self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
1111
+ previous_layer_info = self._compute_info(
1112
+ config, layer_id=layer_id - 1, is_nextn=False
1113
+ )
1245
1114
 
1246
- if is_nextn or is_sparse_layer(layer_id):
1115
+ if self.info.is_sparse:
1247
1116
  self.mlp = DeepseekV2MoE(
1248
1117
  config=config,
1249
1118
  quant_config=quant_config,
1250
1119
  prefix=add_prefix("mlp", prefix),
1251
1120
  )
1252
- self.is_sparse = True
1253
1121
  else:
1122
+ if self._enable_moe_dense_fully_dp():
1123
+ mlp_tp_rank, mlp_tp_size = 0, 1
1124
+ else:
1125
+ mlp_tp_rank, mlp_tp_size = None, None
1254
1126
  self.mlp = DeepseekV2MLP(
1255
1127
  hidden_size=config.hidden_size,
1256
1128
  intermediate_size=config.intermediate_size,
1257
1129
  hidden_act=config.hidden_act,
1258
1130
  quant_config=quant_config,
1259
1131
  prefix=add_prefix("mlp", prefix),
1132
+ tp_rank=mlp_tp_rank,
1133
+ tp_size=mlp_tp_size,
1260
1134
  )
1261
- self.is_sparse = False
1262
1135
 
1263
1136
  self.input_is_scattered = (
1264
- is_sparse_layer(layer_id - 1)
1265
- and global_server_args_dict["enable_deepep_moe"]
1137
+ previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1266
1138
  )
1267
1139
  self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1268
1140
 
@@ -1271,28 +1143,51 @@ class DeepseekV2DecoderLayer(nn.Module):
1271
1143
  config.hidden_size, eps=config.rms_norm_eps
1272
1144
  )
1273
1145
 
1146
+ @staticmethod
1147
+ def _enable_moe_dense_fully_dp():
1148
+ return global_server_args_dict["moe_dense_tp_size"] == 1
1149
+
1150
+ @staticmethod
1151
+ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
1152
+ is_sparse = is_nextn or (
1153
+ config.n_routed_experts is not None
1154
+ and layer_id >= config.first_k_dense_replace
1155
+ and layer_id % config.moe_layer_freq == 0
1156
+ )
1157
+ ffn_input_mode = (
1158
+ _FFNInputMode.SCATTERED
1159
+ if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1160
+ or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1161
+ else _FFNInputMode.FULL
1162
+ )
1163
+ return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
1164
+
1274
1165
  def forward(
1275
1166
  self,
1276
1167
  positions: torch.Tensor,
1277
1168
  hidden_states: torch.Tensor,
1278
1169
  forward_batch: ForwardBatch,
1279
1170
  residual: Optional[torch.Tensor],
1171
+ zero_allocator: BumpAllocator,
1280
1172
  ) -> torch.Tensor:
1281
- if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
1282
- return self.forward_deepep(
1283
- positions, hidden_states, forward_batch, residual
1173
+ if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
1174
+ return self.forward_ffn_with_scattered_input(
1175
+ positions, hidden_states, forward_batch, residual, zero_allocator
1284
1176
  )
1285
- else:
1286
- return self.forward_normal(
1287
- positions, hidden_states, forward_batch, residual
1177
+ elif self.info.ffn_input_mode == _FFNInputMode.FULL:
1178
+ return self.forward_ffn_with_full_input(
1179
+ positions, hidden_states, forward_batch, residual, zero_allocator
1288
1180
  )
1181
+ else:
1182
+ raise NotImplementedError
1289
1183
 
1290
- def forward_normal(
1184
+ def forward_ffn_with_full_input(
1291
1185
  self,
1292
1186
  positions: torch.Tensor,
1293
1187
  hidden_states: torch.Tensor,
1294
1188
  forward_batch: ForwardBatch,
1295
1189
  residual: Optional[torch.Tensor],
1190
+ zero_allocator: BumpAllocator,
1296
1191
  ) -> torch.Tensor:
1297
1192
 
1298
1193
  if hidden_states.shape[0] == 0:
@@ -1313,6 +1208,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1313
1208
  positions=positions,
1314
1209
  hidden_states=hidden_states,
1315
1210
  forward_batch=forward_batch,
1211
+ zero_allocator=zero_allocator,
1316
1212
  )
1317
1213
 
1318
1214
  # Gather
@@ -1354,12 +1250,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1354
1250
 
1355
1251
  return hidden_states, residual
1356
1252
 
1357
- def forward_deepep(
1253
+ def forward_ffn_with_scattered_input(
1358
1254
  self,
1359
1255
  positions: torch.Tensor,
1360
1256
  hidden_states: torch.Tensor,
1361
1257
  forward_batch: ForwardBatch,
1362
1258
  residual: Optional[torch.Tensor],
1259
+ zero_allocator: BumpAllocator,
1363
1260
  ) -> torch.Tensor:
1364
1261
 
1365
1262
  if hidden_states.shape[0] == 0:
@@ -1385,6 +1282,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1385
1282
  positions=positions,
1386
1283
  hidden_states=hidden_states,
1387
1284
  forward_batch=forward_batch,
1285
+ zero_allocator=zero_allocator,
1388
1286
  )
1389
1287
 
1390
1288
  if self.attn_tp_size != 1:
@@ -1410,7 +1308,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1410
1308
  hidden_states, residual = self.post_attention_layernorm(
1411
1309
  hidden_states, residual
1412
1310
  )
1413
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1311
+
1312
+ if not (
1313
+ self._enable_moe_dense_fully_dp()
1314
+ and (not self.info.is_sparse)
1315
+ and hidden_states.shape[0] == 0
1316
+ ):
1317
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1414
1318
 
1415
1319
  if self.is_last_layer and self.attn_tp_size != 1:
1416
1320
  hidden_states += residual
@@ -1466,6 +1370,14 @@ class DeepseekV2Model(nn.Module):
1466
1370
  forward_batch: ForwardBatch,
1467
1371
  input_embeds: torch.Tensor = None,
1468
1372
  ) -> torch.Tensor:
1373
+ zero_allocator = BumpAllocator(
1374
+ # TODO for two-batch-overlap, we need a larger buffer size
1375
+ buffer_size=len(self.layers) * 2,
1376
+ dtype=torch.float32,
1377
+ device=(
1378
+ input_embeds.device if input_embeds is not None else input_ids.device
1379
+ ),
1380
+ )
1469
1381
 
1470
1382
  if input_embeds is None:
1471
1383
  hidden_states = self.embed_tokens(input_ids)
@@ -1477,7 +1389,7 @@ class DeepseekV2Model(nn.Module):
1477
1389
  expert_distribution_recorder.set_current_layer(i)
1478
1390
  layer = self.layers[i]
1479
1391
  hidden_states, residual = layer(
1480
- positions, hidden_states, forward_batch, residual
1392
+ positions, hidden_states, forward_batch, residual, zero_allocator
1481
1393
  )
1482
1394
  if not forward_batch.forward_mode.is_idle():
1483
1395
  if residual is None:
@@ -1500,24 +1412,33 @@ class DeepseekV2ForCausalLM(nn.Module):
1500
1412
  self.tp_size = get_tensor_model_parallel_world_size()
1501
1413
  self.quant_config = quant_config
1502
1414
  self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1503
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1504
- if (
1505
- global_server_args_dict.get("disable_shared_experts_fusion", False)
1506
- or self.config.architectures[0] != "DeepseekV3ForCausalLM"
1507
- or self.config.n_routed_experts != 256
1508
- or self.config.routed_scaling_factor != 2.5
1509
- ):
1510
- self.n_share_experts_fusion = None
1511
- global_server_args_dict["n_share_experts_fusion"] = None
1512
- logger.info(
1513
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1514
- )
1515
- elif self.n_share_experts_fusion is None:
1516
- global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1517
- self.n_share_experts_fusion = self.tp_size
1518
- logger.info(
1519
- f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
1520
- )
1415
+ if self.n_share_experts_fusion > 0:
1416
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1417
+ if (
1418
+ self.config.architectures[0] != "DeepseekV3ForCausalLM"
1419
+ or self.config.n_routed_experts != 256
1420
+ ):
1421
+ self.n_share_experts_fusion = 0
1422
+ global_server_args_dict["n_share_experts_fusion"] = 0
1423
+ logger.info(
1424
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1425
+ )
1426
+ else:
1427
+ assert (
1428
+ self.n_share_experts_fusion == self.tp_size
1429
+ ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
1430
+ elif self.n_share_experts_fusion == 0:
1431
+ if (
1432
+ torch.cuda.get_device_capability("cuda") >= (9, 0)
1433
+ and self.config.architectures[0] == "DeepseekV3ForCausalLM"
1434
+ and self.config.n_routed_experts == 256
1435
+ and (not global_server_args_dict["enable_deepep_moe"])
1436
+ ):
1437
+ self.n_share_experts_fusion = self.tp_size
1438
+ global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1439
+ logger.info(
1440
+ "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1441
+ )
1521
1442
 
1522
1443
  self.model = DeepseekV2Model(
1523
1444
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -1552,78 +1473,92 @@ class DeepseekV2ForCausalLM(nn.Module):
1552
1473
  def post_load_weights(self):
1553
1474
 
1554
1475
  # Perform post-processing after loading weights
1555
-
1556
- if not global_server_args_dict["disable_mla"]:
1557
- for layer_id in range(self.config.num_hidden_layers):
1558
- self_attn = self.model.layers[layer_id].self_attn
1559
- if hasattr(self_attn.kv_b_proj, "qweight"):
1560
- # AWQ compatible
1561
- if _is_cuda:
1562
- w = awq_dequantize(
1563
- self_attn.kv_b_proj.qweight,
1564
- self_attn.kv_b_proj.scales,
1565
- self_attn.kv_b_proj.qzeros,
1566
- ).T
1567
- else:
1568
- w = ops.awq_dequantize(
1569
- self_attn.kv_b_proj.qweight,
1570
- self_attn.kv_b_proj.scales,
1571
- self_attn.kv_b_proj.qzeros,
1572
- 0,
1573
- 0,
1574
- 0,
1575
- ).T
1476
+ for layer_id in range(self.config.num_hidden_layers):
1477
+ self_attn = self.model.layers[layer_id].self_attn
1478
+ if hasattr(self_attn.kv_b_proj, "qweight"):
1479
+ # AWQ compatible
1480
+ if _is_cuda:
1481
+ w = awq_dequantize(
1482
+ self_attn.kv_b_proj.qweight,
1483
+ self_attn.kv_b_proj.scales,
1484
+ self_attn.kv_b_proj.qzeros,
1485
+ ).T
1576
1486
  else:
1577
- w = self_attn.kv_b_proj.weight
1578
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1579
- # This may affect the accuracy of fp8 model.
1580
- if w.dtype in (
1581
- torch.float8_e4m3fn,
1582
- torch.float8_e4m3fnuz,
1583
- ):
1584
- if hasattr(self.quant_config, "weight_block_size"):
1585
- weight_block_size = self.quant_config.weight_block_size
1586
- if weight_block_size is not None:
1587
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1588
- if _is_hip:
1589
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1590
- weight=w,
1591
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1592
- input_scale=None,
1593
- )
1594
- else:
1595
- weight = w
1596
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1487
+ w = awq_dequantize(
1488
+ self_attn.kv_b_proj.qweight,
1489
+ self_attn.kv_b_proj.scales,
1490
+ self_attn.kv_b_proj.qzeros,
1491
+ 0,
1492
+ 0,
1493
+ 0,
1494
+ ).T
1495
+ else:
1496
+ w = self_attn.kv_b_proj.weight
1497
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1498
+ # This may affect the accuracy of fp8 model.
1499
+ # Fix deepseek v3 blockwise bmm by using deep_gemm
1500
+ use_deep_gemm_bmm = False
1501
+ model_dtype = torch.get_default_dtype()
1502
+
1503
+ if w.dtype in (
1504
+ torch.float8_e4m3fn,
1505
+ torch.float8_e4m3fnuz,
1506
+ ):
1507
+ if hasattr(self.quant_config, "weight_block_size"):
1508
+ weight_block_size = self.quant_config.weight_block_size
1509
+ if weight_block_size is not None:
1510
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1511
+ if _is_hip:
1512
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1513
+ weight=w,
1514
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1515
+ input_scale=None,
1516
+ )
1517
+ else:
1518
+ weight = w
1519
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1597
1520
 
1521
+ if (
1522
+ _is_cuda
1523
+ and _enable_jit_deepgemm_bmm
1524
+ and weight_block_size[0] == 128
1525
+ and weight_block_size[1] == 128
1526
+ and model_dtype == torch.bfloat16
1527
+ ):
1528
+ block_scale = weight_scale
1529
+ use_deep_gemm_bmm = True
1530
+ else:
1598
1531
  w, scale = block_quant_to_tensor_quant(
1599
1532
  weight, weight_scale, weight_block_size
1600
1533
  )
1601
1534
  self_attn.w_scale = scale
1602
- else:
1535
+ else:
1536
+ weight = w
1537
+ weight_scale = self_attn.kv_b_proj.weight_scale
1538
+ w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1539
+ self_attn.w_scale = scale
1540
+
1541
+ if w.dtype == torch.int8:
1542
+ if hasattr(self.quant_config, "weight_block_size"):
1543
+ # block-wise int8 need it
1544
+ weight_block_size = self.quant_config.weight_block_size
1545
+ if weight_block_size is not None:
1546
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1603
1547
  weight = w
1604
- weight_scale = self_attn.kv_b_proj.weight_scale
1605
- w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1606
- self_attn.w_scale = scale
1607
-
1608
- if w.dtype == torch.int8:
1609
- if hasattr(self.quant_config, "weight_block_size"):
1610
- # block-wise int8 need it
1611
- weight_block_size = self.quant_config.weight_block_size
1612
- if weight_block_size is not None:
1613
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1614
- weight = w
1615
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1616
- w = int8_block_dequant(
1617
- weight, weight_scale, weight_block_size
1618
- ).to(torch.bfloat16)
1619
- else:
1620
- # channel-wise int8 need it
1621
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1622
- torch.bfloat16
1623
- )
1624
- w_kc, w_vc = w.unflatten(
1625
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1626
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1548
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1549
+ w = int8_block_dequant(
1550
+ weight, weight_scale, weight_block_size
1551
+ ).to(torch.bfloat16)
1552
+ else:
1553
+ # channel-wise int8 need it
1554
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1555
+ torch.bfloat16
1556
+ )
1557
+
1558
+ w_kc, w_vc = w.unflatten(
1559
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1560
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1561
+ if not use_deep_gemm_bmm:
1627
1562
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1628
1563
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1629
1564
  if (
@@ -1633,6 +1568,17 @@ class DeepseekV2ForCausalLM(nn.Module):
1633
1568
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1634
1569
  if _is_hip:
1635
1570
  self_attn.w_scale *= 2.0
1571
+ else:
1572
+ num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
1573
+ num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
1574
+ ws_kc, ws_vc = block_scale.unflatten(
1575
+ 0, (-1, (num_tiles_k + num_tiles_n))
1576
+ ).split([num_tiles_k, num_tiles_n], dim=1)
1577
+ self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1578
+ self_attn.w_scale_v = ws_vc.contiguous()
1579
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1580
+ self_attn.w_vc = w_vc.contiguous()
1581
+ self_attn.use_deep_gemm_bmm = True
1636
1582
 
1637
1583
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1638
1584
  stacked_params_mapping = [
@@ -1640,7 +1586,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1640
1586
  ("gate_up_proj", "gate_proj", 0),
1641
1587
  ("gate_up_proj", "up_proj", 1),
1642
1588
  ]
1643
- if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
1589
+ if self.n_share_experts_fusion > 0:
1644
1590
  weights_list = list(weights)
1645
1591
  weights_dict = dict(weights_list)
1646
1592
  if self.quant_config.get_name() == "w8a8_int8":
@@ -1699,12 +1645,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1699
1645
  ckpt_gate_proj_name="gate_proj",
1700
1646
  ckpt_down_proj_name="down_proj",
1701
1647
  ckpt_up_proj_name="up_proj",
1702
- num_experts=self.config.n_routed_experts
1703
- + (
1704
- self.n_share_experts_fusion
1705
- if self.n_share_experts_fusion is not None
1706
- else 0
1707
- ),
1648
+ num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
1708
1649
  )
1709
1650
 
1710
1651
  params_dict = dict(self.named_parameters())