sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +67 -13
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +36 -12
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +35 -4
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +38 -12
- sglang/srt/managers/scheduler.py +41 -28
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +19 -25
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +50 -11
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +31 -24
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +5 -1
- sglang/test/runners.py +6 -13
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +74 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ if is_cuda():
|
|
25
25
|
|
26
26
|
sm_version = get_device_sm()
|
27
27
|
if sm_version == 90:
|
28
|
-
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="
|
28
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
29
29
|
_ENABLE_JIT_DEEPGEMM = True
|
30
30
|
|
31
31
|
logger = logging.getLogger(__name__)
|
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
|
34
34
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
35
35
|
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
36
36
|
)
|
37
|
-
|
37
|
+
_DO_COMPILE_ALL = True
|
38
|
+
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
38
39
|
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
39
|
-
|
40
|
+
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
40
41
|
|
41
42
|
# Force redirect deep_gemm cache_dir
|
42
43
|
os.environ["DG_CACHE_DIR"] = os.getenv(
|
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
|
|
46
47
|
|
47
48
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
48
49
|
global _BUILTIN_M_LIST
|
49
|
-
global
|
50
|
+
global _DO_COMPILE_ALL
|
51
|
+
global _IS_FIRST_RANK_ON_NODE
|
50
52
|
|
51
53
|
# Generate m_max
|
52
54
|
m_max = 1024 * 16
|
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
|
57
59
|
m_max = min(1024 * 128, m_max)
|
58
60
|
_BUILTIN_M_LIST = list(range(1, m_max + 1))
|
59
61
|
|
60
|
-
|
61
|
-
|
62
|
+
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
|
63
|
+
|
64
|
+
# Check if is the first rank on node.
|
65
|
+
# Default each rank will try compile all Ms to
|
66
|
+
# load all symbols at the launch stages.
|
67
|
+
# Avoid loading symbols at the serving stages.
|
68
|
+
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
|
62
69
|
|
63
70
|
|
64
71
|
class DeepGemmKernelType(IntEnum):
|
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
|
|
89
96
|
|
90
97
|
|
91
98
|
def _compile_warning_1():
|
92
|
-
if not
|
99
|
+
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
93
100
|
logger.warning(
|
94
101
|
"Entering DeepGEMM JIT Pre-Complie session. "
|
95
102
|
"And it may takes a long time(Typically 10-20 mins) "
|
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
276
283
|
query_key = (kernel_type, n, k, num_groups)
|
277
284
|
if (
|
278
285
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
|
279
|
-
and
|
286
|
+
and _DO_COMPILE_ALL
|
280
287
|
and _INITIALIZATION_DICT.get(query_key) is None
|
281
288
|
):
|
282
289
|
_INITIALIZATION_DICT[query_key] = True
|
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
286
293
|
logger.info(
|
287
294
|
f"Try DeepGEMM JIT Compiling for "
|
288
295
|
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
289
|
-
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not
|
296
|
+
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
290
297
|
)
|
291
298
|
|
292
299
|
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
|
|
355
362
|
|
356
363
|
@contextmanager
|
357
364
|
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
358
|
-
if
|
365
|
+
if _IN_PRECOMPILE_STAGE:
|
359
366
|
yield
|
360
367
|
return
|
361
368
|
|
@@ -72,8 +72,8 @@ _is_hip = is_hip()
|
|
72
72
|
_is_cuda = is_cuda()
|
73
73
|
|
74
74
|
if _is_hip:
|
75
|
-
from aiter import ActivationType
|
76
|
-
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
75
|
+
from aiter import ActivationType, QuantType
|
76
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
77
77
|
from aiter.ops.shuffle import shuffle_weight
|
78
78
|
|
79
79
|
if not _is_cuda:
|
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
|
|
484
484
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
485
485
|
params_dtype = (
|
486
486
|
torch.uint32
|
487
|
-
if get_bool_env_var("
|
487
|
+
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
488
488
|
else torch.float8_e4m3fn
|
489
489
|
)
|
490
490
|
tp_size = get_tensor_model_parallel_world_size()
|
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
|
|
511
511
|
)
|
512
512
|
|
513
513
|
# WEIGHTS
|
514
|
-
if _is_hip and get_bool_env_var("
|
514
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
515
515
|
# INT4 MoE weight - INT32 packed
|
516
516
|
w13_weight = torch.nn.Parameter(
|
517
517
|
torch.empty(
|
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
|
|
585
585
|
|
586
586
|
if (
|
587
587
|
_is_hip
|
588
|
-
): # and get_bool_env_var("
|
588
|
+
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
589
589
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
590
590
|
w13_weight_scale1 = torch.nn.Parameter(
|
591
591
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
|
|
612
612
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
613
613
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
614
614
|
|
615
|
-
if _is_hip and get_bool_env_var("
|
615
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
616
616
|
extra_weight_attrs.update(
|
617
617
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
618
618
|
)
|
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
|
|
644
644
|
layer.w2_input_scale = None
|
645
645
|
|
646
646
|
def process_weights_after_loading(self, layer: Module) -> None:
|
647
|
-
if _is_hip and get_bool_env_var("
|
647
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
648
648
|
self.process_weights_hip_int4(layer)
|
649
649
|
return
|
650
650
|
|
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
|
|
675
675
|
)
|
676
676
|
layer.w2_input_scale = None
|
677
677
|
|
678
|
-
if get_bool_env_var("
|
678
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
679
679
|
# Pre-shuffle weights
|
680
680
|
layer.w13_weight.data = shuffle_weight(
|
681
681
|
layer.w13_weight.contiguous(), (16, 16)
|
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
|
|
798
798
|
return
|
799
799
|
|
800
800
|
def process_weights_hip_int4(self, layer: Module):
|
801
|
-
# TODO: and get_bool_env_var("
|
801
|
+
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
|
802
802
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
803
803
|
# Weight Permutation
|
804
804
|
layer.w13_weight = torch.nn.Parameter(
|
805
|
-
# permute_weight(layer.w13_weight.data),
|
806
805
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
807
806
|
requires_grad=False,
|
808
807
|
)
|
809
808
|
torch.cuda.empty_cache()
|
810
809
|
layer.w2_weight = torch.nn.Parameter(
|
811
|
-
# permute_weight(layer.w2_weight.data),
|
812
810
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
813
811
|
requires_grad=False,
|
814
812
|
)
|
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
|
|
847
845
|
padding_size, # Avoid circular import
|
848
846
|
)
|
849
847
|
|
850
|
-
if get_bool_env_var("
|
848
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
851
849
|
layer.w13_weight = torch.nn.Parameter(
|
852
|
-
# permute_weight(layer.w13_weight.data),
|
853
850
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
854
851
|
requires_grad=False,
|
855
852
|
)
|
856
853
|
torch.cuda.empty_cache()
|
857
854
|
layer.w2_weight = torch.nn.Parameter(
|
858
|
-
# permute_weight(layer.w2_weight.data),
|
859
855
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
860
856
|
requires_grad=False,
|
861
857
|
)
|
862
858
|
torch.cuda.empty_cache()
|
863
|
-
# ROCm (
|
859
|
+
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
|
864
860
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
865
861
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
866
|
-
elif get_bool_env_var("
|
862
|
+
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
867
863
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
868
864
|
layer.w13_weight = torch.nn.Parameter(
|
869
865
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
|
|
912
908
|
)
|
913
909
|
|
914
910
|
if _is_hip:
|
915
|
-
if get_bool_env_var("
|
916
|
-
# TODO: add triton kernel and add check get_bool_env_var("
|
911
|
+
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
912
|
+
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
|
917
913
|
assert not no_combine, f"{no_combine=} is not supported."
|
918
|
-
return
|
914
|
+
return ck_moe_2stages(
|
919
915
|
x,
|
920
916
|
layer.w13_weight,
|
921
917
|
layer.w2_weight,
|
922
918
|
topk_weights,
|
923
919
|
topk_ids,
|
920
|
+
QuantType.per_Token,
|
924
921
|
layer.w13_weight_scale1,
|
925
922
|
layer.w2_weight_scale1,
|
926
923
|
activation=(
|
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
|
|
930
927
|
),
|
931
928
|
)
|
932
929
|
|
933
|
-
if get_bool_env_var("
|
930
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
934
931
|
assert not no_combine, f"{no_combine=} is not supported."
|
935
932
|
if self.block_quant:
|
936
|
-
# TODO(
|
933
|
+
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
937
934
|
assert (
|
938
935
|
activation == "silu"
|
939
|
-
), f"
|
936
|
+
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
940
937
|
return asm_moe(
|
941
938
|
x,
|
942
939
|
layer.w13_weight,
|
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
|
|
955
952
|
layer.w2_weight,
|
956
953
|
topk_weights,
|
957
954
|
topk_ids,
|
955
|
+
QuantType.per_Token,
|
958
956
|
layer.w13_weight_scale1,
|
959
957
|
layer.w2_weight_scale1,
|
960
958
|
activation=(
|
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
|
|
31
31
|
_is_hip = is_hip()
|
32
32
|
_is_cuda = is_cuda()
|
33
33
|
|
34
|
-
if _is_hip and get_bool_env_var("
|
34
|
+
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
35
35
|
from aiter import gemm_a8w8_blockscale
|
36
36
|
|
37
37
|
if _is_cuda:
|
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
|
|
132
132
|
output = fp8_blockwise_scaled_mm(
|
133
133
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
134
134
|
)
|
135
|
-
elif _is_hip and get_bool_env_var("
|
135
|
+
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
136
136
|
q_input, x_scale = per_token_group_quant_fp8(
|
137
137
|
input_2d, block_size[1], column_major_scales=False
|
138
138
|
)
|
@@ -8,7 +8,11 @@ import torch
|
|
8
8
|
import triton
|
9
9
|
import triton.language as tl
|
10
10
|
|
11
|
-
from sglang.srt.utils import get_device_name
|
11
|
+
from sglang.srt.utils import get_device_name, is_cuda
|
12
|
+
|
13
|
+
_is_cuda = is_cuda()
|
14
|
+
if _is_cuda:
|
15
|
+
from sgl_kernel import sgl_per_token_group_quant_int8
|
12
16
|
|
13
17
|
logger = logging.getLogger(__name__)
|
14
18
|
|
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
|
|
165
169
|
return x_q, x_s
|
166
170
|
|
167
171
|
|
172
|
+
def sglang_per_token_group_quant_int8(
|
173
|
+
x: torch.Tensor,
|
174
|
+
group_size: int,
|
175
|
+
eps: float = 1e-10,
|
176
|
+
dtype: torch.dtype = torch.int8,
|
177
|
+
):
|
178
|
+
assert (
|
179
|
+
x.shape[-1] % group_size == 0
|
180
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
181
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
182
|
+
|
183
|
+
iinfo = torch.iinfo(dtype)
|
184
|
+
int8_max = iinfo.max
|
185
|
+
int8_min = iinfo.min
|
186
|
+
|
187
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
188
|
+
x_s = torch.empty(
|
189
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
190
|
+
device=x.device,
|
191
|
+
dtype=torch.float32,
|
192
|
+
)
|
193
|
+
|
194
|
+
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
195
|
+
|
196
|
+
return x_q, x_s
|
197
|
+
|
198
|
+
|
168
199
|
@triton.jit
|
169
200
|
def _w8a8_block_int8_matmul(
|
170
201
|
# Pointers to inputs and output
|
@@ -87,13 +87,23 @@ class RadixAttention(nn.Module):
|
|
87
87
|
v,
|
88
88
|
forward_batch: ForwardBatch,
|
89
89
|
save_kv_cache: bool = True,
|
90
|
+
**kwargs,
|
90
91
|
):
|
91
92
|
if k is not None:
|
92
93
|
# For cross-layer sharing, kv can be None
|
93
94
|
assert v is not None
|
94
|
-
|
95
|
-
|
95
|
+
if "k_rope" not in kwargs:
|
96
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
97
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
98
|
+
else:
|
99
|
+
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
96
100
|
|
97
101
|
return forward_batch.attn_backend.forward(
|
98
|
-
q,
|
102
|
+
q,
|
103
|
+
k,
|
104
|
+
v,
|
105
|
+
self,
|
106
|
+
forward_batch,
|
107
|
+
save_kv_cache,
|
108
|
+
**kwargs,
|
99
109
|
)
|
@@ -14,8 +14,6 @@ _is_cuda = is_cuda()
|
|
14
14
|
|
15
15
|
if _is_cuda:
|
16
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
17
|
-
else:
|
18
|
-
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
19
17
|
|
20
18
|
|
21
19
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp):
|
|
84
82
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
85
83
|
if not _is_cuda:
|
86
84
|
cache = cache.to(dtype)
|
85
|
+
|
86
|
+
if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
|
87
|
+
from vllm._custom_ops import rotary_embedding
|
88
|
+
|
89
|
+
self.vllm_rotary_embedding = rotary_embedding
|
90
|
+
|
87
91
|
self.cos_sin_cache: torch.Tensor
|
88
92
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
89
93
|
|
@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
|
|
160
164
|
)
|
161
165
|
else:
|
162
166
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
163
|
-
vllm_rotary_embedding(
|
167
|
+
self.vllm_rotary_embedding(
|
164
168
|
positions,
|
165
169
|
query,
|
166
170
|
key,
|
@@ -665,6 +669,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
665
669
|
offsets: Optional[torch.Tensor] = None,
|
666
670
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
667
671
|
"""PyTorch-native implementation equivalent to forward()."""
|
672
|
+
dtype = query.dtype
|
668
673
|
query_rot = query[..., : self.rotary_dim]
|
669
674
|
key_rot = key[..., : self.rotary_dim]
|
670
675
|
if self.rotary_dim < self.head_size:
|
@@ -695,7 +700,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
695
700
|
else:
|
696
701
|
query = query_rot
|
697
702
|
key = key_rot
|
698
|
-
return query, key
|
703
|
+
return query.to(dtype), key.to(dtype)
|
699
704
|
|
700
705
|
|
701
706
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
@@ -876,142 +881,181 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
876
881
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
877
882
|
return query, key
|
878
883
|
|
884
|
+
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
|
879
885
|
@staticmethod
|
880
|
-
def
|
881
|
-
|
882
|
-
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
883
|
-
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
886
|
+
def get_rope_index(
|
887
|
+
spatial_merge_size: int,
|
884
888
|
image_token_id: int,
|
885
889
|
video_token_id: int,
|
886
890
|
vision_start_token_id: int,
|
887
|
-
|
888
|
-
spatial_merge_size: int,
|
889
|
-
context_len: int = 0,
|
890
|
-
seq_len: Optional[int] = None,
|
891
|
-
second_per_grid_ts: Optional[torch.Tensor] = None,
|
891
|
+
model_type: str,
|
892
892
|
tokens_per_second: Optional[int] = None,
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
input_tokens_tensor == vision_start_token_id
|
911
|
-
).squeeze(1)
|
912
|
-
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
913
|
-
image_nums = (vision_tokens == image_token_id).sum()
|
914
|
-
video_nums = (vision_tokens == video_token_id).sum()
|
915
|
-
llm_pos_ids_list: list = []
|
916
|
-
|
917
|
-
st = 0
|
918
|
-
remain_images, remain_videos = image_nums, video_nums
|
919
|
-
|
920
|
-
image_index, video_index = 0, 0
|
921
|
-
for _ in range(image_nums + video_nums):
|
922
|
-
if image_token_id in input_tokens and remain_images > 0:
|
923
|
-
ed_image = input_tokens.index(image_token_id, st)
|
924
|
-
else:
|
925
|
-
ed_image = len(input_tokens) + 1
|
926
|
-
if video_token_id in input_tokens and remain_videos > 0:
|
927
|
-
ed_video = input_tokens.index(video_token_id, st)
|
928
|
-
else:
|
929
|
-
ed_video = len(input_tokens) + 1
|
930
|
-
if ed_image < ed_video:
|
931
|
-
t, h, w = (
|
932
|
-
image_grid_thw[image_index][0],
|
933
|
-
image_grid_thw[image_index][1],
|
934
|
-
image_grid_thw[image_index][2],
|
935
|
-
)
|
936
|
-
image_index += 1
|
937
|
-
remain_images -= 1
|
938
|
-
second_per_grid_t = 0
|
939
|
-
ed = ed_image
|
940
|
-
else:
|
941
|
-
t, h, w = (
|
942
|
-
video_grid_thw[video_index][0],
|
943
|
-
video_grid_thw[video_index][1],
|
944
|
-
video_grid_thw[video_index][2],
|
945
|
-
)
|
946
|
-
if second_per_grid_ts is not None:
|
947
|
-
second_per_grid_t = second_per_grid_ts[video_index]
|
948
|
-
else:
|
949
|
-
second_per_grid_t = 1.0
|
950
|
-
video_index += 1
|
951
|
-
remain_videos -= 1
|
952
|
-
ed = ed_video
|
953
|
-
llm_grid_t, llm_grid_h, llm_grid_w = (
|
954
|
-
t,
|
955
|
-
h // spatial_merge_size,
|
956
|
-
w // spatial_merge_size,
|
957
|
-
)
|
958
|
-
text_len = ed - st
|
959
|
-
|
960
|
-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
961
|
-
llm_pos_ids_list.append(
|
962
|
-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
963
|
-
)
|
964
|
-
|
965
|
-
t_index = (
|
966
|
-
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
967
|
-
* second_per_grid_t
|
968
|
-
* tokens_per_second
|
969
|
-
).flatten()
|
970
|
-
|
971
|
-
h_index = (
|
972
|
-
torch.arange(llm_grid_h)
|
973
|
-
.view(1, -1, 1)
|
974
|
-
.expand(llm_grid_t, -1, llm_grid_w)
|
975
|
-
.flatten()
|
976
|
-
)
|
977
|
-
w_index = (
|
978
|
-
torch.arange(llm_grid_w)
|
979
|
-
.view(1, 1, -1)
|
980
|
-
.expand(llm_grid_t, llm_grid_h, -1)
|
981
|
-
.flatten()
|
982
|
-
)
|
983
|
-
llm_pos_ids_list.append(
|
984
|
-
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
893
|
+
input_ids: Optional[torch.LongTensor] = None,
|
894
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
895
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
896
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
897
|
+
**kwargs,
|
898
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
899
|
+
mrope_position_deltas = []
|
900
|
+
if input_ids is not None and (
|
901
|
+
image_grid_thw is not None or video_grid_thw is not None
|
902
|
+
):
|
903
|
+
total_input_ids = input_ids
|
904
|
+
position_ids = torch.ones(
|
905
|
+
3,
|
906
|
+
input_ids.shape[0],
|
907
|
+
input_ids.shape[1],
|
908
|
+
dtype=input_ids.dtype,
|
909
|
+
device=input_ids.device,
|
985
910
|
)
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
911
|
+
image_index, video_index = 0, 0
|
912
|
+
for i, input_ids in enumerate(total_input_ids):
|
913
|
+
image_nums, video_nums = 0, 0
|
914
|
+
vision_start_indices = torch.argwhere(
|
915
|
+
input_ids == vision_start_token_id
|
916
|
+
).squeeze(1)
|
917
|
+
vision_tokens = input_ids[vision_start_indices + 1]
|
918
|
+
image_nums = (vision_tokens == image_token_id).sum()
|
919
|
+
video_nums = (vision_tokens == video_token_id).sum()
|
920
|
+
input_tokens = input_ids.tolist()
|
921
|
+
llm_pos_ids_list: list = []
|
922
|
+
st = 0
|
923
|
+
remain_images, remain_videos = image_nums, video_nums
|
924
|
+
for _ in range(image_nums + video_nums):
|
925
|
+
if image_token_id in input_tokens and remain_images > 0:
|
926
|
+
ed_image = input_tokens.index(image_token_id, st)
|
927
|
+
else:
|
928
|
+
ed_image = len(input_tokens) + 1
|
929
|
+
if video_token_id in input_tokens and remain_videos > 0:
|
930
|
+
ed_video = input_tokens.index(video_token_id, st)
|
931
|
+
else:
|
932
|
+
ed_video = len(input_tokens) + 1
|
933
|
+
if ed_image < ed_video:
|
934
|
+
t, h, w = (
|
935
|
+
image_grid_thw[image_index][0],
|
936
|
+
image_grid_thw[image_index][1],
|
937
|
+
image_grid_thw[image_index][2],
|
938
|
+
)
|
939
|
+
second_per_grid_t = 0
|
940
|
+
image_index += 1
|
941
|
+
remain_images -= 1
|
942
|
+
ed = ed_image
|
943
|
+
else:
|
944
|
+
t, h, w = (
|
945
|
+
video_grid_thw[video_index][0],
|
946
|
+
video_grid_thw[video_index][1],
|
947
|
+
video_grid_thw[video_index][2],
|
948
|
+
)
|
949
|
+
if second_per_grid_ts is not None:
|
950
|
+
second_per_grid_t = second_per_grid_ts[video_index]
|
951
|
+
else:
|
952
|
+
second_per_grid_t = 1.0
|
953
|
+
video_index += 1
|
954
|
+
remain_videos -= 1
|
955
|
+
ed = ed_video
|
956
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
957
|
+
t.item(),
|
958
|
+
h.item() // spatial_merge_size,
|
959
|
+
w.item() // spatial_merge_size,
|
960
|
+
)
|
961
|
+
text_len = ed - st
|
962
|
+
|
963
|
+
st_idx = (
|
964
|
+
llm_pos_ids_list[-1].max() + 1
|
965
|
+
if len(llm_pos_ids_list) > 0
|
966
|
+
else 0
|
967
|
+
)
|
968
|
+
llm_pos_ids_list.append(
|
969
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
970
|
+
)
|
971
|
+
|
972
|
+
if model_type == "qwen2_5_vl":
|
973
|
+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
974
|
+
expanded_range = range_tensor.expand(
|
975
|
+
-1, llm_grid_h * llm_grid_w
|
976
|
+
)
|
977
|
+
|
978
|
+
time_tensor = (
|
979
|
+
expanded_range * second_per_grid_t * tokens_per_second
|
980
|
+
)
|
981
|
+
|
982
|
+
time_tensor_long = time_tensor.long()
|
983
|
+
t_index = time_tensor_long.flatten()
|
984
|
+
elif model_type == "qwen2_vl":
|
985
|
+
t_index = (
|
986
|
+
torch.arange(llm_grid_t)
|
987
|
+
.view(-1, 1)
|
988
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
989
|
+
.flatten()
|
990
|
+
)
|
991
|
+
else:
|
992
|
+
raise RuntimeError("Unimplemented")
|
993
|
+
h_index = (
|
994
|
+
torch.arange(llm_grid_h)
|
995
|
+
.view(1, -1, 1)
|
996
|
+
.expand(llm_grid_t, -1, llm_grid_w)
|
997
|
+
.flatten()
|
998
|
+
)
|
999
|
+
w_index = (
|
1000
|
+
torch.arange(llm_grid_w)
|
1001
|
+
.view(1, 1, -1)
|
1002
|
+
.expand(llm_grid_t, llm_grid_h, -1)
|
1003
|
+
.flatten()
|
1004
|
+
)
|
1005
|
+
llm_pos_ids_list.append(
|
1006
|
+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
1007
|
+
)
|
1008
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
1009
|
+
|
1010
|
+
if st < len(input_tokens):
|
1011
|
+
st_idx = (
|
1012
|
+
llm_pos_ids_list[-1].max() + 1
|
1013
|
+
if len(llm_pos_ids_list) > 0
|
1014
|
+
else 0
|
1015
|
+
)
|
1016
|
+
text_len = len(input_tokens) - st
|
1017
|
+
llm_pos_ids_list.append(
|
1018
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1022
|
+
position_ids[..., i, :] = llm_positions.to(position_ids.device)
|
1023
|
+
mrope_position_deltas.append(
|
1024
|
+
llm_positions.max() + 1 - len(total_input_ids[i])
|
1025
|
+
)
|
1026
|
+
mrope_position_deltas = torch.tensor(
|
1027
|
+
mrope_position_deltas, device=input_ids.device
|
1028
|
+
).unsqueeze(1)
|
1029
|
+
return position_ids, mrope_position_deltas
|
1030
|
+
else:
|
1031
|
+
s = input_ids.shape[1]
|
1032
|
+
position_ids = torch.arange(s)
|
1033
|
+
position_ids = (
|
1034
|
+
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
993
1035
|
)
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
return llm_positions.tolist(), mrope_position_delta
|
1036
|
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
1037
|
+
-1, keepdim=True
|
1038
|
+
)[0]
|
1039
|
+
mrope_position_deltas = max_position_ids + 1 - s
|
1040
|
+
return position_ids, mrope_position_deltas
|
1000
1041
|
|
1001
1042
|
@staticmethod
|
1002
1043
|
def get_next_input_positions(
|
1003
1044
|
mrope_position_delta: int,
|
1004
1045
|
context_len: int,
|
1005
1046
|
seq_len: int,
|
1006
|
-
) ->
|
1007
|
-
return
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1047
|
+
) -> torch.Tensor:
|
1048
|
+
return torch.tensor(
|
1049
|
+
[
|
1050
|
+
list(
|
1051
|
+
range(
|
1052
|
+
context_len + mrope_position_delta,
|
1053
|
+
seq_len + mrope_position_delta,
|
1054
|
+
)
|
1011
1055
|
)
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1056
|
+
for _ in range(3)
|
1057
|
+
]
|
1058
|
+
)
|
1015
1059
|
|
1016
1060
|
|
1017
1061
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|