sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,8 @@
18
18
 
19
19
  import logging
20
20
  import os
21
- from enum import IntEnum, auto
21
+ from dataclasses import dataclass
22
+ from enum import Enum, IntEnum, auto
22
23
  from typing import Any, Dict, Iterable, Optional, Tuple
23
24
 
24
25
  import torch
@@ -28,6 +29,7 @@ from tqdm import tqdm
28
29
  from transformers import PretrainedConfig
29
30
 
30
31
  from sglang.srt.distributed import (
32
+ get_tensor_model_parallel_rank,
31
33
  get_tensor_model_parallel_world_size,
32
34
  parallel_state,
33
35
  tensor_model_parallel_all_reduce,
@@ -51,10 +53,15 @@ from sglang.srt.layers.linear import (
51
53
  )
52
54
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
55
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
56
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
54
57
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
58
  from sglang.srt.layers.moe.topk import select_experts
56
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
- from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
60
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
+ from sglang.srt.layers.quantization.fp8_kernel import (
62
+ per_tensor_quant_mla_deep_gemm_masked_fp8,
63
+ per_tensor_quant_mla_fp8,
64
+ )
58
65
  from sglang.srt.layers.quantization.fp8_utils import (
59
66
  block_quant_to_tensor_quant,
60
67
  channel_quant_to_tensor_quant,
@@ -73,7 +80,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
73
80
  from sglang.srt.managers.schedule_batch import global_server_args_dict
74
81
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
75
82
  from sglang.srt.model_loader.weight_utils import default_weight_loader
76
- from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
83
+ from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
77
84
 
78
85
  _is_hip = is_hip()
79
86
  _is_cuda = is_cuda()
@@ -81,9 +88,11 @@ _is_cuda = is_cuda()
81
88
  if _is_cuda:
82
89
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
83
90
 
84
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
91
+ from sglang.srt.layers.quantization.deep_gemm import (
92
+ grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
93
+ )
85
94
  else:
86
- from vllm import _custom_ops as ops
95
+ from vllm._custom_ops import awq_dequantize
87
96
 
88
97
  if _is_hip:
89
98
  from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
@@ -96,7 +105,6 @@ logger = logging.getLogger(__name__)
96
105
 
97
106
 
98
107
  class AttnForwardMethod(IntEnum):
99
-
100
108
  # Use multi-head attention
101
109
  MHA = auto()
102
110
 
@@ -147,7 +155,7 @@ class DeepseekV2MLP(nn.Module):
147
155
  )
148
156
  self.act_fn = SiluAndMul()
149
157
 
150
- def forward(self, x):
158
+ def forward(self, x, forward_mode: Optional[ForwardMode] = None):
151
159
  gate_up, _ = self.gate_up_proj(x)
152
160
  x = self.act_fn(gate_up)
153
161
  x, _ = self.down_proj(x)
@@ -188,11 +196,7 @@ class DeepseekV2MoE(nn.Module):
188
196
  self.tp_size = get_tensor_model_parallel_world_size()
189
197
  self.routed_scaling_factor = config.routed_scaling_factor
190
198
  self.n_shared_experts = config.n_shared_experts
191
- self.n_share_experts_fusion = (
192
- global_server_args_dict["n_share_experts_fusion"]
193
- if global_server_args_dict["n_share_experts_fusion"] is not None
194
- else 0
195
- )
199
+ self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
196
200
 
197
201
  if self.tp_size > config.n_routed_experts:
198
202
  raise ValueError(
@@ -225,6 +229,7 @@ class DeepseekV2MoE(nn.Module):
225
229
  num_expert_group=config.n_group,
226
230
  topk_group=config.topk_group,
227
231
  correction_bias=self.gate.e_score_correction_bias,
232
+ routed_scaling_factor=self.routed_scaling_factor,
228
233
  prefix=add_prefix("experts", prefix),
229
234
  **(
230
235
  dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
@@ -333,6 +338,7 @@ class DeepseekV2MoE(nn.Module):
333
338
  topk_group=self.topk_group,
334
339
  num_expert_group=self.num_expert_group,
335
340
  correction_bias=self.correction_bias,
341
+ routed_scaling_factor=self.routed_scaling_factor,
336
342
  )
337
343
  if self.ep_size > 1:
338
344
  # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
@@ -373,7 +379,7 @@ class DeepseekV2MoE(nn.Module):
373
379
  return final_hidden_states
374
380
 
375
381
  def _forward_shared_experts(self, hidden_states):
376
- if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
382
+ if self.n_share_experts_fusion == 0:
377
383
  return self.shared_experts(hidden_states)
378
384
  else:
379
385
  return None
@@ -387,179 +393,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
387
393
  return 0.1 * mscale * math.log(scale) + 1.0
388
394
 
389
395
 
390
- class DeepseekV2Attention(nn.Module):
391
-
392
- def __init__(
393
- self,
394
- config: PretrainedConfig,
395
- hidden_size: int,
396
- num_heads: int,
397
- qk_nope_head_dim: int,
398
- qk_rope_head_dim: int,
399
- v_head_dim: int,
400
- q_lora_rank: int,
401
- kv_lora_rank: int,
402
- rope_theta: float = 10000,
403
- rope_scaling: Optional[Dict[str, Any]] = None,
404
- max_position_embeddings: int = 8192,
405
- quant_config: Optional[QuantizationConfig] = None,
406
- layer_id=None,
407
- reduce_results: bool = True,
408
- prefix: str = "",
409
- ) -> None:
410
- super().__init__()
411
- self.layer_id = layer_id
412
- self.hidden_size = hidden_size
413
- self.qk_nope_head_dim = qk_nope_head_dim
414
- self.qk_rope_head_dim = qk_rope_head_dim
415
- self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
416
- self.v_head_dim = v_head_dim
417
- self.q_lora_rank = q_lora_rank
418
- self.kv_lora_rank = kv_lora_rank
419
-
420
- self.dp_size = get_attention_dp_size()
421
- attn_tp_rank = get_attention_tp_rank()
422
- attn_tp_size = get_attention_tp_size()
423
-
424
- self.num_heads = num_heads
425
- assert num_heads % attn_tp_size == 0
426
- self.num_local_heads = num_heads // attn_tp_size
427
- self.scaling = self.qk_head_dim**-0.5
428
- self.rope_theta = rope_theta
429
- self.max_position_embeddings = max_position_embeddings
430
-
431
- if self.q_lora_rank is not None:
432
- self.q_a_proj = ReplicatedLinear(
433
- self.hidden_size,
434
- self.q_lora_rank,
435
- bias=False,
436
- quant_config=quant_config,
437
- prefix=add_prefix("q_a_proj", prefix),
438
- )
439
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
440
- self.q_b_proj = ColumnParallelLinear(
441
- q_lora_rank,
442
- self.num_heads * self.qk_head_dim,
443
- bias=False,
444
- quant_config=quant_config,
445
- prefix=add_prefix("q_b_proj", prefix),
446
- )
447
- else:
448
- self.q_proj = ColumnParallelLinear(
449
- self.hidden_size,
450
- self.num_heads * self.qk_head_dim,
451
- bias=False,
452
- quant_config=quant_config,
453
- prefix=add_prefix("q_proj", prefix),
454
- tp_rank=attn_tp_rank,
455
- tp_size=attn_tp_size,
456
- )
457
-
458
- self.kv_a_proj_with_mqa = ReplicatedLinear(
459
- self.hidden_size,
460
- self.kv_lora_rank + self.qk_rope_head_dim,
461
- bias=False,
462
- quant_config=quant_config,
463
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
464
- )
465
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
466
- self.kv_b_proj = ColumnParallelLinear(
467
- self.kv_lora_rank,
468
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
469
- bias=False,
470
- quant_config=quant_config,
471
- prefix=add_prefix("kv_b_proj", prefix),
472
- )
473
- # O projection.
474
- self.o_proj = RowParallelLinear(
475
- self.num_heads * self.v_head_dim,
476
- self.hidden_size,
477
- bias=False,
478
- quant_config=quant_config,
479
- prefix=add_prefix("o_proj", prefix),
480
- reduce_results=reduce_results,
481
- tp_rank=attn_tp_rank,
482
- tp_size=attn_tp_size,
483
- )
484
- rope_scaling["rope_type"] = "deepseek_yarn"
485
- self.rotary_emb = get_rope_wrapper(
486
- qk_rope_head_dim,
487
- rotary_dim=qk_rope_head_dim,
488
- max_position=max_position_embeddings,
489
- base=rope_theta,
490
- rope_scaling=rope_scaling,
491
- is_neox_style=False,
492
- device=global_server_args_dict["device"],
493
- )
494
-
495
- if rope_scaling:
496
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
497
- scaling_factor = rope_scaling["factor"]
498
- mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
499
- self.scaling = self.scaling * mscale * mscale
500
-
501
- # TODO, support head_size 192
502
- self.attn = RadixAttention(
503
- self.num_local_heads,
504
- 256,
505
- self.scaling,
506
- num_kv_heads=self.num_local_heads,
507
- layer_id=layer_id,
508
- quant_config=quant_config,
509
- prefix=add_prefix("attn", prefix),
510
- )
511
-
512
- def forward(
513
- self,
514
- positions: torch.Tensor,
515
- hidden_states: torch.Tensor,
516
- forward_batch: ForwardBatch,
517
- ) -> torch.Tensor:
518
- if hidden_states.shape[0] == 0:
519
- assert (
520
- not self.o_proj.reduce_results
521
- ), "short-circuiting allreduce will lead to hangs"
522
- return hidden_states
523
-
524
- if self.q_lora_rank is not None:
525
- q = self.q_a_proj(hidden_states)[0]
526
- q = self.q_a_layernorm(q)
527
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
528
- else:
529
- q = self.q_proj(hidden_states)[0].view(
530
- -1, self.num_local_heads, self.qk_head_dim
531
- )
532
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
533
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
534
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
535
- latent_cache = latent_cache.unsqueeze(1)
536
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
537
- kv = self.kv_b_proj(kv_a)[0]
538
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
539
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
540
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
541
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
542
- q[..., self.qk_nope_head_dim :] = q_pe
543
- k = torch.empty_like(q)
544
- k[..., : self.qk_nope_head_dim] = k_nope
545
- k[..., self.qk_nope_head_dim :] = k_pe
546
- q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
547
- -1, self.num_local_heads * 256
548
- )
549
- k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
550
- -1, self.num_local_heads * 256
551
- )
552
- v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
553
- -1, self.num_local_heads * 256
554
- )
555
- attn_output = self.attn(q, k, v, forward_batch)
556
- attn_output = attn_output.view(-1, self.num_local_heads, 256)[
557
- ..., : self.v_head_dim
558
- ].reshape(-1, self.num_local_heads * self.v_head_dim)
559
- output, _ = self.o_proj(attn_output)
560
- return output
561
-
562
-
563
396
  class DeepseekV2AttentionMLA(nn.Module):
564
397
 
565
398
  def __init__(
@@ -705,6 +538,10 @@ class DeepseekV2AttentionMLA(nn.Module):
705
538
  self.w_vc = None
706
539
  self.w_scale = None
707
540
 
541
+ self.w_scale_k = None
542
+ self.w_scale_v = None
543
+ self.use_deep_gemm_bmm = False
544
+
708
545
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
709
546
  "flashinfer_mla_disable_ragged"
710
547
  ]
@@ -762,6 +599,7 @@ class DeepseekV2AttentionMLA(nn.Module):
762
599
  positions: torch.Tensor,
763
600
  hidden_states: torch.Tensor,
764
601
  forward_batch: ForwardBatch,
602
+ zero_allocator: BumpAllocator,
765
603
  ) -> torch.Tensor:
766
604
  if hidden_states.shape[0] == 0:
767
605
  assert (
@@ -787,9 +625,13 @@ class DeepseekV2AttentionMLA(nn.Module):
787
625
  positions, hidden_states, forward_batch
788
626
  )
789
627
  else:
790
- return self.forward_absorb(positions, hidden_states, forward_batch)
628
+ return self.forward_absorb(
629
+ positions, hidden_states, forward_batch, zero_allocator
630
+ )
791
631
  else:
792
- return self.forward_absorb(positions, hidden_states, forward_batch)
632
+ return self.forward_absorb(
633
+ positions, hidden_states, forward_batch, zero_allocator
634
+ )
793
635
 
794
636
  def forward_normal(
795
637
  self,
@@ -838,6 +680,7 @@ class DeepseekV2AttentionMLA(nn.Module):
838
680
  positions: torch.Tensor,
839
681
  hidden_states: torch.Tensor,
840
682
  forward_batch: ForwardBatch,
683
+ zero_allocator: BumpAllocator,
841
684
  ) -> torch.Tensor:
842
685
  q_len = hidden_states.shape[0]
843
686
  q_input = hidden_states.new_empty(
@@ -853,7 +696,24 @@ class DeepseekV2AttentionMLA(nn.Module):
853
696
  )
854
697
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
855
698
 
856
- if self.w_kc.dtype == torch.float8_e4m3fnuz:
699
+ if self.use_deep_gemm_bmm:
700
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
701
+ per_tensor_quant_mla_deep_gemm_masked_fp8(
702
+ q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
703
+ )
704
+ )
705
+ q_nope_out = q_nope.new_empty(
706
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
707
+ )
708
+ deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
709
+ (q_nope_val, q_nope_scale),
710
+ (self.w_kc, self.w_scale_k),
711
+ q_nope_out,
712
+ masked_m,
713
+ expected_m,
714
+ )
715
+ q_nope_out = q_nope_out[:, :expected_m, :]
716
+ elif self.w_kc.dtype == torch.float8_e4m3fnuz:
857
717
  # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
858
718
  q_nope_out = torch.bmm(
859
719
  q_nope.to(torch.bfloat16).transpose(0, 1),
@@ -861,7 +721,8 @@ class DeepseekV2AttentionMLA(nn.Module):
861
721
  )
862
722
  elif self.w_kc.dtype == torch.float8_e4m3fn:
863
723
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
864
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
724
+ q_nope.transpose(0, 1),
725
+ zero_allocator.allocate(1),
865
726
  )
866
727
  q_nope_out = bmm_fp8(
867
728
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -884,7 +745,24 @@ class DeepseekV2AttentionMLA(nn.Module):
884
745
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
885
746
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
886
747
 
887
- if self.w_vc.dtype == torch.float8_e4m3fnuz:
748
+ if self.use_deep_gemm_bmm:
749
+ attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
750
+ per_tensor_quant_mla_deep_gemm_masked_fp8(
751
+ attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
752
+ )
753
+ )
754
+ attn_bmm_output = attn_output.new_empty(
755
+ (self.num_local_heads, aligned_m, self.v_head_dim)
756
+ )
757
+ deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
758
+ (attn_output_val, attn_output_scale),
759
+ (self.w_vc, self.w_scale_v),
760
+ attn_bmm_output,
761
+ masked_m,
762
+ expected_m,
763
+ )
764
+ attn_bmm_output = attn_bmm_output[:, :expected_m, :]
765
+ elif self.w_vc.dtype == torch.float8_e4m3fnuz:
888
766
  # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
889
767
  attn_bmm_output = torch.bmm(
890
768
  attn_output.to(torch.bfloat16).transpose(0, 1),
@@ -892,7 +770,8 @@ class DeepseekV2AttentionMLA(nn.Module):
892
770
  )
893
771
  elif self.w_vc.dtype == torch.float8_e4m3fn:
894
772
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
895
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
773
+ attn_output.transpose(0, 1),
774
+ zero_allocator.allocate(1),
896
775
  )
897
776
  attn_bmm_output = bmm_fp8(
898
777
  attn_output_val,
@@ -913,6 +792,7 @@ class DeepseekV2AttentionMLA(nn.Module):
913
792
  positions: torch.Tensor,
914
793
  hidden_states: torch.Tensor,
915
794
  forward_batch: ForwardBatch,
795
+ zero_allocator: BumpAllocator,
916
796
  ) -> torch.Tensor:
917
797
  enable_rope_fusion = (
918
798
  os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
@@ -939,7 +819,9 @@ class DeepseekV2AttentionMLA(nn.Module):
939
819
  )
940
820
  elif self.w_kc.dtype == torch.float8_e4m3fn:
941
821
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
942
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
822
+ q_nope.transpose(0, 1),
823
+ zero_allocator.allocate(1),
824
+ dtype=torch.float8_e4m3fn,
943
825
  )
944
826
  q_nope_out = bmm_fp8(
945
827
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -1035,7 +917,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1035
917
  )
1036
918
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1037
919
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1038
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
920
+ attn_output.transpose(0, 1),
921
+ zero_allocator.allocate(1),
922
+ dtype=torch.float8_e4m3fn,
1039
923
  )
1040
924
  attn_bmm_output = bmm_fp8(
1041
925
  attn_output_val,
@@ -1173,6 +1057,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1173
1057
  return output
1174
1058
 
1175
1059
 
1060
+ class _FFNInputMode(Enum):
1061
+ # The MLP sublayer requires 1/tp_size tokens as input
1062
+ SCATTERED = auto()
1063
+ # The MLP sublayer requires all tokens as input
1064
+ FULL = auto()
1065
+
1066
+
1067
+ @dataclass
1068
+ class _DecoderLayerInfo:
1069
+ is_sparse: bool
1070
+ ffn_input_mode: _FFNInputMode
1071
+
1072
+
1176
1073
  class DeepseekV2DecoderLayer(nn.Module):
1177
1074
 
1178
1075
  def __init__(
@@ -1183,14 +1080,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1183
1080
  is_nextn: bool = False,
1184
1081
  prefix: str = "",
1185
1082
  ) -> None:
1186
-
1187
- def is_sparse_layer(l: int):
1188
- return (
1189
- config.n_routed_experts is not None
1190
- and l >= config.first_k_dense_replace
1191
- and l % config.moe_layer_freq == 0
1192
- )
1193
-
1194
1083
  super().__init__()
1195
1084
  self.hidden_size = config.hidden_size
1196
1085
  rope_theta = getattr(config, "rope_theta", 10000)
@@ -1201,68 +1090,54 @@ class DeepseekV2DecoderLayer(nn.Module):
1201
1090
  self.dp_size = get_attention_dp_size()
1202
1091
  self.attn_tp_size = get_attention_tp_size()
1203
1092
  self.attn_tp_rank = get_attention_tp_rank()
1093
+ self.self_attn = DeepseekV2AttentionMLA(
1094
+ config=config,
1095
+ hidden_size=self.hidden_size,
1096
+ num_heads=config.num_attention_heads,
1097
+ qk_nope_head_dim=config.qk_nope_head_dim,
1098
+ qk_rope_head_dim=config.qk_rope_head_dim,
1099
+ v_head_dim=config.v_head_dim,
1100
+ q_lora_rank=(
1101
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1102
+ ),
1103
+ kv_lora_rank=config.kv_lora_rank,
1104
+ rope_theta=rope_theta,
1105
+ rope_scaling=rope_scaling,
1106
+ max_position_embeddings=max_position_embeddings,
1107
+ quant_config=quant_config,
1108
+ layer_id=layer_id,
1109
+ reduce_results=False,
1110
+ prefix=add_prefix("self_attn", prefix),
1111
+ )
1204
1112
 
1205
- if not global_server_args_dict["disable_mla"]:
1206
- self.self_attn = DeepseekV2AttentionMLA(
1207
- config=config,
1208
- hidden_size=self.hidden_size,
1209
- num_heads=config.num_attention_heads,
1210
- qk_nope_head_dim=config.qk_nope_head_dim,
1211
- qk_rope_head_dim=config.qk_rope_head_dim,
1212
- v_head_dim=config.v_head_dim,
1213
- q_lora_rank=(
1214
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1215
- ),
1216
- kv_lora_rank=config.kv_lora_rank,
1217
- rope_theta=rope_theta,
1218
- rope_scaling=rope_scaling,
1219
- max_position_embeddings=max_position_embeddings,
1220
- quant_config=quant_config,
1221
- layer_id=layer_id,
1222
- reduce_results=False,
1223
- prefix=add_prefix("self_attn", prefix),
1224
- )
1225
- else:
1226
- self.self_attn = DeepseekV2Attention(
1227
- config=config,
1228
- hidden_size=self.hidden_size,
1229
- num_heads=config.num_attention_heads,
1230
- qk_nope_head_dim=config.qk_nope_head_dim,
1231
- qk_rope_head_dim=config.qk_rope_head_dim,
1232
- v_head_dim=config.v_head_dim,
1233
- q_lora_rank=(
1234
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1235
- ),
1236
- kv_lora_rank=config.kv_lora_rank,
1237
- rope_theta=rope_theta,
1238
- rope_scaling=rope_scaling,
1239
- max_position_embeddings=max_position_embeddings,
1240
- quant_config=quant_config,
1241
- layer_id=layer_id,
1242
- reduce_results=False,
1243
- prefix=add_prefix("self_attn", prefix),
1244
- )
1113
+ self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
1114
+ previous_layer_info = self._compute_info(
1115
+ config, layer_id=layer_id - 1, is_nextn=False
1116
+ )
1245
1117
 
1246
- if is_nextn or is_sparse_layer(layer_id):
1118
+ if self.info.is_sparse:
1247
1119
  self.mlp = DeepseekV2MoE(
1248
1120
  config=config,
1249
1121
  quant_config=quant_config,
1250
1122
  prefix=add_prefix("mlp", prefix),
1251
1123
  )
1252
- self.is_sparse = True
1253
1124
  else:
1125
+ if self._enable_moe_dense_fully_dp():
1126
+ mlp_tp_rank, mlp_tp_size = 0, 1
1127
+ else:
1128
+ mlp_tp_rank, mlp_tp_size = None, None
1254
1129
  self.mlp = DeepseekV2MLP(
1255
1130
  hidden_size=config.hidden_size,
1256
1131
  intermediate_size=config.intermediate_size,
1257
1132
  hidden_act=config.hidden_act,
1258
1133
  quant_config=quant_config,
1259
1134
  prefix=add_prefix("mlp", prefix),
1135
+ tp_rank=mlp_tp_rank,
1136
+ tp_size=mlp_tp_size,
1260
1137
  )
1261
- self.is_sparse = False
1262
1138
 
1263
1139
  self.input_is_scattered = (
1264
- is_sparse_layer(layer_id - 1)
1265
- and global_server_args_dict["enable_deepep_moe"]
1140
+ previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1266
1141
  )
1267
1142
  self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1268
1143
 
@@ -1271,28 +1146,51 @@ class DeepseekV2DecoderLayer(nn.Module):
1271
1146
  config.hidden_size, eps=config.rms_norm_eps
1272
1147
  )
1273
1148
 
1149
+ @staticmethod
1150
+ def _enable_moe_dense_fully_dp():
1151
+ return global_server_args_dict["moe_dense_tp_size"] == 1
1152
+
1153
+ @staticmethod
1154
+ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
1155
+ is_sparse = is_nextn or (
1156
+ config.n_routed_experts is not None
1157
+ and layer_id >= config.first_k_dense_replace
1158
+ and layer_id % config.moe_layer_freq == 0
1159
+ )
1160
+ ffn_input_mode = (
1161
+ _FFNInputMode.SCATTERED
1162
+ if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1163
+ or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1164
+ else _FFNInputMode.FULL
1165
+ )
1166
+ return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
1167
+
1274
1168
  def forward(
1275
1169
  self,
1276
1170
  positions: torch.Tensor,
1277
1171
  hidden_states: torch.Tensor,
1278
1172
  forward_batch: ForwardBatch,
1279
1173
  residual: Optional[torch.Tensor],
1174
+ zero_allocator: BumpAllocator,
1280
1175
  ) -> torch.Tensor:
1281
- if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
1282
- return self.forward_deepep(
1283
- positions, hidden_states, forward_batch, residual
1176
+ if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
1177
+ return self.forward_ffn_with_scattered_input(
1178
+ positions, hidden_states, forward_batch, residual, zero_allocator
1284
1179
  )
1285
- else:
1286
- return self.forward_normal(
1287
- positions, hidden_states, forward_batch, residual
1180
+ elif self.info.ffn_input_mode == _FFNInputMode.FULL:
1181
+ return self.forward_ffn_with_full_input(
1182
+ positions, hidden_states, forward_batch, residual, zero_allocator
1288
1183
  )
1184
+ else:
1185
+ raise NotImplementedError
1289
1186
 
1290
- def forward_normal(
1187
+ def forward_ffn_with_full_input(
1291
1188
  self,
1292
1189
  positions: torch.Tensor,
1293
1190
  hidden_states: torch.Tensor,
1294
1191
  forward_batch: ForwardBatch,
1295
1192
  residual: Optional[torch.Tensor],
1193
+ zero_allocator: BumpAllocator,
1296
1194
  ) -> torch.Tensor:
1297
1195
 
1298
1196
  if hidden_states.shape[0] == 0:
@@ -1313,6 +1211,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1313
1211
  positions=positions,
1314
1212
  hidden_states=hidden_states,
1315
1213
  forward_batch=forward_batch,
1214
+ zero_allocator=zero_allocator,
1316
1215
  )
1317
1216
 
1318
1217
  # Gather
@@ -1354,12 +1253,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1354
1253
 
1355
1254
  return hidden_states, residual
1356
1255
 
1357
- def forward_deepep(
1256
+ def forward_ffn_with_scattered_input(
1358
1257
  self,
1359
1258
  positions: torch.Tensor,
1360
1259
  hidden_states: torch.Tensor,
1361
1260
  forward_batch: ForwardBatch,
1362
1261
  residual: Optional[torch.Tensor],
1262
+ zero_allocator: BumpAllocator,
1363
1263
  ) -> torch.Tensor:
1364
1264
 
1365
1265
  if hidden_states.shape[0] == 0:
@@ -1385,6 +1285,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1385
1285
  positions=positions,
1386
1286
  hidden_states=hidden_states,
1387
1287
  forward_batch=forward_batch,
1288
+ zero_allocator=zero_allocator,
1388
1289
  )
1389
1290
 
1390
1291
  if self.attn_tp_size != 1:
@@ -1410,7 +1311,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1410
1311
  hidden_states, residual = self.post_attention_layernorm(
1411
1312
  hidden_states, residual
1412
1313
  )
1413
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1314
+
1315
+ if not (
1316
+ self._enable_moe_dense_fully_dp()
1317
+ and (not self.info.is_sparse)
1318
+ and hidden_states.shape[0] == 0
1319
+ ):
1320
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1414
1321
 
1415
1322
  if self.is_last_layer and self.attn_tp_size != 1:
1416
1323
  hidden_states += residual
@@ -1466,6 +1373,14 @@ class DeepseekV2Model(nn.Module):
1466
1373
  forward_batch: ForwardBatch,
1467
1374
  input_embeds: torch.Tensor = None,
1468
1375
  ) -> torch.Tensor:
1376
+ zero_allocator = BumpAllocator(
1377
+ # TODO for two-batch-overlap, we need a larger buffer size
1378
+ buffer_size=len(self.layers) * 2,
1379
+ dtype=torch.float32,
1380
+ device=(
1381
+ input_embeds.device if input_embeds is not None else input_ids.device
1382
+ ),
1383
+ )
1469
1384
 
1470
1385
  if input_embeds is None:
1471
1386
  hidden_states = self.embed_tokens(input_ids)
@@ -1477,7 +1392,7 @@ class DeepseekV2Model(nn.Module):
1477
1392
  expert_distribution_recorder.set_current_layer(i)
1478
1393
  layer = self.layers[i]
1479
1394
  hidden_states, residual = layer(
1480
- positions, hidden_states, forward_batch, residual
1395
+ positions, hidden_states, forward_batch, residual, zero_allocator
1481
1396
  )
1482
1397
  if not forward_batch.forward_mode.is_idle():
1483
1398
  if residual is None:
@@ -1500,24 +1415,33 @@ class DeepseekV2ForCausalLM(nn.Module):
1500
1415
  self.tp_size = get_tensor_model_parallel_world_size()
1501
1416
  self.quant_config = quant_config
1502
1417
  self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1503
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1504
- if (
1505
- global_server_args_dict.get("disable_shared_experts_fusion", False)
1506
- or self.config.architectures[0] != "DeepseekV3ForCausalLM"
1507
- or self.config.n_routed_experts != 256
1508
- or self.config.routed_scaling_factor != 2.5
1509
- ):
1510
- self.n_share_experts_fusion = None
1511
- global_server_args_dict["n_share_experts_fusion"] = None
1512
- logger.info(
1513
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1514
- )
1515
- elif self.n_share_experts_fusion is None:
1516
- global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1517
- self.n_share_experts_fusion = self.tp_size
1518
- logger.info(
1519
- f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
1520
- )
1418
+ if self.n_share_experts_fusion > 0:
1419
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1420
+ if (
1421
+ self.config.architectures[0] != "DeepseekV3ForCausalLM"
1422
+ or self.config.n_routed_experts != 256
1423
+ ):
1424
+ self.n_share_experts_fusion = 0
1425
+ global_server_args_dict["n_share_experts_fusion"] = 0
1426
+ logger.info(
1427
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1428
+ )
1429
+ else:
1430
+ assert (
1431
+ self.n_share_experts_fusion == self.tp_size
1432
+ ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
1433
+ elif self.n_share_experts_fusion == 0:
1434
+ if (
1435
+ torch.cuda.get_device_capability("cuda") >= (9, 0)
1436
+ and self.config.architectures[0] == "DeepseekV3ForCausalLM"
1437
+ and self.config.n_routed_experts == 256
1438
+ and (not global_server_args_dict["enable_deepep_moe"])
1439
+ ):
1440
+ self.n_share_experts_fusion = self.tp_size
1441
+ global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1442
+ logger.info(
1443
+ "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1444
+ )
1521
1445
 
1522
1446
  self.model = DeepseekV2Model(
1523
1447
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -1552,78 +1476,92 @@ class DeepseekV2ForCausalLM(nn.Module):
1552
1476
  def post_load_weights(self):
1553
1477
 
1554
1478
  # Perform post-processing after loading weights
1555
-
1556
- if not global_server_args_dict["disable_mla"]:
1557
- for layer_id in range(self.config.num_hidden_layers):
1558
- self_attn = self.model.layers[layer_id].self_attn
1559
- if hasattr(self_attn.kv_b_proj, "qweight"):
1560
- # AWQ compatible
1561
- if _is_cuda:
1562
- w = awq_dequantize(
1563
- self_attn.kv_b_proj.qweight,
1564
- self_attn.kv_b_proj.scales,
1565
- self_attn.kv_b_proj.qzeros,
1566
- ).T
1567
- else:
1568
- w = ops.awq_dequantize(
1569
- self_attn.kv_b_proj.qweight,
1570
- self_attn.kv_b_proj.scales,
1571
- self_attn.kv_b_proj.qzeros,
1572
- 0,
1573
- 0,
1574
- 0,
1575
- ).T
1479
+ for layer_id in range(self.config.num_hidden_layers):
1480
+ self_attn = self.model.layers[layer_id].self_attn
1481
+ if hasattr(self_attn.kv_b_proj, "qweight"):
1482
+ # AWQ compatible
1483
+ if _is_cuda:
1484
+ w = awq_dequantize(
1485
+ self_attn.kv_b_proj.qweight,
1486
+ self_attn.kv_b_proj.scales,
1487
+ self_attn.kv_b_proj.qzeros,
1488
+ ).T
1576
1489
  else:
1577
- w = self_attn.kv_b_proj.weight
1578
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1579
- # This may affect the accuracy of fp8 model.
1580
- if w.dtype in (
1581
- torch.float8_e4m3fn,
1582
- torch.float8_e4m3fnuz,
1583
- ):
1584
- if hasattr(self.quant_config, "weight_block_size"):
1585
- weight_block_size = self.quant_config.weight_block_size
1586
- if weight_block_size is not None:
1587
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1588
- if _is_hip:
1589
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1590
- weight=w,
1591
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1592
- input_scale=None,
1593
- )
1594
- else:
1595
- weight = w
1596
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1490
+ w = awq_dequantize(
1491
+ self_attn.kv_b_proj.qweight,
1492
+ self_attn.kv_b_proj.scales,
1493
+ self_attn.kv_b_proj.qzeros,
1494
+ 0,
1495
+ 0,
1496
+ 0,
1497
+ ).T
1498
+ else:
1499
+ w = self_attn.kv_b_proj.weight
1500
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1501
+ # This may affect the accuracy of fp8 model.
1502
+ # Fix deepseek v3 blockwise bmm by using deep_gemm
1503
+ use_deep_gemm_bmm = False
1504
+ model_dtype = torch.get_default_dtype()
1505
+
1506
+ if w.dtype in (
1507
+ torch.float8_e4m3fn,
1508
+ torch.float8_e4m3fnuz,
1509
+ ):
1510
+ if hasattr(self.quant_config, "weight_block_size"):
1511
+ weight_block_size = self.quant_config.weight_block_size
1512
+ if weight_block_size is not None:
1513
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1514
+ if _is_hip:
1515
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1516
+ weight=w,
1517
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1518
+ input_scale=None,
1519
+ )
1520
+ else:
1521
+ weight = w
1522
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1597
1523
 
1524
+ if (
1525
+ _is_cuda
1526
+ and _ENABLE_JIT_DEEPGEMM
1527
+ and weight_block_size[0] == 128
1528
+ and weight_block_size[1] == 128
1529
+ and model_dtype == torch.bfloat16
1530
+ ):
1531
+ block_scale = weight_scale
1532
+ use_deep_gemm_bmm = True
1533
+ else:
1598
1534
  w, scale = block_quant_to_tensor_quant(
1599
1535
  weight, weight_scale, weight_block_size
1600
1536
  )
1601
1537
  self_attn.w_scale = scale
1602
- else:
1538
+ else:
1539
+ weight = w
1540
+ weight_scale = self_attn.kv_b_proj.weight_scale
1541
+ w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1542
+ self_attn.w_scale = scale
1543
+
1544
+ if w.dtype == torch.int8:
1545
+ if hasattr(self.quant_config, "weight_block_size"):
1546
+ # block-wise int8 need it
1547
+ weight_block_size = self.quant_config.weight_block_size
1548
+ if weight_block_size is not None:
1549
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1603
1550
  weight = w
1604
- weight_scale = self_attn.kv_b_proj.weight_scale
1605
- w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1606
- self_attn.w_scale = scale
1607
-
1608
- if w.dtype == torch.int8:
1609
- if hasattr(self.quant_config, "weight_block_size"):
1610
- # block-wise int8 need it
1611
- weight_block_size = self.quant_config.weight_block_size
1612
- if weight_block_size is not None:
1613
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1614
- weight = w
1615
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1616
- w = int8_block_dequant(
1617
- weight, weight_scale, weight_block_size
1618
- ).to(torch.bfloat16)
1619
- else:
1620
- # channel-wise int8 need it
1621
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1622
- torch.bfloat16
1623
- )
1624
- w_kc, w_vc = w.unflatten(
1625
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1626
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1551
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1552
+ w = int8_block_dequant(
1553
+ weight, weight_scale, weight_block_size
1554
+ ).to(torch.bfloat16)
1555
+ else:
1556
+ # channel-wise int8 need it
1557
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1558
+ torch.bfloat16
1559
+ )
1560
+
1561
+ w_kc, w_vc = w.unflatten(
1562
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1563
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1564
+ if not use_deep_gemm_bmm:
1627
1565
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1628
1566
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1629
1567
  if (
@@ -1633,6 +1571,17 @@ class DeepseekV2ForCausalLM(nn.Module):
1633
1571
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1634
1572
  if _is_hip:
1635
1573
  self_attn.w_scale *= 2.0
1574
+ else:
1575
+ num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
1576
+ num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
1577
+ ws_kc, ws_vc = block_scale.unflatten(
1578
+ 0, (-1, (num_tiles_k + num_tiles_n))
1579
+ ).split([num_tiles_k, num_tiles_n], dim=1)
1580
+ self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1581
+ self_attn.w_scale_v = ws_vc.contiguous()
1582
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1583
+ self_attn.w_vc = w_vc.contiguous()
1584
+ self_attn.use_deep_gemm_bmm = True
1636
1585
 
1637
1586
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1638
1587
  stacked_params_mapping = [
@@ -1640,7 +1589,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1640
1589
  ("gate_up_proj", "gate_proj", 0),
1641
1590
  ("gate_up_proj", "up_proj", 1),
1642
1591
  ]
1643
- if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
1592
+ if self.n_share_experts_fusion > 0:
1644
1593
  weights_list = list(weights)
1645
1594
  weights_dict = dict(weights_list)
1646
1595
  if self.quant_config.get_name() == "w8a8_int8":
@@ -1682,7 +1631,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1682
1631
  f"mlp.experts."
1683
1632
  f"{self.config.n_routed_experts + num_repeat}"
1684
1633
  f".{suffix}",
1685
- weights_dict[shared_expert_weight_name].clone(),
1634
+ weights_dict[shared_expert_weight_name],
1686
1635
  )
1687
1636
  )
1688
1637
  names_to_remove += [shared_expert_weight_name]
@@ -1699,12 +1648,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1699
1648
  ckpt_gate_proj_name="gate_proj",
1700
1649
  ckpt_down_proj_name="down_proj",
1701
1650
  ckpt_up_proj_name="up_proj",
1702
- num_experts=self.config.n_routed_experts
1703
- + (
1704
- self.n_share_experts_fusion
1705
- if self.n_share_experts_fusion is not None
1706
- else 0
1707
- ),
1651
+ num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
1708
1652
  )
1709
1653
 
1710
1654
  params_dict = dict(self.named_parameters())