sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -320,17 +320,30 @@ class ForwardBatch:
|
|
320
320
|
|
321
321
|
# For DP attention
|
322
322
|
if batch.global_num_tokens is not None:
|
323
|
-
|
323
|
+
|
324
|
+
spec_num_draft_tokens = (
|
325
|
+
batch.spec_num_draft_tokens
|
326
|
+
if batch.spec_num_draft_tokens is not None
|
327
|
+
else 1
|
328
|
+
)
|
329
|
+
global_num_tokens = [
|
330
|
+
x * spec_num_draft_tokens for x in batch.global_num_tokens
|
331
|
+
]
|
332
|
+
global_num_tokens_for_logprob = [
|
333
|
+
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
|
334
|
+
]
|
335
|
+
|
336
|
+
ret.global_num_tokens_cpu = global_num_tokens
|
324
337
|
ret.global_num_tokens_gpu = torch.tensor(
|
325
|
-
|
338
|
+
global_num_tokens, dtype=torch.int64
|
326
339
|
).to(device, non_blocking=True)
|
327
340
|
|
328
|
-
ret.global_num_tokens_for_logprob_cpu =
|
341
|
+
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
|
329
342
|
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
330
|
-
|
343
|
+
global_num_tokens_for_logprob, dtype=torch.int64
|
331
344
|
).to(device, non_blocking=True)
|
332
345
|
|
333
|
-
sum_len = sum(
|
346
|
+
sum_len = sum(global_num_tokens)
|
334
347
|
ret.gathered_buffer = torch.zeros(
|
335
348
|
(sum_len, model_runner.model_config.hidden_size),
|
336
349
|
dtype=model_runner.dtype,
|
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
|
|
30
30
|
from sglang.srt.configs.device_config import DeviceConfig
|
31
31
|
from sglang.srt.configs.load_config import LoadConfig
|
32
32
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
33
34
|
from sglang.srt.distributed import (
|
34
35
|
get_tp_group,
|
35
36
|
get_world_group,
|
@@ -70,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
|
|
70
71
|
GLOBAL_SERVER_ARGS_KEYS,
|
71
72
|
global_server_args_dict,
|
72
73
|
)
|
74
|
+
from sglang.srt.mem_cache.allocator import (
|
75
|
+
BaseTokenToKVPoolAllocator,
|
76
|
+
PagedTokenToKVPoolAllocator,
|
77
|
+
TokenToKVPoolAllocator,
|
78
|
+
)
|
73
79
|
from sglang.srt.mem_cache.memory_pool import (
|
74
80
|
DoubleSparseTokenToKVPool,
|
75
81
|
MHATokenToKVPool,
|
76
82
|
MLATokenToKVPool,
|
77
83
|
ReqToTokenPool,
|
78
|
-
TokenToKVPoolAllocator,
|
79
84
|
)
|
80
|
-
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
81
85
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
82
86
|
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
83
87
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
@@ -93,6 +97,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
93
97
|
from sglang.srt.utils import (
|
94
98
|
MultiprocessingSerializer,
|
95
99
|
cpu_has_amx_support,
|
100
|
+
dynamic_import,
|
96
101
|
enable_show_time_cost,
|
97
102
|
get_available_gpu_memory,
|
98
103
|
get_bool_env_var,
|
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
|
|
110
115
|
)
|
111
116
|
|
112
117
|
_is_hip = is_hip()
|
118
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
113
119
|
|
114
120
|
# Use a small KV cache pool size for tests in CI
|
115
121
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
@@ -149,7 +155,7 @@ class ModelRunner:
|
|
149
155
|
server_args: ServerArgs,
|
150
156
|
is_draft_worker: bool = False,
|
151
157
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
152
|
-
token_to_kv_pool_allocator: Optional[
|
158
|
+
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
153
159
|
):
|
154
160
|
# Parse args
|
155
161
|
self.model_config = model_config
|
@@ -162,6 +168,7 @@ class ModelRunner:
|
|
162
168
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
163
169
|
self.tp_rank = tp_rank
|
164
170
|
self.tp_size = tp_size
|
171
|
+
self.dp_size = server_args.dp_size
|
165
172
|
self.pp_rank = pp_rank
|
166
173
|
self.pp_size = pp_size
|
167
174
|
self.dist_port = nccl_port
|
@@ -195,6 +202,7 @@ class ModelRunner:
|
|
195
202
|
| {
|
196
203
|
# TODO it is indeed not a "server args"
|
197
204
|
"use_mla_backend": self.use_mla_backend,
|
205
|
+
"speculative_algorithm": self.spec_algorithm,
|
198
206
|
}
|
199
207
|
)
|
200
208
|
|
@@ -218,6 +226,7 @@ class ModelRunner:
|
|
218
226
|
|
219
227
|
def initialize(self, min_per_gpu_memory: float):
|
220
228
|
server_args = self.server_args
|
229
|
+
|
221
230
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
222
231
|
enable=self.server_args.enable_memory_saver
|
223
232
|
)
|
@@ -272,6 +281,10 @@ class ModelRunner:
|
|
272
281
|
self.apply_torch_tp()
|
273
282
|
|
274
283
|
# Init lora
|
284
|
+
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
|
285
|
+
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
|
286
|
+
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
|
287
|
+
# load LoRA adapters dynamically later.
|
275
288
|
if server_args.lora_paths is not None:
|
276
289
|
self.init_lora_manager()
|
277
290
|
|
@@ -299,7 +312,7 @@ class ModelRunner:
|
|
299
312
|
if (
|
300
313
|
server_args.attention_backend == "intel_amx"
|
301
314
|
and server_args.device == "cpu"
|
302
|
-
and not
|
315
|
+
and not _is_cpu_amx_available
|
303
316
|
):
|
304
317
|
logger.info(
|
305
318
|
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
@@ -543,7 +556,7 @@ class ModelRunner:
|
|
543
556
|
monkey_patch_vllm_parallel_state()
|
544
557
|
monkey_patch_isinstance_for_vllm_base_layer()
|
545
558
|
|
546
|
-
with self.memory_saver_adapter.region():
|
559
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
547
560
|
self.model = get_model(
|
548
561
|
model_config=self.model_config,
|
549
562
|
load_config=self.load_config,
|
@@ -761,6 +774,9 @@ class ModelRunner:
|
|
761
774
|
]
|
762
775
|
if load_format == "direct":
|
763
776
|
_model_load_weights_direct(self.model, named_tensors)
|
777
|
+
elif load_format in self.server_args.custom_weight_loader:
|
778
|
+
custom_loader = dynamic_import(load_format)
|
779
|
+
custom_loader(self.model, named_tensors)
|
764
780
|
elif load_format is None:
|
765
781
|
self.model.load_weights(named_tensors)
|
766
782
|
else:
|
@@ -787,7 +803,6 @@ class ModelRunner:
|
|
787
803
|
def init_lora_manager(self):
|
788
804
|
self.lora_manager = LoRAManager(
|
789
805
|
base_model=self.model,
|
790
|
-
lora_paths=self.server_args.lora_paths,
|
791
806
|
base_hf_config=self.model_config.hf_config,
|
792
807
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
793
808
|
load_config=self.load_config,
|
@@ -796,6 +811,7 @@ class ModelRunner:
|
|
796
811
|
tp_size=self.tp_size,
|
797
812
|
tp_rank=self.tp_rank,
|
798
813
|
)
|
814
|
+
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
799
815
|
logger.info("LoRA manager ready.")
|
800
816
|
|
801
817
|
def profile_max_num_token(self, total_gpu_memory: int):
|
@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader):
|
|
337
337
|
hf_weights_files,
|
338
338
|
)
|
339
339
|
elif use_safetensors:
|
340
|
-
|
340
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
341
|
+
|
342
|
+
weight_loader_disable_mmap = global_server_args_dict.get(
|
343
|
+
"weight_loader_disable_mmap"
|
344
|
+
)
|
345
|
+
weights_iterator = safetensors_weights_iterator(
|
346
|
+
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
347
|
+
)
|
341
348
|
else:
|
342
349
|
weights_iterator = pt_weights_iterator(hf_weights_files)
|
343
350
|
|
@@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|
34
34
|
from sglang.srt.configs.model_config import ModelConfig
|
35
35
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
36
36
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
37
|
+
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
37
38
|
from sglang.srt.utils import print_warning_once
|
38
39
|
|
39
40
|
logger = logging.getLogger(__name__)
|
@@ -206,7 +207,10 @@ def get_quant_config(
|
|
206
207
|
config["adapter_name_or_path"] = model_name_or_path
|
207
208
|
elif model_config.quantization == "modelopt":
|
208
209
|
if config["producer"]["name"] == "modelopt":
|
209
|
-
|
210
|
+
if "FP4" in config["quantization"]["quant_algo"]:
|
211
|
+
return ModelOptFp4Config.from_config(config)
|
212
|
+
else:
|
213
|
+
return quant_cls.from_config(config)
|
210
214
|
else:
|
211
215
|
raise ValueError(
|
212
216
|
f"Unsupported quantization config"
|
@@ -418,6 +422,7 @@ def safetensors_weights_iterator(
|
|
418
422
|
hf_weights_files: List[str],
|
419
423
|
is_all_weights_sharded: bool = False,
|
420
424
|
decryption_key: Optional[str] = None,
|
425
|
+
disable_mmap: bool = False,
|
421
426
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
422
427
|
"""Iterate over the weights in the model safetensor files.
|
423
428
|
|
@@ -439,7 +444,11 @@ def safetensors_weights_iterator(
|
|
439
444
|
disable=not enable_tqdm,
|
440
445
|
bar_format=_BAR_FORMAT,
|
441
446
|
):
|
442
|
-
|
447
|
+
if disable_mmap:
|
448
|
+
with open(st_file, "rb") as f:
|
449
|
+
result = safetensors.torch.load(f.read())
|
450
|
+
else:
|
451
|
+
result = safetensors.torch.load_file(st_file, device="cpu")
|
443
452
|
for name, param in result.items():
|
444
453
|
yield name, param
|
445
454
|
|
@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
|
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
|
-
from sglang.srt.layers.linear import ReplicatedLinear
|
26
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
27
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
@@ -45,6 +44,12 @@ class DeepseekModelNextN(nn.Module):
|
|
45
44
|
prefix: str = "",
|
46
45
|
) -> None:
|
47
46
|
super().__init__()
|
47
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
48
|
+
logger.warning(
|
49
|
+
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
50
|
+
)
|
51
|
+
quant_config = None
|
52
|
+
|
48
53
|
self.vocab_size = config.vocab_size
|
49
54
|
|
50
55
|
self.embed_tokens = VocabParallelEmbedding(
|
@@ -77,6 +82,7 @@ class DeepseekModelNextN(nn.Module):
|
|
77
82
|
forward_batch: ForwardBatch,
|
78
83
|
input_embeds: torch.Tensor = None,
|
79
84
|
) -> torch.Tensor:
|
85
|
+
|
80
86
|
zero_allocator = BumpAllocator(
|
81
87
|
buffer_size=2,
|
82
88
|
dtype=torch.float32,
|
@@ -90,15 +96,16 @@ class DeepseekModelNextN(nn.Module):
|
|
90
96
|
else:
|
91
97
|
hidden_states = input_embeds
|
92
98
|
|
93
|
-
hidden_states
|
94
|
-
|
95
|
-
(
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
99
|
+
if hidden_states.shape[0] > 0:
|
100
|
+
hidden_states = self.eh_proj(
|
101
|
+
torch.cat(
|
102
|
+
(
|
103
|
+
self.enorm(hidden_states),
|
104
|
+
self.hnorm(forward_batch.spec_info.hidden_states),
|
105
|
+
),
|
106
|
+
dim=-1,
|
107
|
+
)
|
100
108
|
)
|
101
|
-
)
|
102
109
|
|
103
110
|
residual = None
|
104
111
|
hidden_states, residual = self.decoder(
|
@@ -106,7 +113,11 @@ class DeepseekModelNextN(nn.Module):
|
|
106
113
|
)
|
107
114
|
|
108
115
|
if not forward_batch.forward_mode.is_idle():
|
109
|
-
|
116
|
+
if residual is not None:
|
117
|
+
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
|
118
|
+
else:
|
119
|
+
hidden_states = self.shared_head.norm(hidden_states)
|
120
|
+
|
110
121
|
return hidden_states
|
111
122
|
|
112
123
|
|
@@ -127,23 +138,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
127
138
|
self.model = DeepseekModelNextN(
|
128
139
|
config, quant_config, prefix=add_prefix("model", prefix)
|
129
140
|
)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
139
|
-
else:
|
140
|
-
self.lm_head = ParallelLMHead(
|
141
|
-
config.vocab_size,
|
142
|
-
config.hidden_size,
|
143
|
-
quant_config=quant_config,
|
144
|
-
prefix=add_prefix("model.shared_head.head", prefix),
|
145
|
-
)
|
146
|
-
self.logits_processor = LogitsProcessor(config)
|
141
|
+
self.lm_head = ParallelLMHead(
|
142
|
+
config.vocab_size,
|
143
|
+
config.hidden_size,
|
144
|
+
quant_config=quant_config,
|
145
|
+
prefix=add_prefix("model.shared_head.head", prefix),
|
146
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
147
|
+
)
|
148
|
+
self.logits_processor = LogitsProcessor(config)
|
147
149
|
|
148
150
|
@torch.no_grad()
|
149
151
|
def forward(
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
72
72
|
block_dequant as int8_block_dequant,
|
73
73
|
)
|
74
74
|
from sglang.srt.layers.radix_attention import RadixAttention
|
75
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
75
|
+
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
76
76
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
77
77
|
ParallelLMHead,
|
78
78
|
VocabParallelEmbedding,
|
@@ -95,8 +95,10 @@ from sglang.srt.utils import (
|
|
95
95
|
LazyValue,
|
96
96
|
add_prefix,
|
97
97
|
bind_or_assign,
|
98
|
+
cpu_has_amx_support,
|
98
99
|
get_bool_env_var,
|
99
100
|
get_int_env_var,
|
101
|
+
is_cpu,
|
100
102
|
is_cuda,
|
101
103
|
is_hip,
|
102
104
|
is_non_idle_and_non_empty,
|
@@ -107,9 +109,13 @@ _is_hip = is_hip()
|
|
107
109
|
_is_cuda = is_cuda()
|
108
110
|
_is_fp8_fnuz = is_fp8_fnuz()
|
109
111
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
112
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
113
|
+
_is_cpu = is_cpu()
|
110
114
|
|
111
115
|
if _is_cuda:
|
112
116
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
117
|
+
elif _is_cpu and _is_cpu_amx_available:
|
118
|
+
pass
|
113
119
|
else:
|
114
120
|
from vllm._custom_ops import awq_dequantize
|
115
121
|
|
@@ -220,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
|
|
220
226
|
layer_id: int,
|
221
227
|
quant_config: Optional[QuantizationConfig] = None,
|
222
228
|
prefix: str = "",
|
229
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
223
230
|
):
|
224
231
|
super().__init__()
|
225
232
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -232,6 +239,7 @@ class DeepseekV2MoE(nn.Module):
|
|
232
239
|
)
|
233
240
|
self.config = config
|
234
241
|
self.layer_id = layer_id
|
242
|
+
self.alt_stream = alt_stream
|
235
243
|
|
236
244
|
if self.tp_size > config.n_routed_experts:
|
237
245
|
raise ValueError(
|
@@ -269,6 +277,15 @@ class DeepseekV2MoE(nn.Module):
|
|
269
277
|
if global_server_args_dict["enable_deepep_moe"]
|
270
278
|
else {}
|
271
279
|
),
|
280
|
+
# Additional args for FusedMoE
|
281
|
+
**(
|
282
|
+
dict(
|
283
|
+
enable_flashinfer_moe=True,
|
284
|
+
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
285
|
+
)
|
286
|
+
if global_server_args_dict["enable_flashinfer_moe"]
|
287
|
+
else {}
|
288
|
+
),
|
272
289
|
)
|
273
290
|
|
274
291
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
@@ -332,10 +349,38 @@ class DeepseekV2MoE(nn.Module):
|
|
332
349
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
333
350
|
) -> torch.Tensor:
|
334
351
|
if not self._enable_deepep_moe:
|
335
|
-
|
352
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
353
|
+
if (
|
354
|
+
self.alt_stream is not None
|
355
|
+
and self.num_fused_shared_experts == 0
|
356
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
357
|
+
):
|
358
|
+
return self.forward_normal_dual_stream(hidden_states)
|
359
|
+
else:
|
360
|
+
return self.forward_normal(hidden_states)
|
336
361
|
else:
|
337
362
|
return self.forward_deepep(hidden_states, forward_batch)
|
338
363
|
|
364
|
+
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
365
|
+
# router_logits: (num_tokens, n_experts)
|
366
|
+
router_logits = self.gate(hidden_states)
|
367
|
+
|
368
|
+
current_stream = torch.cuda.current_stream()
|
369
|
+
self.alt_stream.wait_stream(current_stream)
|
370
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
371
|
+
|
372
|
+
with torch.cuda.stream(self.alt_stream):
|
373
|
+
final_hidden_states = self.experts(
|
374
|
+
hidden_states=hidden_states, router_logits=router_logits
|
375
|
+
)
|
376
|
+
if not _is_cuda:
|
377
|
+
final_hidden_states *= self.routed_scaling_factor
|
378
|
+
current_stream.wait_stream(self.alt_stream)
|
379
|
+
final_hidden_states = final_hidden_states + shared_output
|
380
|
+
if self.tp_size > 1:
|
381
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
382
|
+
return final_hidden_states
|
383
|
+
|
339
384
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
340
385
|
shared_output = self._forward_shared_experts(hidden_states)
|
341
386
|
# router_logits: (num_tokens, n_experts)
|
@@ -665,13 +710,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
665
710
|
if rope_scaling:
|
666
711
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
667
712
|
|
668
|
-
self.rotary_emb =
|
713
|
+
self.rotary_emb = get_rope_wrapper(
|
669
714
|
qk_rope_head_dim,
|
670
715
|
rotary_dim=qk_rope_head_dim,
|
671
716
|
max_position=max_position_embeddings,
|
672
717
|
base=rope_theta,
|
673
718
|
rope_scaling=rope_scaling,
|
674
719
|
is_neox_style=False,
|
720
|
+
device=global_server_args_dict["device"],
|
675
721
|
)
|
676
722
|
|
677
723
|
if rope_scaling:
|
@@ -1040,13 +1086,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1040
1086
|
masked_m,
|
1041
1087
|
expected_m,
|
1042
1088
|
)
|
1043
|
-
attn_bmm_output =
|
1089
|
+
attn_bmm_output = (
|
1090
|
+
attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
|
1091
|
+
)
|
1044
1092
|
elif _is_hip:
|
1045
1093
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1046
1094
|
attn_bmm_output = torch.bmm(
|
1047
1095
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1048
1096
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1049
1097
|
)
|
1098
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1050
1099
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1051
1100
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1052
1101
|
attn_output.transpose(0, 1),
|
@@ -1059,10 +1108,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1059
1108
|
self.w_scale,
|
1060
1109
|
torch.bfloat16,
|
1061
1110
|
)
|
1111
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1062
1112
|
else:
|
1063
|
-
attn_bmm_output = torch.
|
1064
|
-
|
1065
|
-
|
1113
|
+
attn_bmm_output = torch.empty(
|
1114
|
+
(attn_output.shape[0], self.num_local_heads * self.v_head_dim),
|
1115
|
+
dtype=attn_output.dtype,
|
1116
|
+
device=attn_output.device,
|
1117
|
+
)
|
1118
|
+
torch.bmm(
|
1119
|
+
attn_output.transpose(0, 1),
|
1120
|
+
self.w_vc,
|
1121
|
+
out=attn_bmm_output.view(
|
1122
|
+
-1, self.num_local_heads, self.v_head_dim
|
1123
|
+
).transpose(0, 1),
|
1124
|
+
)
|
1125
|
+
output, _ = self.o_proj(attn_bmm_output)
|
1066
1126
|
|
1067
1127
|
return output
|
1068
1128
|
|
@@ -1399,7 +1459,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1399
1459
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1400
1460
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1401
1461
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1462
|
+
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
1402
1463
|
self.layer_id = layer_id
|
1464
|
+
self.is_nextn = is_nextn
|
1403
1465
|
self.self_attn = DeepseekV2AttentionMLA(
|
1404
1466
|
config=config,
|
1405
1467
|
hidden_size=self.hidden_size,
|
@@ -1426,7 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1426
1488
|
|
1427
1489
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
1428
1490
|
layer_id=layer_id,
|
1429
|
-
num_layers=config.num_hidden_layers,
|
1491
|
+
num_layers=1 if is_nextn else config.num_hidden_layers,
|
1430
1492
|
is_layer_sparse=self.is_layer_sparse,
|
1431
1493
|
is_previous_layer_sparse=is_previous_layer_sparse,
|
1432
1494
|
)
|
@@ -1437,6 +1499,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1437
1499
|
quant_config=quant_config,
|
1438
1500
|
prefix=add_prefix("mlp", prefix),
|
1439
1501
|
layer_id=self.layer_id,
|
1502
|
+
alt_stream=alt_stream,
|
1440
1503
|
)
|
1441
1504
|
else:
|
1442
1505
|
if enable_moe_dense_fully_dp():
|
@@ -1479,6 +1542,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1479
1542
|
residual: Optional[torch.Tensor],
|
1480
1543
|
zero_allocator: BumpAllocator,
|
1481
1544
|
) -> torch.Tensor:
|
1545
|
+
|
1482
1546
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1483
1547
|
hidden_states, residual, forward_batch
|
1484
1548
|
)
|
@@ -1500,6 +1564,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1500
1564
|
hidden_states, residual, forward_batch
|
1501
1565
|
)
|
1502
1566
|
|
1567
|
+
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
1568
|
+
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
1569
|
+
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
1570
|
+
hidden_states = hidden_states.clone()
|
1571
|
+
|
1503
1572
|
return hidden_states, residual
|
1504
1573
|
|
1505
1574
|
def op_comm_prepare_attn(
|
@@ -1607,8 +1676,6 @@ class DeepseekV2Model(nn.Module):
|
|
1607
1676
|
)
|
1608
1677
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1609
1678
|
|
1610
|
-
self.dp_size = get_local_attention_dp_size()
|
1611
|
-
|
1612
1679
|
def get_input_embeddings(self) -> torch.Tensor:
|
1613
1680
|
return self.embed_tokens
|
1614
1681
|
|
@@ -1692,7 +1759,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1692
1759
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
1693
1760
|
)
|
1694
1761
|
self.logits_processor = LogitsProcessor(config)
|
1695
|
-
self.dp_size = get_local_attention_dp_size()
|
1696
1762
|
|
1697
1763
|
self._routed_experts_weights_of_layer = LazyValue(
|
1698
1764
|
lambda: {
|
@@ -1717,12 +1783,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1717
1783
|
disable_reason = None
|
1718
1784
|
if (
|
1719
1785
|
not _is_cuda
|
1720
|
-
or torch.cuda.get_device_capability("cuda") < (
|
1786
|
+
or torch.cuda.get_device_capability("cuda") < (8, 0)
|
1721
1787
|
or self.config.architectures[0] != architecture
|
1722
1788
|
or self.config.n_routed_experts != 256
|
1723
1789
|
or self.config.n_shared_experts != 1
|
1724
1790
|
):
|
1725
|
-
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >=
|
1791
|
+
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
1726
1792
|
elif (
|
1727
1793
|
global_server_args_dict["enable_deepep_moe"]
|
1728
1794
|
or global_server_args_dict["enable_ep_moe"]
|
@@ -1919,10 +1985,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1919
1985
|
if (
|
1920
1986
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1921
1987
|
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1988
|
+
and hasattr(self.quant_config, "weight_block_size")
|
1989
|
+
and self.quant_config.weight_block_size is not None
|
1922
1990
|
):
|
1923
|
-
self._weight_requant_ue8m0()
|
1991
|
+
self._weight_requant_ue8m0(is_nextn)
|
1924
1992
|
|
1925
|
-
def _weight_requant_ue8m0(self):
|
1993
|
+
def _weight_requant_ue8m0(self, is_nextn=False):
|
1926
1994
|
weight_block_size = self.quant_config.weight_block_size
|
1927
1995
|
|
1928
1996
|
moe_layers = list(
|
@@ -1933,8 +2001,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1933
2001
|
)
|
1934
2002
|
)
|
1935
2003
|
|
1936
|
-
|
1937
|
-
|
2004
|
+
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2005
|
+
for layer_id in range(num_hidden_layers):
|
2006
|
+
if is_nextn:
|
2007
|
+
layer = self.model.decoder
|
2008
|
+
else:
|
2009
|
+
layer = self.model.layers[layer_id]
|
1938
2010
|
|
1939
2011
|
for module in [
|
1940
2012
|
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
@@ -1946,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1946
2018
|
module.weight, module.weight_scale_inv, weight_block_size
|
1947
2019
|
)
|
1948
2020
|
|
1949
|
-
if layer_id in moe_layers:
|
2021
|
+
if layer_id in moe_layers or is_nextn:
|
1950
2022
|
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
1951
2023
|
if shared_experts is not None:
|
1952
2024
|
for module in [
|
@@ -2022,7 +2094,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2022
2094
|
|
2023
2095
|
if self.num_fused_shared_experts > 0:
|
2024
2096
|
assert self.num_fused_shared_experts == 1
|
2025
|
-
logger
|
2097
|
+
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
2026
2098
|
|
2027
2099
|
params_dict = dict(self.named_parameters())
|
2028
2100
|
weight_names = []
|
@@ -2128,8 +2200,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2128
2200
|
):
|
2129
2201
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
2130
2202
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
2203
|
+
cat_dim = 0
|
2204
|
+
if self.quant_config is not None and (
|
2205
|
+
self.quant_config.get_name() == "awq"
|
2206
|
+
or self.quant_config.get_name() == "moe_wna16"
|
2207
|
+
):
|
2208
|
+
cat_dim = 1
|
2131
2209
|
fused_weight = torch.cat(
|
2132
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=
|
2210
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2133
2211
|
)
|
2134
2212
|
param_name = (
|
2135
2213
|
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
@@ -2151,12 +2229,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2151
2229
|
"k_scale" in name or "v_scale" in name
|
2152
2230
|
) and name not in params_dict:
|
2153
2231
|
# modelopt attn kv scale is named differently
|
2154
|
-
|
2155
|
-
|
2156
|
-
|
2157
|
-
|
2158
|
-
|
2159
|
-
|
2232
|
+
for scale in ["k_scale", "v_scale"]:
|
2233
|
+
if scale in name:
|
2234
|
+
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
2235
|
+
break
|
2236
|
+
if name not in params_dict:
|
2237
|
+
# modelopt ckpt contains not needed weights for MTP module:
|
2238
|
+
# model.decoder.self_attn.attn_mqa.v_scale and
|
2239
|
+
# model.decoder.self_attn.attn_mqa.k_scale
|
2240
|
+
logger.warning(f"{name} not found in params_dict.")
|
2241
|
+
continue
|
2160
2242
|
param = params_dict[name]
|
2161
2243
|
weight_loader = getattr(
|
2162
2244
|
param, "weight_loader", default_weight_loader
|