sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,17 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
-
from contextlib import contextmanager
|
20
19
|
from typing import Any, Dict, List, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
23
22
|
import triton
|
24
23
|
import triton.language as tl
|
25
24
|
|
25
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
26
26
|
from sglang.srt.utils import (
|
27
27
|
direct_register_custom_op,
|
28
|
-
get_bool_env_var,
|
29
28
|
get_device_core_count,
|
30
29
|
get_device_name,
|
31
|
-
get_device_sm,
|
32
30
|
is_cuda,
|
33
31
|
is_hip,
|
34
32
|
supports_custom_op,
|
@@ -43,22 +41,16 @@ else:
|
|
43
41
|
fp8_max = torch.finfo(_fp8_type).max
|
44
42
|
fp8_min = -fp8_max
|
45
43
|
|
46
|
-
_enable_jit_deepgemm = False
|
47
|
-
_enable_jit_deepgemm_bmm = False
|
48
44
|
if _is_cuda:
|
49
|
-
import deep_gemm
|
50
45
|
from sgl_kernel import (
|
51
46
|
sgl_per_tensor_quant_fp8,
|
52
47
|
sgl_per_token_group_quant_fp8,
|
53
48
|
sgl_per_token_quant_fp8,
|
54
49
|
)
|
55
50
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
_enable_jit_deepgemm = True
|
60
|
-
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
|
61
|
-
_enable_jit_deepgemm_bmm = True
|
51
|
+
from sglang.srt.layers.quantization.deep_gemm import (
|
52
|
+
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
53
|
+
)
|
62
54
|
|
63
55
|
logger = logging.getLogger(__name__)
|
64
56
|
|
@@ -71,10 +63,7 @@ if supports_custom_op():
|
|
71
63
|
Bs: torch.Tensor,
|
72
64
|
C: torch.Tensor,
|
73
65
|
) -> None:
|
74
|
-
|
75
|
-
N, _ = B.shape
|
76
|
-
with _log_jit_build(M, N, K):
|
77
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
66
|
+
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
78
67
|
|
79
68
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
80
69
|
A: torch.Tensor,
|
@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
|
|
715
704
|
return None
|
716
705
|
|
717
706
|
|
718
|
-
@contextmanager
|
719
|
-
def _log_jit_build(M: int, N: int, K: int):
|
720
|
-
from deep_gemm.jit.runtime import RuntimeCache
|
721
|
-
|
722
|
-
origin_func = RuntimeCache.__getitem__
|
723
|
-
|
724
|
-
def __patched_func(self, *args, **kwargs):
|
725
|
-
ret = origin_func(self, *args, **kwargs)
|
726
|
-
if ret is None:
|
727
|
-
logger.warning(
|
728
|
-
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
729
|
-
)
|
730
|
-
return ret
|
731
|
-
|
732
|
-
RuntimeCache.__getitem__ = __patched_func
|
733
|
-
yield
|
734
|
-
RuntimeCache.__getitem__ = origin_func
|
735
|
-
|
736
|
-
|
737
707
|
def w8a8_block_fp8_matmul(
|
738
708
|
A: torch.Tensor,
|
739
709
|
B: torch.Tensor,
|
@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
|
|
804
774
|
)
|
805
775
|
|
806
776
|
# deepgemm only support bf16
|
807
|
-
if C.dtype == torch.bfloat16 and
|
777
|
+
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
808
778
|
if supports_custom_op():
|
809
779
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
810
780
|
else:
|
811
|
-
|
812
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
781
|
+
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
813
782
|
else:
|
814
783
|
kernel = (
|
815
784
|
_w8a8_block_fp8_matmul_unrolledx4
|
@@ -12,8 +12,8 @@ try:
|
|
12
12
|
except ImportError:
|
13
13
|
VLLM_AVAILABLE = False
|
14
14
|
|
15
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
15
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
16
|
-
_enable_jit_deepgemm,
|
17
17
|
per_token_group_quant_fp8,
|
18
18
|
scaled_fp8_quant,
|
19
19
|
sglang_per_token_quant_fp8,
|
@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
|
|
143
143
|
)
|
144
144
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
145
145
|
else:
|
146
|
-
if
|
146
|
+
if _ENABLE_JIT_DEEPGEMM:
|
147
147
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
148
148
|
input_2d,
|
149
149
|
block_size[1],
|
@@ -37,6 +37,14 @@ except ImportError:
|
|
37
37
|
logger = logging.getLogger(__name__)
|
38
38
|
|
39
39
|
|
40
|
+
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
41
|
+
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
|
42
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
43
|
+
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
|
44
|
+
"is_marlin_format", False
|
45
|
+
)
|
46
|
+
|
47
|
+
|
40
48
|
class GPTQConfig(QuantizationConfig):
|
41
49
|
"""Config class for GPTQ.
|
42
50
|
|
@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
262
270
|
|
263
271
|
@classmethod
|
264
272
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
273
|
+
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
274
|
+
|
265
275
|
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
266
276
|
|
267
277
|
is_valid_user_quant = (
|
268
278
|
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
269
279
|
)
|
270
280
|
|
271
|
-
if can_convert and is_valid_user_quant:
|
281
|
+
if not is_marlin_format and can_convert and is_valid_user_quant:
|
272
282
|
msg = (
|
273
283
|
"The model is convertible to {} during runtime."
|
274
284
|
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
276
286
|
logger.info(msg)
|
277
287
|
return cls.get_name()
|
278
288
|
|
279
|
-
if can_convert and user_quant == "gptq":
|
289
|
+
if not is_marlin_format and can_convert and user_quant == "gptq":
|
280
290
|
logger.info(
|
281
291
|
"Detected that the model can run with gptq_marlin"
|
282
292
|
", however you specified quantization=gptq explicitly,"
|
@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
|
|
401
411
|
|
402
412
|
@classmethod
|
403
413
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
404
|
-
|
405
|
-
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
406
|
-
is_marlin_format = hf_quant_cfg.get(
|
407
|
-
"checkpoint_format"
|
408
|
-
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
414
|
+
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
409
415
|
|
410
416
|
is_valid_user_quant = (
|
411
417
|
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
@@ -8,7 +8,11 @@ import torch
|
|
8
8
|
import triton
|
9
9
|
import triton.language as tl
|
10
10
|
|
11
|
-
from sglang.srt.utils import get_device_name
|
11
|
+
from sglang.srt.utils import get_device_name, is_cuda
|
12
|
+
|
13
|
+
_is_cuda = is_cuda()
|
14
|
+
if _is_cuda:
|
15
|
+
from sgl_kernel import sgl_per_token_group_quant_int8
|
12
16
|
|
13
17
|
logger = logging.getLogger(__name__)
|
14
18
|
|
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
|
|
165
169
|
return x_q, x_s
|
166
170
|
|
167
171
|
|
172
|
+
def sglang_per_token_group_quant_int8(
|
173
|
+
x: torch.Tensor,
|
174
|
+
group_size: int,
|
175
|
+
eps: float = 1e-10,
|
176
|
+
dtype: torch.dtype = torch.int8,
|
177
|
+
):
|
178
|
+
assert (
|
179
|
+
x.shape[-1] % group_size == 0
|
180
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
181
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
182
|
+
|
183
|
+
iinfo = torch.iinfo(dtype)
|
184
|
+
int8_max = iinfo.max
|
185
|
+
int8_min = iinfo.min
|
186
|
+
|
187
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
188
|
+
x_s = torch.empty(
|
189
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
190
|
+
device=x.device,
|
191
|
+
dtype=torch.float32,
|
192
|
+
)
|
193
|
+
|
194
|
+
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
195
|
+
|
196
|
+
return x_q, x_s
|
197
|
+
|
198
|
+
|
168
199
|
@triton.jit
|
169
200
|
def _w8a8_block_int8_matmul(
|
170
201
|
# Pointers to inputs and output
|
@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
|
|
22
22
|
requantize_with_max_scale,
|
23
23
|
)
|
24
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
25
|
-
from sglang.srt.utils import
|
25
|
+
from sglang.srt.utils import is_cuda
|
26
26
|
|
27
|
-
if
|
27
|
+
if is_cuda():
|
28
28
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
29
29
|
|
30
30
|
# Initialize logger for the module
|
@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|
11
11
|
QuantizeMethodBase,
|
12
12
|
)
|
13
13
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
15
15
|
|
16
|
-
|
17
|
-
if
|
16
|
+
_is_cuda = is_cuda()
|
17
|
+
if _is_cuda:
|
18
18
|
from sgl_kernel import int8_scaled_mm
|
19
19
|
|
20
20
|
|
@@ -87,13 +87,23 @@ class RadixAttention(nn.Module):
|
|
87
87
|
v,
|
88
88
|
forward_batch: ForwardBatch,
|
89
89
|
save_kv_cache: bool = True,
|
90
|
+
**kwargs,
|
90
91
|
):
|
91
92
|
if k is not None:
|
92
93
|
# For cross-layer sharing, kv can be None
|
93
94
|
assert v is not None
|
94
|
-
|
95
|
-
|
95
|
+
if "k_rope" not in kwargs:
|
96
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
97
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
98
|
+
else:
|
99
|
+
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
96
100
|
|
97
101
|
return forward_batch.attn_backend.forward(
|
98
|
-
q,
|
102
|
+
q,
|
103
|
+
k,
|
104
|
+
v,
|
105
|
+
self,
|
106
|
+
forward_batch,
|
107
|
+
save_kv_cache,
|
108
|
+
**kwargs,
|
99
109
|
)
|
@@ -8,14 +8,12 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import
|
11
|
+
from sglang.srt.utils import is_cuda
|
12
12
|
|
13
|
-
|
13
|
+
_is_cuda = is_cuda()
|
14
14
|
|
15
|
-
if
|
15
|
+
if _is_cuda:
|
16
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
17
|
-
else:
|
18
|
-
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
19
17
|
|
20
18
|
|
21
19
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -82,8 +80,14 @@ class RotaryEmbedding(CustomOp):
|
|
82
80
|
|
83
81
|
cache = self._compute_cos_sin_cache()
|
84
82
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
85
|
-
if not
|
83
|
+
if not _is_cuda:
|
86
84
|
cache = cache.to(dtype)
|
85
|
+
|
86
|
+
if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
|
87
|
+
from vllm._custom_ops import rotary_embedding
|
88
|
+
|
89
|
+
self.vllm_rotary_embedding = rotary_embedding
|
90
|
+
|
87
91
|
self.cos_sin_cache: torch.Tensor
|
88
92
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
89
93
|
|
@@ -149,7 +153,7 @@ class RotaryEmbedding(CustomOp):
|
|
149
153
|
key: torch.Tensor,
|
150
154
|
offsets: Optional[torch.Tensor] = None,
|
151
155
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
152
|
-
if
|
156
|
+
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
153
157
|
apply_rope_with_cos_sin_cache_inplace(
|
154
158
|
positions=positions,
|
155
159
|
query=query,
|
@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
|
|
160
164
|
)
|
161
165
|
else:
|
162
166
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
163
|
-
vllm_rotary_embedding(
|
167
|
+
self.vllm_rotary_embedding(
|
164
168
|
positions,
|
165
169
|
query,
|
166
170
|
key,
|
@@ -652,7 +656,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
652
656
|
def forward(self, *args, **kwargs):
|
653
657
|
if torch.compiler.is_compiling():
|
654
658
|
return self.forward_native(*args, **kwargs)
|
655
|
-
if
|
659
|
+
if _is_cuda:
|
656
660
|
return self.forward_cuda(*args, **kwargs)
|
657
661
|
else:
|
658
662
|
return self.forward_native(*args, **kwargs)
|
@@ -665,6 +669,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
665
669
|
offsets: Optional[torch.Tensor] = None,
|
666
670
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
667
671
|
"""PyTorch-native implementation equivalent to forward()."""
|
672
|
+
dtype = query.dtype
|
668
673
|
query_rot = query[..., : self.rotary_dim]
|
669
674
|
key_rot = key[..., : self.rotary_dim]
|
670
675
|
if self.rotary_dim < self.head_size:
|
@@ -695,7 +700,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
695
700
|
else:
|
696
701
|
query = query_rot
|
697
702
|
key = key_rot
|
698
|
-
return query, key
|
703
|
+
return query.to(dtype), key.to(dtype)
|
699
704
|
|
700
705
|
|
701
706
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
@@ -876,142 +881,181 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
876
881
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
877
882
|
return query, key
|
878
883
|
|
884
|
+
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
|
879
885
|
@staticmethod
|
880
|
-
def
|
881
|
-
|
882
|
-
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
883
|
-
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
886
|
+
def get_rope_index(
|
887
|
+
spatial_merge_size: int,
|
884
888
|
image_token_id: int,
|
885
889
|
video_token_id: int,
|
886
890
|
vision_start_token_id: int,
|
887
|
-
|
888
|
-
spatial_merge_size: int,
|
889
|
-
context_len: int = 0,
|
890
|
-
seq_len: Optional[int] = None,
|
891
|
-
second_per_grid_ts: Optional[torch.Tensor] = None,
|
891
|
+
model_type: str,
|
892
892
|
tokens_per_second: Optional[int] = None,
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
input_tokens_tensor == vision_start_token_id
|
911
|
-
).squeeze(1)
|
912
|
-
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
913
|
-
image_nums = (vision_tokens == image_token_id).sum()
|
914
|
-
video_nums = (vision_tokens == video_token_id).sum()
|
915
|
-
llm_pos_ids_list: list = []
|
916
|
-
|
917
|
-
st = 0
|
918
|
-
remain_images, remain_videos = image_nums, video_nums
|
919
|
-
|
920
|
-
image_index, video_index = 0, 0
|
921
|
-
for _ in range(image_nums + video_nums):
|
922
|
-
if image_token_id in input_tokens and remain_images > 0:
|
923
|
-
ed_image = input_tokens.index(image_token_id, st)
|
924
|
-
else:
|
925
|
-
ed_image = len(input_tokens) + 1
|
926
|
-
if video_token_id in input_tokens and remain_videos > 0:
|
927
|
-
ed_video = input_tokens.index(video_token_id, st)
|
928
|
-
else:
|
929
|
-
ed_video = len(input_tokens) + 1
|
930
|
-
if ed_image < ed_video:
|
931
|
-
t, h, w = (
|
932
|
-
image_grid_thw[image_index][0],
|
933
|
-
image_grid_thw[image_index][1],
|
934
|
-
image_grid_thw[image_index][2],
|
935
|
-
)
|
936
|
-
image_index += 1
|
937
|
-
remain_images -= 1
|
938
|
-
second_per_grid_t = 0
|
939
|
-
ed = ed_image
|
940
|
-
else:
|
941
|
-
t, h, w = (
|
942
|
-
video_grid_thw[video_index][0],
|
943
|
-
video_grid_thw[video_index][1],
|
944
|
-
video_grid_thw[video_index][2],
|
945
|
-
)
|
946
|
-
if second_per_grid_ts is not None:
|
947
|
-
second_per_grid_t = second_per_grid_ts[video_index]
|
948
|
-
else:
|
949
|
-
second_per_grid_t = 1.0
|
950
|
-
video_index += 1
|
951
|
-
remain_videos -= 1
|
952
|
-
ed = ed_video
|
953
|
-
llm_grid_t, llm_grid_h, llm_grid_w = (
|
954
|
-
t,
|
955
|
-
h // spatial_merge_size,
|
956
|
-
w // spatial_merge_size,
|
957
|
-
)
|
958
|
-
text_len = ed - st
|
959
|
-
|
960
|
-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
961
|
-
llm_pos_ids_list.append(
|
962
|
-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
963
|
-
)
|
964
|
-
|
965
|
-
t_index = (
|
966
|
-
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
967
|
-
* second_per_grid_t
|
968
|
-
* tokens_per_second
|
969
|
-
).flatten()
|
970
|
-
|
971
|
-
h_index = (
|
972
|
-
torch.arange(llm_grid_h)
|
973
|
-
.view(1, -1, 1)
|
974
|
-
.expand(llm_grid_t, -1, llm_grid_w)
|
975
|
-
.flatten()
|
976
|
-
)
|
977
|
-
w_index = (
|
978
|
-
torch.arange(llm_grid_w)
|
979
|
-
.view(1, 1, -1)
|
980
|
-
.expand(llm_grid_t, llm_grid_h, -1)
|
981
|
-
.flatten()
|
982
|
-
)
|
983
|
-
llm_pos_ids_list.append(
|
984
|
-
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
893
|
+
input_ids: Optional[torch.LongTensor] = None,
|
894
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
895
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
896
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
897
|
+
**kwargs,
|
898
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
899
|
+
mrope_position_deltas = []
|
900
|
+
if input_ids is not None and (
|
901
|
+
image_grid_thw is not None or video_grid_thw is not None
|
902
|
+
):
|
903
|
+
total_input_ids = input_ids
|
904
|
+
position_ids = torch.ones(
|
905
|
+
3,
|
906
|
+
input_ids.shape[0],
|
907
|
+
input_ids.shape[1],
|
908
|
+
dtype=input_ids.dtype,
|
909
|
+
device=input_ids.device,
|
985
910
|
)
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
911
|
+
image_index, video_index = 0, 0
|
912
|
+
for i, input_ids in enumerate(total_input_ids):
|
913
|
+
image_nums, video_nums = 0, 0
|
914
|
+
vision_start_indices = torch.argwhere(
|
915
|
+
input_ids == vision_start_token_id
|
916
|
+
).squeeze(1)
|
917
|
+
vision_tokens = input_ids[vision_start_indices + 1]
|
918
|
+
image_nums = (vision_tokens == image_token_id).sum()
|
919
|
+
video_nums = (vision_tokens == video_token_id).sum()
|
920
|
+
input_tokens = input_ids.tolist()
|
921
|
+
llm_pos_ids_list: list = []
|
922
|
+
st = 0
|
923
|
+
remain_images, remain_videos = image_nums, video_nums
|
924
|
+
for _ in range(image_nums + video_nums):
|
925
|
+
if image_token_id in input_tokens and remain_images > 0:
|
926
|
+
ed_image = input_tokens.index(image_token_id, st)
|
927
|
+
else:
|
928
|
+
ed_image = len(input_tokens) + 1
|
929
|
+
if video_token_id in input_tokens and remain_videos > 0:
|
930
|
+
ed_video = input_tokens.index(video_token_id, st)
|
931
|
+
else:
|
932
|
+
ed_video = len(input_tokens) + 1
|
933
|
+
if ed_image < ed_video:
|
934
|
+
t, h, w = (
|
935
|
+
image_grid_thw[image_index][0],
|
936
|
+
image_grid_thw[image_index][1],
|
937
|
+
image_grid_thw[image_index][2],
|
938
|
+
)
|
939
|
+
second_per_grid_t = 0
|
940
|
+
image_index += 1
|
941
|
+
remain_images -= 1
|
942
|
+
ed = ed_image
|
943
|
+
else:
|
944
|
+
t, h, w = (
|
945
|
+
video_grid_thw[video_index][0],
|
946
|
+
video_grid_thw[video_index][1],
|
947
|
+
video_grid_thw[video_index][2],
|
948
|
+
)
|
949
|
+
if second_per_grid_ts is not None:
|
950
|
+
second_per_grid_t = second_per_grid_ts[video_index]
|
951
|
+
else:
|
952
|
+
second_per_grid_t = 1.0
|
953
|
+
video_index += 1
|
954
|
+
remain_videos -= 1
|
955
|
+
ed = ed_video
|
956
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
957
|
+
t.item(),
|
958
|
+
h.item() // spatial_merge_size,
|
959
|
+
w.item() // spatial_merge_size,
|
960
|
+
)
|
961
|
+
text_len = ed - st
|
962
|
+
|
963
|
+
st_idx = (
|
964
|
+
llm_pos_ids_list[-1].max() + 1
|
965
|
+
if len(llm_pos_ids_list) > 0
|
966
|
+
else 0
|
967
|
+
)
|
968
|
+
llm_pos_ids_list.append(
|
969
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
970
|
+
)
|
971
|
+
|
972
|
+
if model_type == "qwen2_5_vl":
|
973
|
+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
974
|
+
expanded_range = range_tensor.expand(
|
975
|
+
-1, llm_grid_h * llm_grid_w
|
976
|
+
)
|
977
|
+
|
978
|
+
time_tensor = (
|
979
|
+
expanded_range * second_per_grid_t * tokens_per_second
|
980
|
+
)
|
981
|
+
|
982
|
+
time_tensor_long = time_tensor.long()
|
983
|
+
t_index = time_tensor_long.flatten()
|
984
|
+
elif model_type == "qwen2_vl":
|
985
|
+
t_index = (
|
986
|
+
torch.arange(llm_grid_t)
|
987
|
+
.view(-1, 1)
|
988
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
989
|
+
.flatten()
|
990
|
+
)
|
991
|
+
else:
|
992
|
+
raise RuntimeError("Unimplemented")
|
993
|
+
h_index = (
|
994
|
+
torch.arange(llm_grid_h)
|
995
|
+
.view(1, -1, 1)
|
996
|
+
.expand(llm_grid_t, -1, llm_grid_w)
|
997
|
+
.flatten()
|
998
|
+
)
|
999
|
+
w_index = (
|
1000
|
+
torch.arange(llm_grid_w)
|
1001
|
+
.view(1, 1, -1)
|
1002
|
+
.expand(llm_grid_t, llm_grid_h, -1)
|
1003
|
+
.flatten()
|
1004
|
+
)
|
1005
|
+
llm_pos_ids_list.append(
|
1006
|
+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
1007
|
+
)
|
1008
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
1009
|
+
|
1010
|
+
if st < len(input_tokens):
|
1011
|
+
st_idx = (
|
1012
|
+
llm_pos_ids_list[-1].max() + 1
|
1013
|
+
if len(llm_pos_ids_list) > 0
|
1014
|
+
else 0
|
1015
|
+
)
|
1016
|
+
text_len = len(input_tokens) - st
|
1017
|
+
llm_pos_ids_list.append(
|
1018
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1022
|
+
position_ids[..., i, :] = llm_positions.to(position_ids.device)
|
1023
|
+
mrope_position_deltas.append(
|
1024
|
+
llm_positions.max() + 1 - len(total_input_ids[i])
|
1025
|
+
)
|
1026
|
+
mrope_position_deltas = torch.tensor(
|
1027
|
+
mrope_position_deltas, device=input_ids.device
|
1028
|
+
).unsqueeze(1)
|
1029
|
+
return position_ids, mrope_position_deltas
|
1030
|
+
else:
|
1031
|
+
s = input_ids.shape[1]
|
1032
|
+
position_ids = torch.arange(s)
|
1033
|
+
position_ids = (
|
1034
|
+
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
993
1035
|
)
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
return llm_positions.tolist(), mrope_position_delta
|
1036
|
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
1037
|
+
-1, keepdim=True
|
1038
|
+
)[0]
|
1039
|
+
mrope_position_deltas = max_position_ids + 1 - s
|
1040
|
+
return position_ids, mrope_position_deltas
|
1000
1041
|
|
1001
1042
|
@staticmethod
|
1002
1043
|
def get_next_input_positions(
|
1003
1044
|
mrope_position_delta: int,
|
1004
1045
|
context_len: int,
|
1005
1046
|
seq_len: int,
|
1006
|
-
) ->
|
1007
|
-
return
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1047
|
+
) -> torch.Tensor:
|
1048
|
+
return torch.tensor(
|
1049
|
+
[
|
1050
|
+
list(
|
1051
|
+
range(
|
1052
|
+
context_len + mrope_position_delta,
|
1053
|
+
seq_len + mrope_position_delta,
|
1054
|
+
)
|
1011
1055
|
)
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1056
|
+
for _ in range(3)
|
1057
|
+
]
|
1058
|
+
)
|
1015
1059
|
|
1016
1060
|
|
1017
1061
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
sglang/srt/layers/sampler.py
CHANGED
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
|
|
10
10
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
12
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
13
|
-
from sglang.srt.utils import crash_on_warnings, get_bool_env_var,
|
13
|
+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
|
14
14
|
|
15
|
-
if
|
15
|
+
if is_cuda():
|
16
16
|
from sgl_kernel import (
|
17
17
|
min_p_sampling_from_probs,
|
18
18
|
top_k_renorm_prob,
|