sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,11 @@ from sglang.srt.configs.janus_pro import *
47
47
  from sglang.srt.layers.attention.vision import VisionAttention
48
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
49
  from sglang.srt.layers.quantization import QuantizationConfig
50
- from sglang.srt.managers.multi_modality_padding import (
50
+ from sglang.srt.managers.mm_utils import (
51
51
  MultiModalityDataPaddingPatternTokenPairs,
52
+ general_mm_embed_routine,
52
53
  )
53
- from sglang.srt.managers.schedule_batch import ImageInputs
54
+ from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
54
55
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
56
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
57
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -1289,7 +1290,7 @@ class MlpProjector(nn.Module):
1289
1290
  high_x, low_x = x_or_tuple
1290
1291
  high_x = self.high_up_proj(high_x)
1291
1292
  low_x = self.low_up_proj(low_x)
1292
- x = torch.concat([high_x, low_x], dim=-1)
1293
+ x = torch.cat([high_x, low_x], dim=-1)
1293
1294
  else:
1294
1295
  x = x_or_tuple
1295
1296
 
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1958
1959
  )
1959
1960
  self.logits_processor = LogitsProcessor(config)
1960
1961
 
1961
- def prepare_images_seq_mask(
1962
- self, input_ids: torch.Tensor, image_inputs: ImageInputs
1963
- ) -> Optional[torch.LongTensor]:
1964
- images_seq_mask = torch.isin(
1965
- input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
1962
+ def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
1963
+ pixel_values = image_input.pixel_values
1964
+ bs, n = pixel_values.shape[0:2]
1965
+ pixel_values = pixel_values.to(
1966
+ device=self.vision_model.device, dtype=self.vision_model.dtype
1966
1967
  )
1967
- if images_seq_mask.sum() == 0:
1968
- # sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
1969
- return None
1970
- else:
1971
- return images_seq_mask
1968
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
1969
+
1970
+ # [b x n, T2, D]
1971
+ images_embeds = self.aligner(self.vision_model(images))
1972
+
1973
+ # [b x n, T2, D] -> [b, n x T2, D]
1974
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
1975
+
1976
+ return images_embeds
1977
+
1978
+ def get_input_embeddings(self) -> nn.Embedding:
1979
+ return self.language_model.model.embed_tokens
1972
1980
 
1973
1981
  @torch.no_grad()
1974
1982
  def forward(
@@ -1978,90 +1986,25 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1978
1986
  forward_batch: ForwardBatch,
1979
1987
  ) -> torch.Tensor:
1980
1988
 
1981
- inputs_embeds = None
1982
- if (
1983
- forward_batch.image_inputs is not None
1984
- and len(forward_batch.image_inputs) != 0
1985
- and forward_batch.image_inputs[0] is not None
1986
- ):
1987
-
1988
- image_inputs = forward_batch.image_inputs[0]
1989
-
1990
- images_seq_mask = self.prepare_images_seq_mask(
1991
- input_ids=input_ids, image_inputs=image_inputs
1992
- )
1993
-
1994
- if images_seq_mask is not None:
1995
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1996
- inputs_embeds = self.prepare_inputs_embeds(
1997
- input_ids=input_ids,
1998
- pixel_values=image_inputs.pixel_values,
1999
- images_seq_mask=images_seq_mask,
2000
- images_emb_mask=image_inputs.images_emb_mask,
2001
- )
2002
- input_ids = None
2003
-
2004
- if input_ids is not None:
2005
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1989
+ inputs_embeds = general_mm_embed_routine(
1990
+ input_ids=input_ids,
1991
+ forward_batch=forward_batch,
1992
+ embed_tokens=self.get_input_embeddings(),
1993
+ mm_data_embedding_func=self.get_image_feature,
1994
+ )
2006
1995
 
2007
1996
  return self.language_model(
2008
- input_ids=input_ids,
1997
+ input_ids=None,
2009
1998
  positions=positions,
2010
1999
  forward_batch=forward_batch,
2011
2000
  input_embeds=inputs_embeds,
2012
2001
  get_embedding=False,
2013
2002
  )
2014
2003
 
2015
- def prepare_inputs_embeds(
2016
- self,
2017
- input_ids: torch.LongTensor,
2018
- pixel_values: torch.FloatTensor,
2019
- images_seq_mask: torch.LongTensor,
2020
- images_emb_mask: torch.BoolTensor,
2021
- **_kwargs,
2022
- ):
2023
- """
2024
-
2025
- Args:
2026
- input_ids (torch.LongTensor): [b, T]
2027
- pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
2028
- images_seq_mask (torch.BoolTensor): [b, T]
2029
- images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
2030
-
2031
- assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
2032
-
2033
- Returns:
2034
- input_embeds (torch.Tensor): [b, T, D]
2035
- """
2036
-
2037
- bs, n = pixel_values.shape[0:2]
2038
- pixel_values = pixel_values.to(
2039
- device=self.vision_model.device, dtype=self.vision_model.dtype
2040
- )
2041
- images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
2042
-
2043
- # [b x n, T2, D]
2044
- images_embeds = self.aligner(self.vision_model(images))
2045
-
2046
- # [b x n, T2, D] -> [b, n x T2, D]
2047
- images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
2048
- # [b, n, T2] -> [b, n x T2]
2049
- images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
2050
-
2051
- # [b, T, D]
2052
- # ignore the image embeddings
2053
- input_ids[input_ids < 0] = 0
2054
- inputs_embeds = self.language_model.model.embed_tokens(input_ids)
2055
-
2056
- # replace with the image embeddings
2057
- inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
2058
-
2059
- return inputs_embeds
2060
-
2061
2004
  def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2062
2005
  return self.gen_aligner(self.gen_embed(image_ids))
2063
2006
 
2064
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
2007
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
2065
2008
  im_start_id = image_inputs.im_start_id
2066
2009
  im_end_id = image_inputs.im_end_id
2067
2010
  media_token_pairs = [(im_start_id, im_end_id)]
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig
21
- from vllm import _custom_ops as ops
22
21
 
23
22
  from sglang.srt.layers.layernorm import RMSNorm
24
23
  from sglang.srt.layers.linear import ReplicatedLinear
@@ -41,9 +40,15 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
42
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
44
- from sglang.srt.utils import add_prefix, is_hip
43
+ from sglang.srt.utils import add_prefix, is_cuda, is_hip
45
44
 
46
45
  _is_hip = is_hip()
46
+ _is_cuda = is_cuda()
47
+
48
+ if _is_cuda:
49
+ from sgl_kernel import awq_dequantize
50
+ else:
51
+ from vllm import _custom_ops as ops
47
52
 
48
53
 
49
54
  class DeepseekModelNextN(nn.Module):
@@ -261,14 +266,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
261
266
  self_attn = self.model.decoder.self_attn
262
267
  if hasattr(self_attn.kv_b_proj, "qweight"):
263
268
  # AWQ compatible
264
- w = ops.awq_dequantize(
265
- self_attn.kv_b_proj.qweight,
266
- self_attn.kv_b_proj.scales,
267
- self_attn.kv_b_proj.qzeros,
268
- 0,
269
- 0,
270
- 0,
271
- ).T
269
+ if _is_cuda:
270
+ w = awq_dequantize(
271
+ self_attn.kv_b_proj.qweight,
272
+ self_attn.kv_b_proj.scales,
273
+ self_attn.kv_b_proj.qzeros,
274
+ ).T
275
+ else:
276
+ w = ops.awq_dequantize(
277
+ self_attn.kv_b_proj.qweight,
278
+ self_attn.kv_b_proj.scales,
279
+ self_attn.kv_b_proj.qzeros,
280
+ 0,
281
+ 0,
282
+ 0,
283
+ ).T
272
284
  else:
273
285
  w = self_attn.kv_b_proj.weight
274
286
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
@@ -23,10 +23,10 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm import _custom_ops as ops
27
26
 
28
27
  from sglang.srt.distributed import (
29
28
  get_tensor_model_parallel_world_size,
29
+ parallel_state,
30
30
  tensor_model_parallel_all_reduce,
31
31
  )
32
32
  from sglang.srt.layers.activation import SiluAndMul
@@ -34,7 +34,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
34
34
  decode_attention_fwd_grouped_rope,
35
35
  )
36
36
  from sglang.srt.layers.dp_attention import (
37
- dp_gather,
37
+ dp_gather_partial,
38
38
  dp_scatter,
39
39
  get_attention_dp_size,
40
40
  get_attention_tp_rank,
@@ -48,8 +48,10 @@ from sglang.srt.layers.linear import (
48
48
  RowParallelLinear,
49
49
  )
50
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
51
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
52
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
52
53
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
54
+ from sglang.srt.layers.moe.topk import select_experts
53
55
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
54
56
  from sglang.srt.layers.quantization.fp8_utils import (
55
57
  block_quant_to_tensor_quant,
@@ -65,15 +67,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
65
67
  ParallelLMHead,
66
68
  VocabParallelEmbedding,
67
69
  )
70
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
68
71
  from sglang.srt.managers.schedule_batch import global_server_args_dict
69
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
72
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
70
73
  from sglang.srt.model_loader.weight_utils import default_weight_loader
71
- from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
74
+ from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
72
75
 
73
76
  _is_hip = is_hip()
77
+ _is_cuda = is_cuda()
74
78
 
75
- if is_cuda_available():
76
- from sgl_kernel import bmm_fp8
79
+ if _is_cuda:
80
+ from sgl_kernel import awq_dequantize, bmm_fp8
81
+ else:
82
+ from vllm import _custom_ops as ops
83
+
84
+ expert_distribution_recorder = ExpertDistributionRecorder()
77
85
 
78
86
 
79
87
  class DeepseekV2MLP(nn.Module):
@@ -85,6 +93,8 @@ class DeepseekV2MLP(nn.Module):
85
93
  quant_config: Optional[QuantizationConfig] = None,
86
94
  reduce_results: bool = True,
87
95
  prefix: str = "",
96
+ tp_rank: Optional[int] = None,
97
+ tp_size: Optional[int] = None,
88
98
  ) -> None:
89
99
  super().__init__()
90
100
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -93,6 +103,8 @@ class DeepseekV2MLP(nn.Module):
93
103
  bias=False,
94
104
  quant_config=quant_config,
95
105
  prefix=add_prefix("gate_up_proj", prefix),
106
+ tp_rank=tp_rank,
107
+ tp_size=tp_size,
96
108
  )
97
109
  self.down_proj = RowParallelLinear(
98
110
  intermediate_size,
@@ -101,6 +113,8 @@ class DeepseekV2MLP(nn.Module):
101
113
  quant_config=quant_config,
102
114
  reduce_results=reduce_results,
103
115
  prefix=add_prefix("down_proj", prefix),
116
+ tp_rank=tp_rank,
117
+ tp_size=tp_size,
104
118
  )
105
119
  if hidden_act != "silu":
106
120
  raise ValueError(
@@ -165,7 +179,11 @@ class DeepseekV2MoE(nn.Module):
165
179
 
166
180
  self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
167
181
 
168
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
182
+ MoEImpl = (
183
+ DeepEPMoE
184
+ if global_server_args_dict["enable_deepep_moe"]
185
+ else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
186
+ )
169
187
  self.experts = MoEImpl(
170
188
  num_experts=config.n_routed_experts,
171
189
  top_k=config.num_experts_per_tok,
@@ -182,18 +200,60 @@ class DeepseekV2MoE(nn.Module):
182
200
 
183
201
  if config.n_shared_experts is not None:
184
202
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
185
- self.shared_experts = DeepseekV2MLP(
203
+ # disable tp for shared experts when enable deepep moe
204
+ if not global_server_args_dict["enable_deepep_moe"]:
205
+ self.shared_experts = DeepseekV2MLP(
206
+ hidden_size=config.hidden_size,
207
+ intermediate_size=intermediate_size,
208
+ hidden_act=config.hidden_act,
209
+ quant_config=quant_config,
210
+ reduce_results=False,
211
+ prefix=add_prefix("shared_experts", prefix),
212
+ )
213
+ else:
214
+ self.shared_experts = DeepseekV2MLP(
215
+ hidden_size=config.hidden_size,
216
+ intermediate_size=intermediate_size,
217
+ hidden_act=config.hidden_act,
218
+ quant_config=quant_config,
219
+ reduce_results=False,
220
+ prefix=add_prefix("shared_experts", prefix),
221
+ tp_rank=0,
222
+ tp_size=1,
223
+ )
224
+
225
+ if global_server_args_dict["enable_deepep_moe"]:
226
+ self.num_experts = config.n_routed_experts
227
+ self.top_k = config.num_experts_per_tok
228
+ self.renormalize = config.norm_topk_prob
229
+ self.topk_group = config.topk_group
230
+ self.num_expert_group = config.n_group
231
+ self.correction_bias = (
232
+ self.gate.e_score_correction_bias.data
233
+ if self.gate.e_score_correction_bias is not None
234
+ else None
235
+ )
236
+
237
+ self.deepep_dispatcher = DeepEPDispatcher(
238
+ group=parallel_state.get_tp_group().device_group,
239
+ router_topk=self.top_k,
240
+ permute_fusion=True,
241
+ num_experts=config.n_routed_experts,
242
+ num_local_experts=config.n_routed_experts // self.tp_size,
186
243
  hidden_size=config.hidden_size,
187
- intermediate_size=intermediate_size,
188
- hidden_act=config.hidden_act,
189
- quant_config=quant_config,
190
- reduce_results=False,
191
- prefix=add_prefix("shared_experts", prefix),
244
+ params_dtype=config.torch_dtype,
245
+ async_finish=True, # TODO
192
246
  )
193
247
 
194
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
195
- num_tokens, hidden_dim = hidden_states.shape
196
- hidden_states = hidden_states.view(-1, hidden_dim)
248
+ def forward(
249
+ self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
250
+ ) -> torch.Tensor:
251
+ if not global_server_args_dict["enable_deepep_moe"]:
252
+ return self.forward_normal(hidden_states)
253
+ else:
254
+ return self.forward_deepep(hidden_states, forward_mode)
255
+
256
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
197
257
  if self.n_shared_experts is not None:
198
258
  shared_output = self.shared_experts(hidden_states)
199
259
  # router_logits: (num_tokens, n_experts)
@@ -206,8 +266,60 @@ class DeepseekV2MoE(nn.Module):
206
266
  final_hidden_states = final_hidden_states + shared_output
207
267
  if self.tp_size > 1:
208
268
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
269
+ return final_hidden_states
270
+
271
+ def forward_deepep(
272
+ self, hidden_states: torch.Tensor, forward_mode: ForwardMode
273
+ ) -> torch.Tensor:
274
+ shared_output = None
275
+ topk_idx = torch.full(
276
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
277
+ )
278
+ topk_weights = torch.empty(
279
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
280
+ )
281
+ if forward_mode is not None and not forward_mode.is_idle():
282
+ # router_logits: (num_tokens, n_experts)
283
+ router_logits = self.gate(hidden_states)
284
+ if self.n_shared_experts is not None:
285
+ shared_output = self.shared_experts(hidden_states)
286
+ topk_weights, topk_idx = select_experts(
287
+ hidden_states=hidden_states,
288
+ router_logits=router_logits,
289
+ top_k=self.top_k,
290
+ use_grouped_topk=True,
291
+ renormalize=self.renormalize,
292
+ topk_group=self.topk_group,
293
+ num_expert_group=self.num_expert_group,
294
+ correction_bias=self.correction_bias,
295
+ )
296
+ if self.tp_size > 1:
297
+ recv_hidden_states, reorder_topk_ids, seg_indptr = (
298
+ self.deepep_dispatcher.dispatch(
299
+ hidden_states,
300
+ topk_idx,
301
+ topk_weights,
302
+ self.num_experts,
303
+ forward_mode,
304
+ )
305
+ )
306
+ final_hidden_states = (
307
+ self.experts(
308
+ hidden_states=recv_hidden_states,
309
+ reorder_topk_ids=reorder_topk_ids,
310
+ seg_indptr=seg_indptr,
311
+ forward_mode=forward_mode,
312
+ )
313
+ * self.routed_scaling_factor
314
+ )
315
+ if self.tp_size > 1:
316
+ final_hidden_states = self.deepep_dispatcher.combine(
317
+ final_hidden_states, forward_mode
318
+ )
319
+ if shared_output is not None:
320
+ final_hidden_states = final_hidden_states + shared_output
209
321
 
210
- return final_hidden_states.view(num_tokens, hidden_dim)
322
+ return final_hidden_states
211
323
 
212
324
 
213
325
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
@@ -547,7 +659,7 @@ class DeepseekV2AttentionMLA(nn.Module):
547
659
  and forward_batch.forward_mode.is_extend()
548
660
  and not forward_batch.forward_mode.is_target_verify()
549
661
  and not forward_batch.forward_mode.is_draft_extend()
550
- and forward_batch.extend_prefix_lens.sum() == 0
662
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
551
663
  )
552
664
  else:
553
665
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
@@ -555,7 +667,7 @@ class DeepseekV2AttentionMLA(nn.Module):
555
667
  forward_batch.forward_mode.is_extend()
556
668
  and not forward_batch.forward_mode.is_target_verify()
557
669
  and not forward_batch.forward_mode.is_draft_extend()
558
- and forward_batch.extend_prefix_lens.sum() == 0
670
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
559
671
  )
560
672
 
561
673
  def forward(
@@ -937,47 +1049,68 @@ class DeepseekV2DecoderLayer(nn.Module):
937
1049
  forward_batch: ForwardBatch,
938
1050
  residual: Optional[torch.Tensor],
939
1051
  ) -> torch.Tensor:
940
- if residual is None:
1052
+ if hidden_states.shape[0] == 0:
941
1053
  residual = hidden_states
942
- hidden_states = self.input_layernorm(hidden_states)
943
1054
  else:
944
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
1055
+ if residual is None:
1056
+ residual = hidden_states
1057
+ hidden_states = self.input_layernorm(hidden_states)
1058
+ else:
1059
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
945
1060
 
946
- # Scatter
947
- if self.dp_size != 1:
948
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
949
- # be careful about this!
950
- hidden_states, global_hidden_states = (
951
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
952
- hidden_states,
1061
+ # Self Attention
1062
+ hidden_states = self.self_attn(
1063
+ positions=positions,
1064
+ hidden_states=hidden_states,
1065
+ forward_batch=forward_batch,
953
1066
  )
954
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
955
-
956
- # Self Attention
957
- hidden_states = self.self_attn(
958
- positions=positions,
959
- hidden_states=hidden_states,
960
- forward_batch=forward_batch,
961
- )
962
1067
 
963
1068
  # Gather
964
1069
  if get_tensor_model_parallel_world_size() > 1:
965
1070
  # all gather and all reduce
966
1071
  if self.dp_size != 1:
967
- hidden_states, local_hidden_states = (
968
- forward_batch.gathered_buffer,
969
- hidden_states,
970
- )
971
- dp_gather(
972
- hidden_states, local_hidden_states, forward_batch, self.layer_id
973
- )
1072
+ if global_server_args_dict["enable_deepep_moe"] and isinstance(
1073
+ self.mlp, DeepseekV2MoE
1074
+ ):
1075
+ if hidden_states.shape[0] != 0:
1076
+ hidden_states, residual = self.post_attention_layernorm(
1077
+ hidden_states, residual
1078
+ )
1079
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1080
+ return hidden_states, residual
1081
+ else:
1082
+ if get_attention_tp_rank() == 0:
1083
+ hidden_states += residual
1084
+ hidden_states, local_hidden_states = (
1085
+ forward_batch.gathered_buffer,
1086
+ hidden_states,
1087
+ )
1088
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1089
+ dp_scatter(residual, hidden_states, forward_batch)
1090
+ hidden_states = self.post_attention_layernorm(hidden_states)
974
1091
  else:
975
1092
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
976
-
977
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
1093
+ hidden_states, residual = self.post_attention_layernorm(
1094
+ hidden_states, residual
1095
+ )
1096
+ else:
1097
+ hidden_states, residual = self.post_attention_layernorm(
1098
+ hidden_states, residual
1099
+ )
978
1100
 
979
1101
  # Fully Connected
980
1102
  hidden_states = self.mlp(hidden_states)
1103
+
1104
+ # Scatter
1105
+ if self.dp_size != 1:
1106
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
1107
+ # be careful about this!
1108
+ hidden_states, global_hidden_states = (
1109
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1110
+ hidden_states,
1111
+ )
1112
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
1113
+
981
1114
  return hidden_states, residual
982
1115
 
983
1116
 
@@ -1020,23 +1153,17 @@ class DeepseekV2Model(nn.Module):
1020
1153
  input_ids: torch.Tensor,
1021
1154
  positions: torch.Tensor,
1022
1155
  forward_batch: ForwardBatch,
1156
+ input_embeds: torch.Tensor = None,
1023
1157
  ) -> torch.Tensor:
1024
1158
 
1025
- # Gather
1026
- if self.dp_size != 1:
1027
- input_ids, local_input_ids = (
1028
- torch.empty(
1029
- (forward_batch.gathered_buffer.shape[0],),
1030
- dtype=input_ids.dtype,
1031
- device=input_ids.device,
1032
- ),
1033
- input_ids,
1034
- )
1035
- dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
1159
+ if input_embeds is None:
1160
+ hidden_states = self.embed_tokens(input_ids)
1161
+ else:
1162
+ hidden_states = input_embeds
1036
1163
 
1037
- hidden_states = self.embed_tokens(input_ids)
1038
1164
  residual = None
1039
1165
  for i in range(len(self.layers)):
1166
+ expert_distribution_recorder.set_current_layer(i)
1040
1167
  layer = self.layers[i]
1041
1168
  hidden_states, residual = layer(
1042
1169
  positions, hidden_states, forward_batch, residual
@@ -1075,17 +1202,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1075
1202
  input_ids: torch.Tensor,
1076
1203
  positions: torch.Tensor,
1077
1204
  forward_batch: ForwardBatch,
1205
+ input_embeds: torch.Tensor = None,
1078
1206
  ) -> torch.Tensor:
1079
- hidden_states = self.model(input_ids, positions, forward_batch)
1080
1207
 
1081
- if self.dp_size != 1:
1082
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
1083
- # be careful about this!
1084
- hidden_states, global_hidden_states = (
1085
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1086
- hidden_states,
1087
- )
1088
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
1208
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
1089
1209
 
1090
1210
  return self.logits_processor(
1091
1211
  input_ids, hidden_states, self.lm_head, forward_batch
@@ -1100,7 +1220,11 @@ class DeepseekV2ForCausalLM(nn.Module):
1100
1220
 
1101
1221
  # Params for weights, fp8 weight scales, fp8 activation scales
1102
1222
  # (param_name, weight_name, expert_id, shard_id)
1103
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
1223
+ MoEImpl = (
1224
+ DeepEPMoE
1225
+ if global_server_args_dict["enable_deepep_moe"]
1226
+ else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
1227
+ )
1104
1228
  expert_params_mapping = MoEImpl.make_expert_params_mapping(
1105
1229
  ckpt_gate_proj_name="gate_proj",
1106
1230
  ckpt_down_proj_name="down_proj",
@@ -1174,14 +1298,21 @@ class DeepseekV2ForCausalLM(nn.Module):
1174
1298
  self_attn = self.model.layers[layer_id].self_attn
1175
1299
  if hasattr(self_attn.kv_b_proj, "qweight"):
1176
1300
  # AWQ compatible
1177
- w = ops.awq_dequantize(
1178
- self_attn.kv_b_proj.qweight,
1179
- self_attn.kv_b_proj.scales,
1180
- self_attn.kv_b_proj.qzeros,
1181
- 0,
1182
- 0,
1183
- 0,
1184
- ).T
1301
+ if _is_cuda:
1302
+ w = awq_dequantize(
1303
+ self_attn.kv_b_proj.qweight,
1304
+ self_attn.kv_b_proj.scales,
1305
+ self_attn.kv_b_proj.qzeros,
1306
+ ).T
1307
+ else:
1308
+ w = ops.awq_dequantize(
1309
+ self_attn.kv_b_proj.qweight,
1310
+ self_attn.kv_b_proj.scales,
1311
+ self_attn.kv_b_proj.qzeros,
1312
+ 0,
1313
+ 0,
1314
+ 0,
1315
+ ).T
1185
1316
  else:
1186
1317
  w = self_attn.kv_b_proj.weight
1187
1318
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.