sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -173,6 +173,7 @@ class ModelRunner:
|
|
173
173
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
174
174
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
175
175
|
"use_mla_backend": self.use_mla_backend,
|
176
|
+
"mm_attention_backend": server_args.mm_attention_backend,
|
176
177
|
}
|
177
178
|
)
|
178
179
|
|
@@ -278,9 +279,10 @@ class ModelRunner:
|
|
278
279
|
server_args.attention_backend = "fa3"
|
279
280
|
else:
|
280
281
|
server_args.attention_backend = "triton"
|
281
|
-
|
282
|
-
|
283
|
-
|
282
|
+
if self.should_log:
|
283
|
+
logger.info(
|
284
|
+
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
285
|
+
)
|
284
286
|
elif self.use_mla_backend:
|
285
287
|
if server_args.device != "cpu":
|
286
288
|
if server_args.attention_backend in [
|
@@ -290,9 +292,10 @@ class ModelRunner:
|
|
290
292
|
"flashmla",
|
291
293
|
"cutlass_mla",
|
292
294
|
]:
|
293
|
-
|
294
|
-
|
295
|
-
|
295
|
+
if self.should_log:
|
296
|
+
logger.info(
|
297
|
+
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
298
|
+
)
|
296
299
|
else:
|
297
300
|
raise ValueError(
|
298
301
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
@@ -311,9 +314,10 @@ class ModelRunner:
|
|
311
314
|
server_args.attention_backend = "triton"
|
312
315
|
|
313
316
|
if server_args.enable_double_sparsity:
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
+
if self.should_log:
|
318
|
+
logger.info(
|
319
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
320
|
+
)
|
317
321
|
server_args.attention_backend = "triton"
|
318
322
|
server_args.disable_cuda_graph = True
|
319
323
|
if server_args.ds_heavy_channel_type is None:
|
@@ -324,23 +328,26 @@ class ModelRunner:
|
|
324
328
|
|
325
329
|
if self.is_multimodal:
|
326
330
|
self.mem_fraction_static *= 0.90
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
331
|
+
if self.should_log:
|
332
|
+
logger.info(
|
333
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
334
|
+
f"because this is a multimodal model."
|
335
|
+
)
|
336
|
+
logger.info(
|
337
|
+
"Automatically turn off --chunked-prefill-size for multimodal model."
|
338
|
+
)
|
334
339
|
server_args.chunked_prefill_size = -1
|
335
340
|
|
336
341
|
if not self.use_mla_backend:
|
337
342
|
server_args.disable_chunked_prefix_cache = True
|
338
343
|
elif self.page_size > 1:
|
339
|
-
|
344
|
+
if self.should_log:
|
345
|
+
logger.info("Disable chunked prefix cache when page size > 1.")
|
340
346
|
server_args.disable_chunked_prefix_cache = True
|
341
347
|
|
342
348
|
if not server_args.disable_chunked_prefix_cache:
|
343
|
-
|
349
|
+
if self.should_log:
|
350
|
+
logger.info("Chunked prefix cache is turned on.")
|
344
351
|
|
345
352
|
def init_torch_distributed(self):
|
346
353
|
logger.info("Init torch distributed begin.")
|
@@ -361,6 +368,8 @@ class ModelRunner:
|
|
361
368
|
backend = "hccl"
|
362
369
|
elif self.device == "cpu":
|
363
370
|
backend = "gloo"
|
371
|
+
elif self.device == "npu":
|
372
|
+
backend = "hccl"
|
364
373
|
|
365
374
|
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
366
375
|
if not self.server_args.enable_p2p_check:
|
@@ -431,9 +440,10 @@ class ModelRunner:
|
|
431
440
|
torch.set_num_threads(1)
|
432
441
|
if self.device == "cuda":
|
433
442
|
if torch.cuda.get_device_capability()[0] < 8:
|
434
|
-
|
435
|
-
|
436
|
-
|
443
|
+
if self.should_log:
|
444
|
+
logger.info(
|
445
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
446
|
+
)
|
437
447
|
self.server_args.dtype = "float16"
|
438
448
|
self.model_config.dtype = torch.float16
|
439
449
|
if torch.cuda.get_device_capability()[1] < 5:
|
@@ -469,10 +479,11 @@ class ModelRunner:
|
|
469
479
|
self.model.load_kv_cache_scales(
|
470
480
|
self.server_args.quantization_param_path
|
471
481
|
)
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
482
|
+
if self.should_log:
|
483
|
+
logger.info(
|
484
|
+
"Loaded KV cache scaling factors from %s",
|
485
|
+
self.server_args.quantization_param_path,
|
486
|
+
)
|
476
487
|
else:
|
477
488
|
raise RuntimeError(
|
478
489
|
"Using FP8 KV cache and scaling factors provided but "
|
@@ -547,12 +558,7 @@ class ModelRunner:
|
|
547
558
|
return iter
|
548
559
|
|
549
560
|
def model_load_weights(model, iter):
|
550
|
-
|
551
|
-
for _, module in self.model.named_modules():
|
552
|
-
quant_method = getattr(module, "quant_method", None)
|
553
|
-
if quant_method is not None:
|
554
|
-
with device_loading_context(module, target_device):
|
555
|
-
quant_method.process_weights_after_loading(module)
|
561
|
+
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
|
556
562
|
return model
|
557
563
|
|
558
564
|
with set_default_torch_dtype(self.model_config.dtype):
|
@@ -1019,7 +1025,8 @@ class ModelRunner:
|
|
1019
1025
|
)
|
1020
1026
|
|
1021
1027
|
def apply_torch_tp(self):
|
1022
|
-
|
1028
|
+
if self.should_log:
|
1029
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1023
1030
|
from sglang.srt.model_parallel import tensor_parallel
|
1024
1031
|
|
1025
1032
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
@@ -1138,7 +1145,9 @@ class ModelRunner:
|
|
1138
1145
|
[self.sample(values, forward_batch) for values in logits_output],
|
1139
1146
|
axis=-1,
|
1140
1147
|
)
|
1141
|
-
|
1148
|
+
sampling_info = forward_batch.sampling_info
|
1149
|
+
if sampling_info.thinking_budgets is not None:
|
1150
|
+
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
|
1142
1151
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
1143
1152
|
|
1144
1153
|
# Sample the next tokens
|
@@ -1149,6 +1158,8 @@ class ModelRunner:
|
|
1149
1158
|
forward_batch.top_logprobs_nums,
|
1150
1159
|
forward_batch.token_ids_logprobs,
|
1151
1160
|
)
|
1161
|
+
if sampling_info.thinking_budgets is not None:
|
1162
|
+
sampling_info.update_thinking_budgets(next_token_ids)
|
1152
1163
|
return next_token_ids
|
1153
1164
|
|
1154
1165
|
@property
|
@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
|
|
374
374
|
self.load_config,
|
375
375
|
)
|
376
376
|
|
377
|
-
|
377
|
+
self.load_weights_and_postprocess(
|
378
|
+
model, self._get_all_weights(model_config, model), target_device
|
379
|
+
)
|
378
380
|
|
379
|
-
for _, module in model.named_modules():
|
380
|
-
quant_method = getattr(module, "quant_method", None)
|
381
|
-
if quant_method is not None:
|
382
|
-
# When quant methods need to process weights after loading
|
383
|
-
# (for repacking, quantizing, etc), they expect parameters
|
384
|
-
# to be on the global target device. This scope is for the
|
385
|
-
# case where cpu offloading is used, where we will move the
|
386
|
-
# parameters onto device for processing and back off after.
|
387
|
-
with device_loading_context(module, target_device):
|
388
|
-
quant_method.process_weights_after_loading(module)
|
389
381
|
return model.eval()
|
390
382
|
|
383
|
+
@staticmethod
|
384
|
+
def load_weights_and_postprocess(model, weights, target_device):
|
385
|
+
model.load_weights(weights)
|
386
|
+
|
387
|
+
for _, module in model.named_modules():
|
388
|
+
quant_method = getattr(module, "quant_method", None)
|
389
|
+
if quant_method is not None:
|
390
|
+
# When quant methods need to process weights after loading
|
391
|
+
# (for repacking, quantizing, etc), they expect parameters
|
392
|
+
# to be on the global target device. This scope is for the
|
393
|
+
# case where cpu offloading is used, where we will move the
|
394
|
+
# parameters onto device for processing and back off after.
|
395
|
+
with device_loading_context(module, target_device):
|
396
|
+
quant_method.process_weights_after_loading(module)
|
397
|
+
|
391
398
|
|
392
399
|
class LayeredModelLoader(DefaultModelLoader):
|
393
400
|
"""Model loader that loads weights layer by layer so that one can quantize a
|
sglang/srt/models/clip.py
CHANGED
@@ -151,20 +151,20 @@ class CLIPEncoderLayer(nn.Module):
|
|
151
151
|
self.layer_norm1 = norm_layer(config.hidden_size)
|
152
152
|
self.layer_norm2 = norm_layer(config.hidden_size)
|
153
153
|
if attn_implementation == "sdpa":
|
154
|
-
|
154
|
+
qkv_backend = "sdpa"
|
155
155
|
softmax_in_single_precision = False
|
156
156
|
elif attn_implementation == "flash_attention_2":
|
157
|
+
qkv_backend = "triton_attn"
|
157
158
|
softmax_in_single_precision = False
|
158
|
-
use_context_forward = True
|
159
159
|
elif attn_implementation == "eager":
|
160
|
+
qkv_backend = "sdpa"
|
160
161
|
softmax_in_single_precision = True
|
161
|
-
use_context_forward = False
|
162
162
|
self.self_attn = VisionAttention(
|
163
163
|
embed_dim=config.hidden_size,
|
164
164
|
num_heads=config.num_attention_heads,
|
165
165
|
projection_size=config.hidden_size,
|
166
166
|
use_qkv_parallel=True,
|
167
|
-
|
167
|
+
qkv_backend=qkv_backend,
|
168
168
|
softmax_in_single_precision=softmax_in_single_precision,
|
169
169
|
flatten_batch=True,
|
170
170
|
quant_config=quant_config,
|
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
24
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
25
|
from sglang.srt.layers.linear import ReplicatedLinear
|
26
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
29
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
|
-
from sglang.srt.layers.quantization.fp8_utils import (
|
31
|
-
block_quant_to_tensor_quant,
|
32
|
-
normalize_e4m3fn_to_e4m3fnuz,
|
33
|
-
)
|
34
|
-
from sglang.srt.layers.quantization.int8_utils import (
|
35
|
-
block_dequant as int8_block_dequant,
|
36
|
-
)
|
37
28
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
29
|
ParallelLMHead,
|
39
30
|
VocabParallelEmbedding,
|
40
31
|
)
|
41
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
-
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
34
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
45
|
-
from sglang.srt.utils import BumpAllocator, add_prefix
|
46
|
-
|
47
|
-
_is_hip = is_hip()
|
48
|
-
_is_cuda = is_cuda()
|
49
|
-
|
50
|
-
if _is_cuda:
|
51
|
-
from sgl_kernel import awq_dequantize
|
52
|
-
else:
|
53
|
-
from vllm._custom_ops import awq_dequantize
|
54
|
-
|
35
|
+
from sglang.srt.utils import BumpAllocator, add_prefix
|
55
36
|
|
56
37
|
logger = logging.getLogger(__name__)
|
57
38
|
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -59,10 +59,11 @@ from sglang.srt.layers.moe.topk import select_experts
|
|
59
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
60
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
61
61
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
62
|
-
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
63
62
|
per_tensor_quant_mla_fp8,
|
63
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.quantization.fp8_utils import (
|
66
|
+
block_quant_dequant,
|
66
67
|
block_quant_to_tensor_quant,
|
67
68
|
channel_quant_to_tensor_quant,
|
68
69
|
normalize_e4m3fn_to_e4m3fnuz,
|
@@ -88,6 +89,7 @@ from sglang.srt.utils import (
|
|
88
89
|
get_int_env_var,
|
89
90
|
is_cuda,
|
90
91
|
is_hip,
|
92
|
+
log_info_on_rank0,
|
91
93
|
)
|
92
94
|
|
93
95
|
_is_hip = is_hip()
|
@@ -356,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
|
|
356
358
|
topk_idx,
|
357
359
|
topk_weights,
|
358
360
|
reorder_topk_ids,
|
361
|
+
num_recv_tokens_per_expert,
|
359
362
|
seg_indptr,
|
360
363
|
masked_m,
|
361
364
|
expected_m,
|
@@ -367,10 +370,13 @@ class DeepseekV2MoE(nn.Module):
|
|
367
370
|
)
|
368
371
|
final_hidden_states = self.experts(
|
369
372
|
hidden_states=hidden_states,
|
373
|
+
topk_idx=topk_idx,
|
374
|
+
topk_weights=topk_weights,
|
370
375
|
reorder_topk_ids=reorder_topk_ids,
|
371
376
|
seg_indptr=seg_indptr,
|
372
377
|
masked_m=masked_m,
|
373
378
|
expected_m=expected_m,
|
379
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
374
380
|
forward_mode=forward_mode,
|
375
381
|
)
|
376
382
|
if self.ep_size > 1:
|
@@ -421,6 +427,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
421
427
|
reduce_results: bool = True,
|
422
428
|
layer_id: int = None,
|
423
429
|
prefix: str = "",
|
430
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
424
431
|
) -> None:
|
425
432
|
super().__init__()
|
426
433
|
self.layer_id = layer_id
|
@@ -543,6 +550,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
543
550
|
prefix=add_prefix("attn_mha", prefix),
|
544
551
|
)
|
545
552
|
|
553
|
+
self.alt_stream = alt_stream
|
554
|
+
|
546
555
|
self.w_kc = None
|
547
556
|
self.w_vc = None
|
548
557
|
self.w_scale = None
|
@@ -706,20 +715,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
706
715
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
707
716
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
708
717
|
)
|
709
|
-
|
718
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
719
|
+
|
720
|
+
# overlap qk norm
|
721
|
+
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
722
|
+
current_stream = torch.cuda.current_stream()
|
723
|
+
self.alt_stream.wait_stream(current_stream)
|
724
|
+
q = self.q_a_layernorm(q)
|
725
|
+
with torch.cuda.stream(self.alt_stream):
|
726
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
727
|
+
current_stream.wait_stream(self.alt_stream)
|
728
|
+
else:
|
729
|
+
q = self.q_a_layernorm(q)
|
730
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
731
|
+
|
732
|
+
k_nope = k_nope.unsqueeze(1)
|
710
733
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
711
734
|
else:
|
712
735
|
q = self.q_proj(hidden_states)[0].view(
|
713
736
|
-1, self.num_local_heads, self.qk_head_dim
|
714
737
|
)
|
715
738
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
739
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
740
|
+
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
741
|
+
|
716
742
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
743
|
+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
717
744
|
|
718
745
|
if self.use_deep_gemm_bmm:
|
719
746
|
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
720
|
-
|
721
|
-
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
722
|
-
)
|
747
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
|
723
748
|
)
|
724
749
|
q_nope_out = q_nope.new_empty(
|
725
750
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
@@ -750,14 +775,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
750
775
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
751
776
|
|
752
777
|
q_nope_out = q_nope_out.transpose(0, 1)
|
753
|
-
|
754
|
-
k_nope = latent_cache[..., : self.kv_lora_rank]
|
755
|
-
k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
|
756
|
-
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
757
|
-
|
758
778
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
759
779
|
|
760
|
-
if self.attention_backend == "fa3":
|
780
|
+
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
761
781
|
attn_output = self.attn_mqa(
|
762
782
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
763
783
|
)
|
@@ -769,8 +789,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
769
789
|
|
770
790
|
if self.use_deep_gemm_bmm:
|
771
791
|
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
772
|
-
|
773
|
-
attn_output.transpose(0, 1)
|
792
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
793
|
+
attn_output.transpose(0, 1)
|
774
794
|
)
|
775
795
|
)
|
776
796
|
attn_bmm_output = attn_output.new_empty(
|
@@ -1104,6 +1124,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1104
1124
|
quant_config: Optional[QuantizationConfig] = None,
|
1105
1125
|
is_nextn: bool = False,
|
1106
1126
|
prefix: str = "",
|
1127
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
1107
1128
|
) -> None:
|
1108
1129
|
super().__init__()
|
1109
1130
|
self.hidden_size = config.hidden_size
|
@@ -1133,6 +1154,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1133
1154
|
layer_id=layer_id,
|
1134
1155
|
reduce_results=False,
|
1135
1156
|
prefix=add_prefix("self_attn", prefix),
|
1157
|
+
alt_stream=alt_stream,
|
1136
1158
|
)
|
1137
1159
|
|
1138
1160
|
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
@@ -1376,6 +1398,7 @@ class DeepseekV2Model(nn.Module):
|
|
1376
1398
|
config.hidden_size,
|
1377
1399
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1378
1400
|
)
|
1401
|
+
self.alt_stream = torch.cuda.Stream()
|
1379
1402
|
self.layers = nn.ModuleList(
|
1380
1403
|
[
|
1381
1404
|
DeepseekV2DecoderLayer(
|
@@ -1383,6 +1406,7 @@ class DeepseekV2Model(nn.Module):
|
|
1383
1406
|
layer_id,
|
1384
1407
|
quant_config=quant_config,
|
1385
1408
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
1409
|
+
alt_stream=self.alt_stream,
|
1386
1410
|
)
|
1387
1411
|
for layer_id in range(config.num_hidden_layers)
|
1388
1412
|
]
|
@@ -1467,8 +1491,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1467
1491
|
):
|
1468
1492
|
self.n_share_experts_fusion = 0
|
1469
1493
|
global_server_args_dict["n_share_experts_fusion"] = 0
|
1470
|
-
|
1471
|
-
|
1494
|
+
log_info_on_rank0(
|
1495
|
+
logger,
|
1496
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1472
1497
|
)
|
1473
1498
|
else:
|
1474
1499
|
assert (
|
@@ -1483,8 +1508,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1483
1508
|
):
|
1484
1509
|
self.n_share_experts_fusion = self.tp_size
|
1485
1510
|
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1486
|
-
|
1487
|
-
|
1511
|
+
log_info_on_rank0(
|
1512
|
+
logger,
|
1513
|
+
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1488
1514
|
)
|
1489
1515
|
|
1490
1516
|
def get_input_embeddings(self) -> nn.Embedding:
|
@@ -1564,13 +1590,22 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1564
1590
|
|
1565
1591
|
if (
|
1566
1592
|
_is_cuda
|
1567
|
-
and _ENABLE_JIT_DEEPGEMM
|
1568
1593
|
and weight_block_size[0] == 128
|
1569
1594
|
and weight_block_size[1] == 128
|
1570
1595
|
and model_dtype == torch.bfloat16
|
1571
1596
|
):
|
1572
|
-
|
1573
|
-
|
1597
|
+
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
1598
|
+
"SGL_USE_DEEPGEMM_BMM", "false"
|
1599
|
+
):
|
1600
|
+
block_scale = weight_scale
|
1601
|
+
use_deep_gemm_bmm = True
|
1602
|
+
else:
|
1603
|
+
w = block_quant_dequant(
|
1604
|
+
weight,
|
1605
|
+
weight_scale,
|
1606
|
+
weight_block_size,
|
1607
|
+
model_dtype,
|
1608
|
+
)
|
1574
1609
|
else:
|
1575
1610
|
w, scale = block_quant_to_tensor_quant(
|
1576
1611
|
weight, weight_scale, weight_block_size
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
281
281
|
pixel_values = torch.stack(
|
282
282
|
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
283
283
|
)
|
284
|
-
pixel_values = pixel_values.to(
|
284
|
+
pixel_values = pixel_values.to(device=self.vision_tower.device)
|
285
285
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
286
286
|
|
287
287
|
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
sglang/srt/models/internlm2.py
CHANGED