sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
30
30
|
block_quant_to_tensor_quant,
|
31
31
|
normalize_e4m3fn_to_e4m3fnuz,
|
32
32
|
)
|
33
|
+
from sglang.srt.layers.quantization.int8_utils import (
|
34
|
+
block_dequant as int8_block_dequant,
|
35
|
+
)
|
33
36
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
34
37
|
ParallelLMHead,
|
35
38
|
VocabParallelEmbedding,
|
@@ -40,7 +43,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
40
43
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
41
44
|
from sglang.srt.utils import add_prefix, is_hip
|
42
45
|
|
43
|
-
|
46
|
+
_is_hip = is_hip()
|
44
47
|
|
45
48
|
|
46
49
|
class DeepseekModelNextN(nn.Module):
|
@@ -277,7 +280,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
277
280
|
weight_block_size = self.quant_config.weight_block_size
|
278
281
|
if weight_block_size is not None:
|
279
282
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
280
|
-
if
|
283
|
+
if _is_hip:
|
281
284
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
282
285
|
weight=w,
|
283
286
|
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
@@ -291,6 +294,23 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
291
294
|
weight, weight_scale, weight_block_size
|
292
295
|
)
|
293
296
|
self_attn.w_scale = scale
|
297
|
+
if w.dtype == torch.int8:
|
298
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
299
|
+
# block-wise int8 need it
|
300
|
+
weight_block_size = self.quant_config.weight_block_size
|
301
|
+
if weight_block_size is not None:
|
302
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
303
|
+
weight = w
|
304
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
305
|
+
w = int8_block_dequant(
|
306
|
+
weight, weight_scale, weight_block_size
|
307
|
+
).to(torch.bfloat16)
|
308
|
+
else:
|
309
|
+
# channel-wise int8 need it
|
310
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
311
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
312
|
+
torch.bfloat16
|
313
|
+
)
|
294
314
|
w_kc, w_vc = w.unflatten(
|
295
315
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
296
316
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
@@ -301,7 +321,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
301
321
|
and self_attn.w_scale is None
|
302
322
|
):
|
303
323
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
304
|
-
if
|
324
|
+
if _is_hip:
|
305
325
|
self_attn.w_scale *= 2.0
|
306
326
|
|
307
327
|
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
|
|
26
26
|
from vllm import _custom_ops as ops
|
27
27
|
|
28
28
|
from sglang.srt.distributed import (
|
29
|
-
get_tensor_model_parallel_rank,
|
30
29
|
get_tensor_model_parallel_world_size,
|
31
|
-
get_tp_group,
|
32
30
|
tensor_model_parallel_all_reduce,
|
33
31
|
)
|
34
32
|
from sglang.srt.layers.activation import SiluAndMul
|
35
33
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
36
34
|
decode_attention_fwd_grouped_rope,
|
37
35
|
)
|
36
|
+
from sglang.srt.layers.dp_attention import (
|
37
|
+
dp_gather,
|
38
|
+
dp_scatter,
|
39
|
+
get_attention_dp_size,
|
40
|
+
get_attention_tp_rank,
|
41
|
+
get_attention_tp_size,
|
42
|
+
)
|
38
43
|
from sglang.srt.layers.layernorm import RMSNorm
|
39
44
|
from sglang.srt.layers.linear import (
|
40
45
|
ColumnParallelLinear,
|
@@ -65,7 +70,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
65
70
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
66
71
|
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
67
72
|
|
68
|
-
|
73
|
+
_is_hip = is_hip()
|
69
74
|
|
70
75
|
if is_cuda_available():
|
71
76
|
from sgl_kernel import bmm_fp8
|
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
|
|
230
235
|
max_position_embeddings: int = 8192,
|
231
236
|
quant_config: Optional[QuantizationConfig] = None,
|
232
237
|
layer_id=None,
|
238
|
+
reduce_results: bool = True,
|
233
239
|
prefix: str = "",
|
234
240
|
) -> None:
|
235
241
|
super().__init__()
|
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
|
|
241
247
|
self.v_head_dim = v_head_dim
|
242
248
|
self.q_lora_rank = q_lora_rank
|
243
249
|
self.kv_lora_rank = kv_lora_rank
|
250
|
+
|
251
|
+
self.dp_size = get_attention_dp_size()
|
252
|
+
attn_tp_rank = get_attention_tp_rank()
|
253
|
+
attn_tp_size = get_attention_tp_size()
|
254
|
+
|
244
255
|
self.num_heads = num_heads
|
245
|
-
|
246
|
-
|
247
|
-
self.num_local_heads = num_heads // tp_size
|
256
|
+
assert num_heads % attn_tp_size == 0
|
257
|
+
self.num_local_heads = num_heads // attn_tp_size
|
248
258
|
self.scaling = self.qk_head_dim**-0.5
|
249
259
|
self.rope_theta = rope_theta
|
250
260
|
self.max_position_embeddings = max_position_embeddings
|
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
|
|
272
282
|
bias=False,
|
273
283
|
quant_config=quant_config,
|
274
284
|
prefix=add_prefix("q_proj", prefix),
|
285
|
+
tp_rank=attn_tp_rank,
|
286
|
+
tp_size=attn_tp_size,
|
275
287
|
)
|
276
288
|
|
277
289
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
|
|
296
308
|
bias=False,
|
297
309
|
quant_config=quant_config,
|
298
310
|
prefix=add_prefix("o_proj", prefix),
|
311
|
+
reduce_results=reduce_results,
|
312
|
+
tp_rank=attn_tp_rank,
|
313
|
+
tp_size=attn_tp_size,
|
299
314
|
)
|
300
315
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
301
316
|
self.rotary_emb = get_rope_wrapper(
|
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
|
|
330
345
|
hidden_states: torch.Tensor,
|
331
346
|
forward_batch: ForwardBatch,
|
332
347
|
) -> torch.Tensor:
|
348
|
+
if hidden_states.shape[0] == 0:
|
349
|
+
assert (
|
350
|
+
not self.o_proj.reduce_results
|
351
|
+
), "short-circuiting allreduce will lead to hangs"
|
352
|
+
return hidden_states
|
353
|
+
|
333
354
|
if self.q_lora_rank is not None:
|
334
355
|
q = self.q_a_proj(hidden_states)[0]
|
335
356
|
q = self.q_a_layernorm(q)
|
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
385
406
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
386
407
|
max_position_embeddings: int = 8192,
|
387
408
|
quant_config: Optional[QuantizationConfig] = None,
|
388
|
-
|
389
|
-
|
409
|
+
reduce_results: bool = True,
|
410
|
+
layer_id: int = None,
|
390
411
|
prefix: str = "",
|
391
412
|
) -> None:
|
392
413
|
super().__init__()
|
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
398
419
|
self.v_head_dim = v_head_dim
|
399
420
|
self.q_lora_rank = q_lora_rank
|
400
421
|
self.kv_lora_rank = kv_lora_rank
|
422
|
+
self.dp_size = get_attention_dp_size()
|
423
|
+
attn_tp_rank = get_attention_tp_rank()
|
424
|
+
attn_tp_size = get_attention_tp_size()
|
425
|
+
|
401
426
|
self.num_heads = num_heads
|
402
|
-
|
403
|
-
|
404
|
-
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
427
|
+
assert num_heads % attn_tp_size == 0
|
428
|
+
self.num_local_heads = num_heads // attn_tp_size
|
405
429
|
self.scaling = self.qk_head_dim**-0.5
|
406
430
|
self.rope_theta = rope_theta
|
407
431
|
self.max_position_embeddings = max_position_embeddings
|
408
432
|
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
self.q_a_proj = ReplicatedLinear(
|
413
|
-
self.hidden_size,
|
414
|
-
self.q_lora_rank,
|
415
|
-
bias=False,
|
416
|
-
quant_config=quant_config,
|
417
|
-
prefix=add_prefix("q_a_proj", prefix),
|
418
|
-
)
|
419
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
420
|
-
self.q_b_proj = ReplicatedLinear(
|
421
|
-
q_lora_rank,
|
422
|
-
self.num_heads * self.qk_head_dim,
|
423
|
-
bias=False,
|
424
|
-
quant_config=quant_config,
|
425
|
-
prefix=add_prefix("q_b_proj", prefix),
|
426
|
-
)
|
427
|
-
else:
|
428
|
-
self.q_proj = ReplicatedLinear(
|
429
|
-
self.hidden_size,
|
430
|
-
self.num_heads * self.qk_head_dim,
|
431
|
-
bias=False,
|
432
|
-
quant_config=quant_config,
|
433
|
-
prefix=add_prefix("q_proj", prefix),
|
434
|
-
)
|
435
|
-
self.kv_b_proj = ReplicatedLinear(
|
436
|
-
self.kv_lora_rank,
|
437
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
438
|
-
bias=False,
|
439
|
-
quant_config=quant_config,
|
440
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
441
|
-
)
|
442
|
-
# O projection.
|
443
|
-
self.o_proj = ReplicatedLinear(
|
444
|
-
self.num_heads * self.v_head_dim,
|
433
|
+
# For tensor parallel attention
|
434
|
+
if self.q_lora_rank is not None:
|
435
|
+
self.q_a_proj = ReplicatedLinear(
|
445
436
|
self.hidden_size,
|
437
|
+
self.q_lora_rank,
|
446
438
|
bias=False,
|
447
439
|
quant_config=quant_config,
|
448
|
-
prefix=add_prefix("
|
440
|
+
prefix=add_prefix("q_a_proj", prefix),
|
449
441
|
)
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
self.
|
454
|
-
self.hidden_size,
|
455
|
-
self.q_lora_rank,
|
456
|
-
bias=False,
|
457
|
-
quant_config=quant_config,
|
458
|
-
prefix=add_prefix("q_a_proj", prefix),
|
459
|
-
)
|
460
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
461
|
-
self.q_b_proj = ColumnParallelLinear(
|
462
|
-
q_lora_rank,
|
463
|
-
self.num_heads * self.qk_head_dim,
|
464
|
-
bias=False,
|
465
|
-
quant_config=quant_config,
|
466
|
-
prefix=add_prefix("q_b_proj", prefix),
|
467
|
-
)
|
468
|
-
else:
|
469
|
-
self.q_proj = ColumnParallelLinear(
|
470
|
-
self.hidden_size,
|
471
|
-
self.num_heads * self.qk_head_dim,
|
472
|
-
bias=False,
|
473
|
-
quant_config=quant_config,
|
474
|
-
prefix=add_prefix("q_proj", prefix),
|
475
|
-
)
|
476
|
-
self.kv_b_proj = ColumnParallelLinear(
|
477
|
-
self.kv_lora_rank,
|
478
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
442
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
443
|
+
self.q_b_proj = ColumnParallelLinear(
|
444
|
+
q_lora_rank,
|
445
|
+
self.num_heads * self.qk_head_dim,
|
479
446
|
bias=False,
|
480
447
|
quant_config=quant_config,
|
481
|
-
prefix=add_prefix("
|
448
|
+
prefix=add_prefix("q_b_proj", prefix),
|
449
|
+
tp_rank=attn_tp_rank,
|
450
|
+
tp_size=attn_tp_size,
|
482
451
|
)
|
483
|
-
|
484
|
-
self.
|
485
|
-
self.num_heads * self.v_head_dim,
|
452
|
+
else:
|
453
|
+
self.q_proj = ColumnParallelLinear(
|
486
454
|
self.hidden_size,
|
455
|
+
self.num_heads * self.qk_head_dim,
|
487
456
|
bias=False,
|
488
457
|
quant_config=quant_config,
|
489
|
-
prefix=add_prefix("
|
458
|
+
prefix=add_prefix("q_proj", prefix),
|
459
|
+
tp_rank=attn_tp_rank,
|
460
|
+
tp_size=attn_tp_size,
|
490
461
|
)
|
462
|
+
self.kv_b_proj = ColumnParallelLinear(
|
463
|
+
self.kv_lora_rank,
|
464
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
465
|
+
bias=False,
|
466
|
+
quant_config=quant_config,
|
467
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
468
|
+
tp_rank=attn_tp_rank,
|
469
|
+
tp_size=attn_tp_size,
|
470
|
+
)
|
471
|
+
# O projection.
|
472
|
+
self.o_proj = RowParallelLinear(
|
473
|
+
self.num_heads * self.v_head_dim,
|
474
|
+
self.hidden_size,
|
475
|
+
bias=False,
|
476
|
+
quant_config=quant_config,
|
477
|
+
reduce_results=reduce_results,
|
478
|
+
prefix=add_prefix("o_proj", prefix),
|
479
|
+
tp_rank=attn_tp_rank,
|
480
|
+
tp_size=attn_tp_size,
|
481
|
+
)
|
491
482
|
|
492
483
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
493
484
|
self.hidden_size,
|
@@ -542,36 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
542
533
|
self.w_vc = None
|
543
534
|
self.w_scale = None
|
544
535
|
|
536
|
+
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
537
|
+
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
538
|
+
"flashinfer_mla_disable_ragged"
|
539
|
+
]
|
540
|
+
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
541
|
+
|
542
|
+
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
543
|
+
if self.enable_flashinfer_mla:
|
544
|
+
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
545
|
+
return (
|
546
|
+
not self.flashinfer_mla_disable_ragged
|
547
|
+
and forward_batch.forward_mode.is_extend()
|
548
|
+
and not forward_batch.forward_mode.is_target_verify()
|
549
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
550
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
554
|
+
return (
|
555
|
+
forward_batch.forward_mode.is_extend()
|
556
|
+
and not forward_batch.forward_mode.is_target_verify()
|
557
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
558
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
559
|
+
)
|
560
|
+
|
545
561
|
def forward(
|
546
562
|
self,
|
547
563
|
positions: torch.Tensor,
|
548
564
|
hidden_states: torch.Tensor,
|
549
565
|
forward_batch: ForwardBatch,
|
550
566
|
) -> torch.Tensor:
|
567
|
+
if hidden_states.shape[0] == 0:
|
568
|
+
assert (
|
569
|
+
not self.o_proj.reduce_results
|
570
|
+
), "short-circuiting allreduce will lead to hangs"
|
571
|
+
return hidden_states
|
551
572
|
|
552
|
-
|
553
|
-
if global_server_args_dict["enable_flashinfer_mla"]:
|
554
|
-
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
555
|
-
return (
|
556
|
-
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
557
|
-
and forward_batch.forward_mode.is_extend()
|
558
|
-
and forward_batch.extend_prefix_lens.sum() == 0
|
559
|
-
)
|
560
|
-
else:
|
561
|
-
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
562
|
-
return (
|
563
|
-
forward_batch.forward_mode.is_extend()
|
564
|
-
and not forward_batch.forward_mode.is_target_verify()
|
565
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
566
|
-
and forward_batch.extend_prefix_lens.sum() == 0
|
567
|
-
)
|
568
|
-
|
569
|
-
if no_absorb():
|
573
|
+
if self.no_absorb(forward_batch):
|
570
574
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
571
575
|
else:
|
572
|
-
if
|
576
|
+
if _is_hip:
|
573
577
|
if (
|
574
|
-
|
578
|
+
self.rocm_fused_decode_mla
|
575
579
|
and forward_batch.forward_mode.is_decode()
|
576
580
|
):
|
577
581
|
return self.forward_absorb_fused_mla_rope(
|
@@ -843,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
843
847
|
return output
|
844
848
|
|
845
849
|
|
846
|
-
def all_gather(
|
847
|
-
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
848
|
-
):
|
849
|
-
if world_size == 1:
|
850
|
-
return input_tensor
|
851
|
-
|
852
|
-
all_lens = forward_batch.global_num_tokens_cpu
|
853
|
-
max_len = max(forward_batch.global_num_tokens_cpu)
|
854
|
-
|
855
|
-
padded_tensor = torch.nn.functional.pad(
|
856
|
-
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
857
|
-
)
|
858
|
-
|
859
|
-
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
860
|
-
|
861
|
-
gathered_tensors = torch.concat(
|
862
|
-
[
|
863
|
-
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
864
|
-
for i in range(world_size)
|
865
|
-
]
|
866
|
-
)
|
867
|
-
|
868
|
-
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
869
|
-
end_index = start_index + all_lens[rank]
|
870
|
-
|
871
|
-
return gathered_tensors, start_index, end_index
|
872
|
-
|
873
|
-
|
874
850
|
class DeepseekV2DecoderLayer(nn.Module):
|
875
851
|
|
876
852
|
def __init__(
|
@@ -886,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
886
862
|
rope_theta = getattr(config, "rope_theta", 10000)
|
887
863
|
rope_scaling = getattr(config, "rope_scaling", None)
|
888
864
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
889
|
-
self.enable_dp_attention =
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
if self.enable_dp_attention:
|
894
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
895
|
-
self.tp_size = get_tensor_model_parallel_world_size()
|
896
|
-
self.tp_group = get_tp_group()
|
865
|
+
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
866
|
+
self.layer_id = layer_id
|
867
|
+
self.dp_size = get_attention_dp_size()
|
868
|
+
|
897
869
|
if not global_server_args_dict["disable_mla"]:
|
898
870
|
self.self_attn = DeepseekV2AttentionMLA(
|
899
871
|
config=config,
|
@@ -911,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
911
883
|
max_position_embeddings=max_position_embeddings,
|
912
884
|
quant_config=quant_config,
|
913
885
|
layer_id=layer_id,
|
914
|
-
|
886
|
+
reduce_results=False,
|
915
887
|
prefix=add_prefix("self_attn", prefix),
|
916
888
|
)
|
917
889
|
else:
|
@@ -931,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
931
903
|
max_position_embeddings=max_position_embeddings,
|
932
904
|
quant_config=quant_config,
|
933
905
|
layer_id=layer_id,
|
906
|
+
reduce_results=False,
|
934
907
|
prefix=add_prefix("self_attn", prefix),
|
935
908
|
)
|
909
|
+
|
936
910
|
if is_nextn or (
|
937
911
|
config.n_routed_experts is not None
|
938
912
|
and layer_id >= config.first_k_dense_replace
|
@@ -963,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
963
937
|
forward_batch: ForwardBatch,
|
964
938
|
residual: Optional[torch.Tensor],
|
965
939
|
) -> torch.Tensor:
|
940
|
+
if residual is None:
|
941
|
+
residual = hidden_states
|
942
|
+
hidden_states = self.input_layernorm(hidden_states)
|
943
|
+
else:
|
944
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
945
|
+
|
946
|
+
# Scatter
|
947
|
+
if self.dp_size != 1:
|
948
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
949
|
+
# be careful about this!
|
950
|
+
hidden_states, global_hidden_states = (
|
951
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
952
|
+
hidden_states,
|
953
|
+
)
|
954
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
955
|
+
|
966
956
|
# Self Attention
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
957
|
+
hidden_states = self.self_attn(
|
958
|
+
positions=positions,
|
959
|
+
hidden_states=hidden_states,
|
960
|
+
forward_batch=forward_batch,
|
961
|
+
)
|
962
|
+
|
963
|
+
# Gather
|
964
|
+
if get_tensor_model_parallel_world_size() > 1:
|
965
|
+
# all gather and all reduce
|
966
|
+
if self.dp_size != 1:
|
967
|
+
hidden_states, local_hidden_states = (
|
968
|
+
forward_batch.gathered_buffer,
|
969
|
+
hidden_states,
|
970
|
+
)
|
971
|
+
dp_gather(
|
972
|
+
hidden_states, local_hidden_states, forward_batch, self.layer_id
|
973
|
+
)
|
971
974
|
else:
|
972
|
-
hidden_states
|
975
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
973
976
|
|
974
|
-
|
975
|
-
positions=positions,
|
976
|
-
hidden_states=hidden_states,
|
977
|
-
forward_batch=forward_batch,
|
978
|
-
)
|
979
|
-
hidden_states, residual = self.post_attention_layernorm(
|
980
|
-
hidden_states, residual
|
981
|
-
)
|
977
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
982
978
|
|
983
979
|
# Fully Connected
|
984
|
-
|
985
|
-
hidden_states, start_idx, end_idx = all_gather(
|
986
|
-
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
987
|
-
)
|
988
|
-
hidden_states = self.mlp(hidden_states)
|
989
|
-
hidden_states = hidden_states[start_idx:end_idx]
|
990
|
-
else:
|
991
|
-
hidden_states = self.mlp(hidden_states)
|
992
|
-
|
980
|
+
hidden_states = self.mlp(hidden_states)
|
993
981
|
return hidden_states, residual
|
994
982
|
|
995
983
|
|
@@ -1025,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
|
|
1025
1013
|
)
|
1026
1014
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1027
1015
|
|
1016
|
+
self.dp_size = get_attention_dp_size()
|
1017
|
+
|
1028
1018
|
def forward(
|
1029
1019
|
self,
|
1030
1020
|
input_ids: torch.Tensor,
|
1031
1021
|
positions: torch.Tensor,
|
1032
1022
|
forward_batch: ForwardBatch,
|
1033
1023
|
) -> torch.Tensor:
|
1024
|
+
|
1025
|
+
# Gather
|
1026
|
+
if self.dp_size != 1:
|
1027
|
+
input_ids, local_input_ids = (
|
1028
|
+
torch.empty(
|
1029
|
+
(forward_batch.gathered_buffer.shape[0],),
|
1030
|
+
dtype=input_ids.dtype,
|
1031
|
+
device=input_ids.device,
|
1032
|
+
),
|
1033
|
+
input_ids,
|
1034
|
+
)
|
1035
|
+
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
1036
|
+
|
1034
1037
|
hidden_states = self.embed_tokens(input_ids)
|
1035
1038
|
residual = None
|
1036
1039
|
for i in range(len(self.layers)):
|
@@ -1057,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1057
1060
|
self.model = DeepseekV2Model(
|
1058
1061
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1059
1062
|
)
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
else:
|
1069
|
-
self.lm_head = ParallelLMHead(
|
1070
|
-
config.vocab_size,
|
1071
|
-
config.hidden_size,
|
1072
|
-
quant_config=quant_config,
|
1073
|
-
prefix=add_prefix("lm_head", prefix),
|
1074
|
-
)
|
1075
|
-
self.logits_processor = LogitsProcessor(config)
|
1063
|
+
self.lm_head = ParallelLMHead(
|
1064
|
+
config.vocab_size,
|
1065
|
+
config.hidden_size,
|
1066
|
+
quant_config=quant_config,
|
1067
|
+
prefix=add_prefix("lm_head", prefix),
|
1068
|
+
)
|
1069
|
+
self.logits_processor = LogitsProcessor(config)
|
1070
|
+
self.dp_size = get_attention_dp_size()
|
1076
1071
|
|
1077
1072
|
@torch.no_grad()
|
1078
1073
|
def forward(
|
@@ -1082,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1082
1077
|
forward_batch: ForwardBatch,
|
1083
1078
|
) -> torch.Tensor:
|
1084
1079
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
1080
|
+
|
1081
|
+
if self.dp_size != 1:
|
1082
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1083
|
+
# be careful about this!
|
1084
|
+
hidden_states, global_hidden_states = (
|
1085
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1086
|
+
hidden_states,
|
1087
|
+
)
|
1088
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1089
|
+
|
1085
1090
|
return self.logits_processor(
|
1086
1091
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1087
1092
|
)
|
@@ -1188,7 +1193,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1188
1193
|
weight_block_size = self.quant_config.weight_block_size
|
1189
1194
|
if weight_block_size is not None:
|
1190
1195
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1191
|
-
if
|
1196
|
+
if _is_hip:
|
1192
1197
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1193
1198
|
weight=w,
|
1194
1199
|
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
@@ -1202,18 +1207,22 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1202
1207
|
weight, weight_scale, weight_block_size
|
1203
1208
|
)
|
1204
1209
|
self_attn.w_scale = scale
|
1205
|
-
if
|
1206
|
-
hasattr(self.quant_config, "weight_block_size")
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1210
|
+
if w.dtype == torch.int8:
|
1211
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
1212
|
+
# block-wise int8 need it
|
1213
|
+
weight_block_size = self.quant_config.weight_block_size
|
1214
|
+
if weight_block_size is not None:
|
1215
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1216
|
+
weight = w
|
1217
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1218
|
+
w = int8_block_dequant(
|
1219
|
+
weight, weight_scale, weight_block_size
|
1220
|
+
).to(torch.bfloat16)
|
1221
|
+
else:
|
1222
|
+
# channel-wise int8 need it
|
1223
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
1224
|
+
torch.bfloat16
|
1225
|
+
)
|
1217
1226
|
w_kc, w_vc = w.unflatten(
|
1218
1227
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1219
1228
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
@@ -1224,7 +1233,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1224
1233
|
and self_attn.w_scale is None
|
1225
1234
|
):
|
1226
1235
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
1227
|
-
if
|
1236
|
+
if _is_hip:
|
1228
1237
|
self_attn.w_scale *= 2.0
|
1229
1238
|
|
1230
1239
|
def get_embed_and_head(self):
|