sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,11 @@ from sglang.srt.configs.janus_pro import *
|
|
47
47
|
from sglang.srt.layers.attention.vision import VisionAttention
|
48
48
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
49
|
from sglang.srt.layers.quantization import QuantizationConfig
|
50
|
-
from sglang.srt.managers.
|
50
|
+
from sglang.srt.managers.mm_utils import (
|
51
51
|
MultiModalityDataPaddingPatternTokenPairs,
|
52
|
+
general_mm_embed_routine,
|
52
53
|
)
|
53
|
-
from sglang.srt.managers.schedule_batch import
|
54
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
|
54
55
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
55
56
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
57
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -1289,7 +1290,7 @@ class MlpProjector(nn.Module):
|
|
1289
1290
|
high_x, low_x = x_or_tuple
|
1290
1291
|
high_x = self.high_up_proj(high_x)
|
1291
1292
|
low_x = self.low_up_proj(low_x)
|
1292
|
-
x = torch.
|
1293
|
+
x = torch.cat([high_x, low_x], dim=-1)
|
1293
1294
|
else:
|
1294
1295
|
x = x_or_tuple
|
1295
1296
|
|
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1958
1959
|
)
|
1959
1960
|
self.logits_processor = LogitsProcessor(config)
|
1960
1961
|
|
1961
|
-
def
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1965
|
-
|
1962
|
+
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
1963
|
+
pixel_values = image_input.pixel_values
|
1964
|
+
bs, n = pixel_values.shape[0:2]
|
1965
|
+
pixel_values = pixel_values.to(
|
1966
|
+
device=self.vision_model.device, dtype=self.vision_model.dtype
|
1966
1967
|
)
|
1967
|
-
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1968
|
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
1969
|
+
|
1970
|
+
# [b x n, T2, D]
|
1971
|
+
images_embeds = self.aligner(self.vision_model(images))
|
1972
|
+
|
1973
|
+
# [b x n, T2, D] -> [b, n x T2, D]
|
1974
|
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
1975
|
+
|
1976
|
+
return images_embeds
|
1977
|
+
|
1978
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
1979
|
+
return self.language_model.model.embed_tokens
|
1972
1980
|
|
1973
1981
|
@torch.no_grad()
|
1974
1982
|
def forward(
|
@@ -1978,90 +1986,25 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1978
1986
|
forward_batch: ForwardBatch,
|
1979
1987
|
) -> torch.Tensor:
|
1980
1988
|
|
1981
|
-
inputs_embeds =
|
1982
|
-
|
1983
|
-
forward_batch
|
1984
|
-
|
1985
|
-
|
1986
|
-
)
|
1987
|
-
|
1988
|
-
image_inputs = forward_batch.image_inputs[0]
|
1989
|
-
|
1990
|
-
images_seq_mask = self.prepare_images_seq_mask(
|
1991
|
-
input_ids=input_ids, image_inputs=image_inputs
|
1992
|
-
)
|
1993
|
-
|
1994
|
-
if images_seq_mask is not None:
|
1995
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1996
|
-
inputs_embeds = self.prepare_inputs_embeds(
|
1997
|
-
input_ids=input_ids,
|
1998
|
-
pixel_values=image_inputs.pixel_values,
|
1999
|
-
images_seq_mask=images_seq_mask,
|
2000
|
-
images_emb_mask=image_inputs.images_emb_mask,
|
2001
|
-
)
|
2002
|
-
input_ids = None
|
2003
|
-
|
2004
|
-
if input_ids is not None:
|
2005
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1989
|
+
inputs_embeds = general_mm_embed_routine(
|
1990
|
+
input_ids=input_ids,
|
1991
|
+
forward_batch=forward_batch,
|
1992
|
+
embed_tokens=self.get_input_embeddings(),
|
1993
|
+
mm_data_embedding_func=self.get_image_feature,
|
1994
|
+
)
|
2006
1995
|
|
2007
1996
|
return self.language_model(
|
2008
|
-
input_ids=
|
1997
|
+
input_ids=None,
|
2009
1998
|
positions=positions,
|
2010
1999
|
forward_batch=forward_batch,
|
2011
2000
|
input_embeds=inputs_embeds,
|
2012
2001
|
get_embedding=False,
|
2013
2002
|
)
|
2014
2003
|
|
2015
|
-
def prepare_inputs_embeds(
|
2016
|
-
self,
|
2017
|
-
input_ids: torch.LongTensor,
|
2018
|
-
pixel_values: torch.FloatTensor,
|
2019
|
-
images_seq_mask: torch.LongTensor,
|
2020
|
-
images_emb_mask: torch.BoolTensor,
|
2021
|
-
**_kwargs,
|
2022
|
-
):
|
2023
|
-
"""
|
2024
|
-
|
2025
|
-
Args:
|
2026
|
-
input_ids (torch.LongTensor): [b, T]
|
2027
|
-
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
2028
|
-
images_seq_mask (torch.BoolTensor): [b, T]
|
2029
|
-
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
2030
|
-
|
2031
|
-
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
2032
|
-
|
2033
|
-
Returns:
|
2034
|
-
input_embeds (torch.Tensor): [b, T, D]
|
2035
|
-
"""
|
2036
|
-
|
2037
|
-
bs, n = pixel_values.shape[0:2]
|
2038
|
-
pixel_values = pixel_values.to(
|
2039
|
-
device=self.vision_model.device, dtype=self.vision_model.dtype
|
2040
|
-
)
|
2041
|
-
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
2042
|
-
|
2043
|
-
# [b x n, T2, D]
|
2044
|
-
images_embeds = self.aligner(self.vision_model(images))
|
2045
|
-
|
2046
|
-
# [b x n, T2, D] -> [b, n x T2, D]
|
2047
|
-
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
2048
|
-
# [b, n, T2] -> [b, n x T2]
|
2049
|
-
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
2050
|
-
|
2051
|
-
# [b, T, D]
|
2052
|
-
# ignore the image embeddings
|
2053
|
-
input_ids[input_ids < 0] = 0
|
2054
|
-
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
2055
|
-
|
2056
|
-
# replace with the image embeddings
|
2057
|
-
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
2058
|
-
|
2059
|
-
return inputs_embeds
|
2060
|
-
|
2061
2004
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
2062
2005
|
return self.gen_aligner(self.gen_embed(image_ids))
|
2063
2006
|
|
2064
|
-
def pad_input_ids(self, input_ids: List[int], image_inputs:
|
2007
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
2065
2008
|
im_start_id = image_inputs.im_start_id
|
2066
2009
|
im_end_id = image_inputs.im_end_id
|
2067
2010
|
media_token_pairs = [(im_start_id, im_end_id)]
|
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig
|
21
|
-
from vllm import _custom_ops as ops
|
22
21
|
|
23
22
|
from sglang.srt.layers.layernorm import RMSNorm
|
24
23
|
from sglang.srt.layers.linear import ReplicatedLinear
|
@@ -41,9 +40,15 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
42
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
44
|
-
from sglang.srt.utils import add_prefix, is_hip
|
43
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
45
44
|
|
46
45
|
_is_hip = is_hip()
|
46
|
+
_is_cuda = is_cuda()
|
47
|
+
|
48
|
+
if _is_cuda:
|
49
|
+
from sgl_kernel import awq_dequantize
|
50
|
+
else:
|
51
|
+
from vllm import _custom_ops as ops
|
47
52
|
|
48
53
|
|
49
54
|
class DeepseekModelNextN(nn.Module):
|
@@ -261,14 +266,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
261
266
|
self_attn = self.model.decoder.self_attn
|
262
267
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
263
268
|
# AWQ compatible
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
269
|
+
if _is_cuda:
|
270
|
+
w = awq_dequantize(
|
271
|
+
self_attn.kv_b_proj.qweight,
|
272
|
+
self_attn.kv_b_proj.scales,
|
273
|
+
self_attn.kv_b_proj.qzeros,
|
274
|
+
).T
|
275
|
+
else:
|
276
|
+
w = ops.awq_dequantize(
|
277
|
+
self_attn.kv_b_proj.qweight,
|
278
|
+
self_attn.kv_b_proj.scales,
|
279
|
+
self_attn.kv_b_proj.qzeros,
|
280
|
+
0,
|
281
|
+
0,
|
282
|
+
0,
|
283
|
+
).T
|
272
284
|
else:
|
273
285
|
w = self_attn.kv_b_proj.weight
|
274
286
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -23,10 +23,10 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import PretrainedConfig
|
26
|
-
from vllm import _custom_ops as ops
|
27
26
|
|
28
27
|
from sglang.srt.distributed import (
|
29
28
|
get_tensor_model_parallel_world_size,
|
29
|
+
parallel_state,
|
30
30
|
tensor_model_parallel_all_reduce,
|
31
31
|
)
|
32
32
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -34,7 +34,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
|
34
34
|
decode_attention_fwd_grouped_rope,
|
35
35
|
)
|
36
36
|
from sglang.srt.layers.dp_attention import (
|
37
|
-
|
37
|
+
dp_gather_partial,
|
38
38
|
dp_scatter,
|
39
39
|
get_attention_dp_size,
|
40
40
|
get_attention_tp_rank,
|
@@ -48,8 +48,10 @@ from sglang.srt.layers.linear import (
|
|
48
48
|
RowParallelLinear,
|
49
49
|
)
|
50
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
51
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
52
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
52
53
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
54
|
+
from sglang.srt.layers.moe.topk import select_experts
|
53
55
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
54
56
|
from sglang.srt.layers.quantization.fp8_utils import (
|
55
57
|
block_quant_to_tensor_quant,
|
@@ -65,15 +67,21 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
65
67
|
ParallelLMHead,
|
66
68
|
VocabParallelEmbedding,
|
67
69
|
)
|
70
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
68
71
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
69
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
72
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
70
73
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
71
|
-
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
74
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
72
75
|
|
73
76
|
_is_hip = is_hip()
|
77
|
+
_is_cuda = is_cuda()
|
74
78
|
|
75
|
-
if
|
76
|
-
from sgl_kernel import bmm_fp8
|
79
|
+
if _is_cuda:
|
80
|
+
from sgl_kernel import awq_dequantize, bmm_fp8
|
81
|
+
else:
|
82
|
+
from vllm import _custom_ops as ops
|
83
|
+
|
84
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
77
85
|
|
78
86
|
|
79
87
|
class DeepseekV2MLP(nn.Module):
|
@@ -85,6 +93,8 @@ class DeepseekV2MLP(nn.Module):
|
|
85
93
|
quant_config: Optional[QuantizationConfig] = None,
|
86
94
|
reduce_results: bool = True,
|
87
95
|
prefix: str = "",
|
96
|
+
tp_rank: Optional[int] = None,
|
97
|
+
tp_size: Optional[int] = None,
|
88
98
|
) -> None:
|
89
99
|
super().__init__()
|
90
100
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -93,6 +103,8 @@ class DeepseekV2MLP(nn.Module):
|
|
93
103
|
bias=False,
|
94
104
|
quant_config=quant_config,
|
95
105
|
prefix=add_prefix("gate_up_proj", prefix),
|
106
|
+
tp_rank=tp_rank,
|
107
|
+
tp_size=tp_size,
|
96
108
|
)
|
97
109
|
self.down_proj = RowParallelLinear(
|
98
110
|
intermediate_size,
|
@@ -101,6 +113,8 @@ class DeepseekV2MLP(nn.Module):
|
|
101
113
|
quant_config=quant_config,
|
102
114
|
reduce_results=reduce_results,
|
103
115
|
prefix=add_prefix("down_proj", prefix),
|
116
|
+
tp_rank=tp_rank,
|
117
|
+
tp_size=tp_size,
|
104
118
|
)
|
105
119
|
if hidden_act != "silu":
|
106
120
|
raise ValueError(
|
@@ -165,7 +179,11 @@ class DeepseekV2MoE(nn.Module):
|
|
165
179
|
|
166
180
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
167
181
|
|
168
|
-
MoEImpl =
|
182
|
+
MoEImpl = (
|
183
|
+
DeepEPMoE
|
184
|
+
if global_server_args_dict["enable_deepep_moe"]
|
185
|
+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
186
|
+
)
|
169
187
|
self.experts = MoEImpl(
|
170
188
|
num_experts=config.n_routed_experts,
|
171
189
|
top_k=config.num_experts_per_tok,
|
@@ -182,18 +200,60 @@ class DeepseekV2MoE(nn.Module):
|
|
182
200
|
|
183
201
|
if config.n_shared_experts is not None:
|
184
202
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
185
|
-
|
203
|
+
# disable tp for shared experts when enable deepep moe
|
204
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
205
|
+
self.shared_experts = DeepseekV2MLP(
|
206
|
+
hidden_size=config.hidden_size,
|
207
|
+
intermediate_size=intermediate_size,
|
208
|
+
hidden_act=config.hidden_act,
|
209
|
+
quant_config=quant_config,
|
210
|
+
reduce_results=False,
|
211
|
+
prefix=add_prefix("shared_experts", prefix),
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
self.shared_experts = DeepseekV2MLP(
|
215
|
+
hidden_size=config.hidden_size,
|
216
|
+
intermediate_size=intermediate_size,
|
217
|
+
hidden_act=config.hidden_act,
|
218
|
+
quant_config=quant_config,
|
219
|
+
reduce_results=False,
|
220
|
+
prefix=add_prefix("shared_experts", prefix),
|
221
|
+
tp_rank=0,
|
222
|
+
tp_size=1,
|
223
|
+
)
|
224
|
+
|
225
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
226
|
+
self.num_experts = config.n_routed_experts
|
227
|
+
self.top_k = config.num_experts_per_tok
|
228
|
+
self.renormalize = config.norm_topk_prob
|
229
|
+
self.topk_group = config.topk_group
|
230
|
+
self.num_expert_group = config.n_group
|
231
|
+
self.correction_bias = (
|
232
|
+
self.gate.e_score_correction_bias.data
|
233
|
+
if self.gate.e_score_correction_bias is not None
|
234
|
+
else None
|
235
|
+
)
|
236
|
+
|
237
|
+
self.deepep_dispatcher = DeepEPDispatcher(
|
238
|
+
group=parallel_state.get_tp_group().device_group,
|
239
|
+
router_topk=self.top_k,
|
240
|
+
permute_fusion=True,
|
241
|
+
num_experts=config.n_routed_experts,
|
242
|
+
num_local_experts=config.n_routed_experts // self.tp_size,
|
186
243
|
hidden_size=config.hidden_size,
|
187
|
-
|
188
|
-
|
189
|
-
quant_config=quant_config,
|
190
|
-
reduce_results=False,
|
191
|
-
prefix=add_prefix("shared_experts", prefix),
|
244
|
+
params_dtype=config.torch_dtype,
|
245
|
+
async_finish=True, # TODO
|
192
246
|
)
|
193
247
|
|
194
|
-
def forward(
|
195
|
-
|
196
|
-
|
248
|
+
def forward(
|
249
|
+
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
250
|
+
) -> torch.Tensor:
|
251
|
+
if not global_server_args_dict["enable_deepep_moe"]:
|
252
|
+
return self.forward_normal(hidden_states)
|
253
|
+
else:
|
254
|
+
return self.forward_deepep(hidden_states, forward_mode)
|
255
|
+
|
256
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
197
257
|
if self.n_shared_experts is not None:
|
198
258
|
shared_output = self.shared_experts(hidden_states)
|
199
259
|
# router_logits: (num_tokens, n_experts)
|
@@ -206,8 +266,60 @@ class DeepseekV2MoE(nn.Module):
|
|
206
266
|
final_hidden_states = final_hidden_states + shared_output
|
207
267
|
if self.tp_size > 1:
|
208
268
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
269
|
+
return final_hidden_states
|
270
|
+
|
271
|
+
def forward_deepep(
|
272
|
+
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
273
|
+
) -> torch.Tensor:
|
274
|
+
shared_output = None
|
275
|
+
topk_idx = torch.full(
|
276
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
277
|
+
)
|
278
|
+
topk_weights = torch.empty(
|
279
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
280
|
+
)
|
281
|
+
if forward_mode is not None and not forward_mode.is_idle():
|
282
|
+
# router_logits: (num_tokens, n_experts)
|
283
|
+
router_logits = self.gate(hidden_states)
|
284
|
+
if self.n_shared_experts is not None:
|
285
|
+
shared_output = self.shared_experts(hidden_states)
|
286
|
+
topk_weights, topk_idx = select_experts(
|
287
|
+
hidden_states=hidden_states,
|
288
|
+
router_logits=router_logits,
|
289
|
+
top_k=self.top_k,
|
290
|
+
use_grouped_topk=True,
|
291
|
+
renormalize=self.renormalize,
|
292
|
+
topk_group=self.topk_group,
|
293
|
+
num_expert_group=self.num_expert_group,
|
294
|
+
correction_bias=self.correction_bias,
|
295
|
+
)
|
296
|
+
if self.tp_size > 1:
|
297
|
+
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
298
|
+
self.deepep_dispatcher.dispatch(
|
299
|
+
hidden_states,
|
300
|
+
topk_idx,
|
301
|
+
topk_weights,
|
302
|
+
self.num_experts,
|
303
|
+
forward_mode,
|
304
|
+
)
|
305
|
+
)
|
306
|
+
final_hidden_states = (
|
307
|
+
self.experts(
|
308
|
+
hidden_states=recv_hidden_states,
|
309
|
+
reorder_topk_ids=reorder_topk_ids,
|
310
|
+
seg_indptr=seg_indptr,
|
311
|
+
forward_mode=forward_mode,
|
312
|
+
)
|
313
|
+
* self.routed_scaling_factor
|
314
|
+
)
|
315
|
+
if self.tp_size > 1:
|
316
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
317
|
+
final_hidden_states, forward_mode
|
318
|
+
)
|
319
|
+
if shared_output is not None:
|
320
|
+
final_hidden_states = final_hidden_states + shared_output
|
209
321
|
|
210
|
-
return final_hidden_states
|
322
|
+
return final_hidden_states
|
211
323
|
|
212
324
|
|
213
325
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
@@ -547,7 +659,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
547
659
|
and forward_batch.forward_mode.is_extend()
|
548
660
|
and not forward_batch.forward_mode.is_target_verify()
|
549
661
|
and not forward_batch.forward_mode.is_draft_extend()
|
550
|
-
and forward_batch.
|
662
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
551
663
|
)
|
552
664
|
else:
|
553
665
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
@@ -555,7 +667,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
555
667
|
forward_batch.forward_mode.is_extend()
|
556
668
|
and not forward_batch.forward_mode.is_target_verify()
|
557
669
|
and not forward_batch.forward_mode.is_draft_extend()
|
558
|
-
and forward_batch.
|
670
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
559
671
|
)
|
560
672
|
|
561
673
|
def forward(
|
@@ -937,47 +1049,68 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
937
1049
|
forward_batch: ForwardBatch,
|
938
1050
|
residual: Optional[torch.Tensor],
|
939
1051
|
) -> torch.Tensor:
|
940
|
-
if
|
1052
|
+
if hidden_states.shape[0] == 0:
|
941
1053
|
residual = hidden_states
|
942
|
-
hidden_states = self.input_layernorm(hidden_states)
|
943
1054
|
else:
|
944
|
-
|
1055
|
+
if residual is None:
|
1056
|
+
residual = hidden_states
|
1057
|
+
hidden_states = self.input_layernorm(hidden_states)
|
1058
|
+
else:
|
1059
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
945
1060
|
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
952
|
-
hidden_states,
|
1061
|
+
# Self Attention
|
1062
|
+
hidden_states = self.self_attn(
|
1063
|
+
positions=positions,
|
1064
|
+
hidden_states=hidden_states,
|
1065
|
+
forward_batch=forward_batch,
|
953
1066
|
)
|
954
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
955
|
-
|
956
|
-
# Self Attention
|
957
|
-
hidden_states = self.self_attn(
|
958
|
-
positions=positions,
|
959
|
-
hidden_states=hidden_states,
|
960
|
-
forward_batch=forward_batch,
|
961
|
-
)
|
962
1067
|
|
963
1068
|
# Gather
|
964
1069
|
if get_tensor_model_parallel_world_size() > 1:
|
965
1070
|
# all gather and all reduce
|
966
1071
|
if self.dp_size != 1:
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
1072
|
+
if global_server_args_dict["enable_deepep_moe"] and isinstance(
|
1073
|
+
self.mlp, DeepseekV2MoE
|
1074
|
+
):
|
1075
|
+
if hidden_states.shape[0] != 0:
|
1076
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1077
|
+
hidden_states, residual
|
1078
|
+
)
|
1079
|
+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
1080
|
+
return hidden_states, residual
|
1081
|
+
else:
|
1082
|
+
if get_attention_tp_rank() == 0:
|
1083
|
+
hidden_states += residual
|
1084
|
+
hidden_states, local_hidden_states = (
|
1085
|
+
forward_batch.gathered_buffer,
|
1086
|
+
hidden_states,
|
1087
|
+
)
|
1088
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1089
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
1090
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
974
1091
|
else:
|
975
1092
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
976
|
-
|
977
|
-
|
1093
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1094
|
+
hidden_states, residual
|
1095
|
+
)
|
1096
|
+
else:
|
1097
|
+
hidden_states, residual = self.post_attention_layernorm(
|
1098
|
+
hidden_states, residual
|
1099
|
+
)
|
978
1100
|
|
979
1101
|
# Fully Connected
|
980
1102
|
hidden_states = self.mlp(hidden_states)
|
1103
|
+
|
1104
|
+
# Scatter
|
1105
|
+
if self.dp_size != 1:
|
1106
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1107
|
+
# be careful about this!
|
1108
|
+
hidden_states, global_hidden_states = (
|
1109
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1110
|
+
hidden_states,
|
1111
|
+
)
|
1112
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1113
|
+
|
981
1114
|
return hidden_states, residual
|
982
1115
|
|
983
1116
|
|
@@ -1020,23 +1153,17 @@ class DeepseekV2Model(nn.Module):
|
|
1020
1153
|
input_ids: torch.Tensor,
|
1021
1154
|
positions: torch.Tensor,
|
1022
1155
|
forward_batch: ForwardBatch,
|
1156
|
+
input_embeds: torch.Tensor = None,
|
1023
1157
|
) -> torch.Tensor:
|
1024
1158
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
(forward_batch.gathered_buffer.shape[0],),
|
1030
|
-
dtype=input_ids.dtype,
|
1031
|
-
device=input_ids.device,
|
1032
|
-
),
|
1033
|
-
input_ids,
|
1034
|
-
)
|
1035
|
-
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
1159
|
+
if input_embeds is None:
|
1160
|
+
hidden_states = self.embed_tokens(input_ids)
|
1161
|
+
else:
|
1162
|
+
hidden_states = input_embeds
|
1036
1163
|
|
1037
|
-
hidden_states = self.embed_tokens(input_ids)
|
1038
1164
|
residual = None
|
1039
1165
|
for i in range(len(self.layers)):
|
1166
|
+
expert_distribution_recorder.set_current_layer(i)
|
1040
1167
|
layer = self.layers[i]
|
1041
1168
|
hidden_states, residual = layer(
|
1042
1169
|
positions, hidden_states, forward_batch, residual
|
@@ -1075,17 +1202,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1075
1202
|
input_ids: torch.Tensor,
|
1076
1203
|
positions: torch.Tensor,
|
1077
1204
|
forward_batch: ForwardBatch,
|
1205
|
+
input_embeds: torch.Tensor = None,
|
1078
1206
|
) -> torch.Tensor:
|
1079
|
-
hidden_states = self.model(input_ids, positions, forward_batch)
|
1080
1207
|
|
1081
|
-
|
1082
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1083
|
-
# be careful about this!
|
1084
|
-
hidden_states, global_hidden_states = (
|
1085
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1086
|
-
hidden_states,
|
1087
|
-
)
|
1088
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1208
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
1089
1209
|
|
1090
1210
|
return self.logits_processor(
|
1091
1211
|
input_ids, hidden_states, self.lm_head, forward_batch
|
@@ -1100,7 +1220,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1100
1220
|
|
1101
1221
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1102
1222
|
# (param_name, weight_name, expert_id, shard_id)
|
1103
|
-
MoEImpl =
|
1223
|
+
MoEImpl = (
|
1224
|
+
DeepEPMoE
|
1225
|
+
if global_server_args_dict["enable_deepep_moe"]
|
1226
|
+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
1227
|
+
)
|
1104
1228
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
1105
1229
|
ckpt_gate_proj_name="gate_proj",
|
1106
1230
|
ckpt_down_proj_name="down_proj",
|
@@ -1174,14 +1298,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1174
1298
|
self_attn = self.model.layers[layer_id].self_attn
|
1175
1299
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1176
1300
|
# AWQ compatible
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1301
|
+
if _is_cuda:
|
1302
|
+
w = awq_dequantize(
|
1303
|
+
self_attn.kv_b_proj.qweight,
|
1304
|
+
self_attn.kv_b_proj.scales,
|
1305
|
+
self_attn.kv_b_proj.qzeros,
|
1306
|
+
).T
|
1307
|
+
else:
|
1308
|
+
w = ops.awq_dequantize(
|
1309
|
+
self_attn.kv_b_proj.qweight,
|
1310
|
+
self_attn.kv_b_proj.scales,
|
1311
|
+
self_attn.kv_b_proj.qzeros,
|
1312
|
+
0,
|
1313
|
+
0,
|
1314
|
+
0,
|
1315
|
+
).T
|
1185
1316
|
else:
|
1186
1317
|
w = self_attn.kv_b_proj.weight
|
1187
1318
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|