sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,30 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Constrained decoding with xgrammar backend."""
|
17
15
|
|
18
16
|
import logging
|
19
17
|
from typing import List, Tuple
|
20
18
|
|
21
19
|
import torch
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
)
|
31
|
-
import_error = e
|
20
|
+
from xgrammar import (
|
21
|
+
CompiledGrammar,
|
22
|
+
GrammarCompiler,
|
23
|
+
GrammarMatcher,
|
24
|
+
TokenizerInfo,
|
25
|
+
allocate_token_bitmask,
|
26
|
+
apply_token_bitmask_inplace,
|
27
|
+
)
|
32
28
|
|
33
29
|
from sglang.srt.constrained.base_grammar_backend import (
|
34
30
|
BaseGrammarBackend,
|
@@ -38,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
38
34
|
logger = logging.getLogger(__name__)
|
39
35
|
|
40
36
|
|
41
|
-
MAX_ROLLBACK_TOKENS =
|
37
|
+
MAX_ROLLBACK_TOKENS = 200
|
42
38
|
|
43
39
|
|
44
40
|
class XGrammarGrammar(BaseGrammarObject):
|
@@ -80,20 +76,25 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
80
76
|
for i in range(k, len(new_output_ids)):
|
81
77
|
assert self.matcher.accept_token(new_output_ids[i])
|
82
78
|
|
83
|
-
def
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
79
|
+
def allocate_vocab_mask(
|
80
|
+
self, vocab_size: int, batch_size: int, device
|
81
|
+
) -> torch.Tensor:
|
82
|
+
return allocate_token_bitmask(batch_size, vocab_size)
|
83
|
+
|
84
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
85
|
+
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
86
|
+
|
87
|
+
@staticmethod
|
88
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
89
|
+
if vocab_mask.device.type != logits.device.type:
|
90
|
+
# vocab_mask must then be on the same device as logits
|
91
|
+
# when applying the token bitmask, so we check and move if needed
|
92
|
+
vocab_mask = vocab_mask.to(logits.device)
|
93
|
+
|
94
|
+
apply_token_bitmask_inplace(logits, vocab_mask)
|
90
95
|
|
91
96
|
def copy(self):
|
92
|
-
matcher = GrammarMatcher(
|
93
|
-
self.ctx,
|
94
|
-
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
95
|
-
mask_vocab_size=self.vocab_size,
|
96
|
-
)
|
97
|
+
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
97
98
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
98
99
|
|
99
100
|
|
@@ -105,26 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
105
106
|
):
|
106
107
|
super().__init__()
|
107
108
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
self.grammar_cache = None
|
113
|
-
return
|
114
|
-
|
115
|
-
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
109
|
+
tokenizer_info = TokenizerInfo.from_huggingface(
|
110
|
+
tokenizer, vocab_size=vocab_size
|
111
|
+
)
|
112
|
+
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
116
113
|
self.vocab_size = vocab_size
|
117
114
|
|
118
115
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
119
|
-
if import_error:
|
120
|
-
raise import_error
|
121
116
|
|
122
117
|
key_type, key_string = key
|
123
118
|
if key_type == "json":
|
124
119
|
try:
|
125
|
-
ctx = self.
|
126
|
-
key_string
|
127
|
-
)
|
120
|
+
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
128
121
|
except RuntimeError as e:
|
129
122
|
logging.warning(
|
130
123
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
@@ -138,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
138
131
|
else:
|
139
132
|
raise ValueError(f"Invalid key_type: {key_type}")
|
140
133
|
|
141
|
-
matcher = GrammarMatcher(
|
142
|
-
ctx,
|
143
|
-
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
144
|
-
mask_vocab_size=self.vocab_size,
|
145
|
-
)
|
134
|
+
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
146
135
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
147
136
|
|
148
137
|
def reset(self):
|
149
|
-
if self.
|
150
|
-
self.
|
138
|
+
if self.grammar_compiler:
|
139
|
+
self.grammar_compiler.clear_cache()
|
sglang/srt/conversation.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Conversation chat templates."""
|
17
15
|
|
18
16
|
# Adapted from
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Utilities for Huggingface Transformers."""
|
17
15
|
|
18
16
|
import contextlib
|
sglang/srt/layers/activation.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
16
|
import logging
|
@@ -32,12 +32,14 @@ from vllm.distributed import (
|
|
32
32
|
)
|
33
33
|
from vllm.model_executor.custom_op import CustomOp
|
34
34
|
|
35
|
+
from sglang.srt.layers.custom_op_util import register_custom_op
|
35
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
37
|
from sglang.srt.utils import set_weight_attrs
|
37
38
|
|
38
39
|
logger = logging.getLogger(__name__)
|
39
40
|
|
40
41
|
|
42
|
+
@register_custom_op("sglang_silu_and_mul")
|
41
43
|
class SiluAndMul(CustomOp):
|
42
44
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
43
45
|
d = x.shape[-1] // 2
|
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
|
|
51
53
|
return out
|
52
54
|
|
53
55
|
|
56
|
+
@register_custom_op("sglang_gelu_and_mul")
|
54
57
|
class GeluAndMul(CustomOp):
|
55
58
|
def __init__(self, approximate="tanh"):
|
56
59
|
super().__init__()
|
@@ -7,8 +7,9 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import os
|
10
11
|
from enum import Enum, auto
|
11
|
-
from typing import TYPE_CHECKING
|
12
|
+
from typing import TYPE_CHECKING, List
|
12
13
|
|
13
14
|
import torch
|
14
15
|
import triton
|
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
45
46
|
super().__init__()
|
46
47
|
|
47
48
|
# Parse constants
|
48
|
-
if
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
self.decode_use_tensor_cores = True
|
49
|
+
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
50
|
+
self.decode_use_tensor_cores = (
|
51
|
+
os.environ["SGLANG_FLASHINFER_USE_TENSOR_CORE"].lower() == "true"
|
52
|
+
)
|
53
53
|
else:
|
54
|
-
|
54
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
55
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
56
|
+
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
57
|
+
):
|
58
|
+
self.decode_use_tensor_cores = True
|
59
|
+
else:
|
60
|
+
self.decode_use_tensor_cores = False
|
61
|
+
|
55
62
|
self.max_context_len = model_runner.model_config.context_len
|
56
63
|
|
57
64
|
assert not (
|
@@ -136,15 +143,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
136
143
|
prefix_lens = forward_batch.extend_prefix_lens
|
137
144
|
|
138
145
|
# Some heuristics to check whether to use ragged forward
|
139
|
-
use_ragged = False
|
140
146
|
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
141
147
|
use_ragged = True
|
142
|
-
|
143
|
-
|
148
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
149
|
+
else:
|
150
|
+
use_ragged = False
|
151
|
+
extend_no_prefix = False
|
144
152
|
|
145
153
|
self.indices_updater_prefill.update(
|
146
154
|
forward_batch.req_pool_indices,
|
147
155
|
forward_batch.seq_lens,
|
156
|
+
forward_batch.seq_lens_sum,
|
148
157
|
prefix_lens,
|
149
158
|
use_ragged=use_ragged,
|
150
159
|
encoder_lens=forward_batch.encoder_lens,
|
@@ -314,7 +323,6 @@ class FlashInferIndicesUpdaterDecode:
|
|
314
323
|
self.head_dim = model_runner.model_config.head_dim
|
315
324
|
self.data_type = model_runner.kv_cache_dtype
|
316
325
|
self.q_data_type = model_runner.dtype
|
317
|
-
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
318
326
|
self.sliding_window_size = model_runner.sliding_window_size
|
319
327
|
|
320
328
|
self.attn_backend = attn_backend
|
@@ -335,7 +343,12 @@ class FlashInferIndicesUpdaterDecode:
|
|
335
343
|
self.update = self.update_single_wrapper
|
336
344
|
|
337
345
|
def update(
|
338
|
-
self,
|
346
|
+
self,
|
347
|
+
req_pool_indices: torch.Tensor,
|
348
|
+
seq_lens: torch.Tensor,
|
349
|
+
seq_lens_sum: int,
|
350
|
+
decode_wrappers: List,
|
351
|
+
encoder_lens: torch.Tensor,
|
339
352
|
):
|
340
353
|
# Keep the signature for type checking. It will be assigned during runtime.
|
341
354
|
raise NotImplementedError()
|
@@ -345,8 +358,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
345
358
|
req_pool_indices: torch.Tensor,
|
346
359
|
seq_lens: torch.Tensor,
|
347
360
|
seq_lens_sum: int,
|
348
|
-
decode_wrappers
|
349
|
-
encoder_lens
|
361
|
+
decode_wrappers: List,
|
362
|
+
encoder_lens: torch.Tensor,
|
350
363
|
):
|
351
364
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
352
365
|
self.call_begin_forward(
|
@@ -363,8 +376,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
363
376
|
req_pool_indices: torch.Tensor,
|
364
377
|
seq_lens: torch.Tensor,
|
365
378
|
seq_lens_sum: int,
|
366
|
-
decode_wrappers
|
367
|
-
encoder_lens
|
379
|
+
decode_wrappers: List,
|
380
|
+
encoder_lens: torch.Tensor,
|
368
381
|
):
|
369
382
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
370
383
|
|
@@ -394,11 +407,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
394
407
|
|
395
408
|
def update_cross_attention(
|
396
409
|
self,
|
397
|
-
req_pool_indices,
|
398
|
-
seq_lens,
|
399
|
-
seq_lens_sum,
|
400
|
-
decode_wrappers
|
401
|
-
encoder_lens
|
410
|
+
req_pool_indices: torch.Tensor,
|
411
|
+
seq_lens: torch.Tensor,
|
412
|
+
seq_lens_sum: int,
|
413
|
+
decode_wrappers: List,
|
414
|
+
encoder_lens: torch.Tensor,
|
402
415
|
):
|
403
416
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
404
417
|
|
@@ -425,11 +438,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
425
438
|
def call_begin_forward(
|
426
439
|
self,
|
427
440
|
wrapper,
|
428
|
-
req_pool_indices,
|
429
|
-
paged_kernel_lens,
|
430
|
-
paged_kernel_lens_sum,
|
431
|
-
kv_indptr,
|
432
|
-
kv_start_idx,
|
441
|
+
req_pool_indices: torch.Tensor,
|
442
|
+
paged_kernel_lens: torch.Tensor,
|
443
|
+
paged_kernel_lens_sum: int,
|
444
|
+
kv_indptr: torch.Tensor,
|
445
|
+
kv_start_idx: torch.Tensor,
|
433
446
|
):
|
434
447
|
bs = len(req_pool_indices)
|
435
448
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
@@ -445,7 +458,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
445
458
|
kv_indptr,
|
446
459
|
kv_start_idx,
|
447
460
|
kv_indices,
|
448
|
-
self.
|
461
|
+
self.req_to_token.shape[1],
|
449
462
|
)
|
450
463
|
|
451
464
|
wrapper.end_forward()
|
@@ -474,7 +487,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
474
487
|
self.head_dim = model_runner.model_config.head_dim
|
475
488
|
self.data_type = model_runner.kv_cache_dtype
|
476
489
|
self.q_data_type = model_runner.dtype
|
477
|
-
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
478
490
|
self.sliding_window_size = model_runner.sliding_window_size
|
479
491
|
|
480
492
|
self.attn_backend = attn_backend
|
@@ -496,23 +508,40 @@ class FlashInferIndicesUpdaterPrefill:
|
|
496
508
|
assert self.attn_backend.num_wrappers == 1
|
497
509
|
self.update = self.update_single_wrapper
|
498
510
|
|
499
|
-
def update(
|
511
|
+
def update(
|
512
|
+
self,
|
513
|
+
req_pool_indices: torch.Tnesor,
|
514
|
+
seq_lens: torch.Tensor,
|
515
|
+
seq_lens_sum: int,
|
516
|
+
prefix_lens: torch.Tensor,
|
517
|
+
use_ragged: bool,
|
518
|
+
encoder_lens: torch.Tensor,
|
519
|
+
):
|
500
520
|
# Keep the signature for type checking. It will be assigned during runtime.
|
501
521
|
raise NotImplementedError()
|
502
522
|
|
503
523
|
def update_single_wrapper(
|
504
|
-
self,
|
524
|
+
self,
|
525
|
+
req_pool_indices: torch.Tnesor,
|
526
|
+
seq_lens: torch.Tensor,
|
527
|
+
seq_lens_sum: int,
|
528
|
+
prefix_lens: torch.Tensor,
|
529
|
+
use_ragged: bool,
|
530
|
+
encoder_lens: torch.Tensor,
|
505
531
|
):
|
506
532
|
if use_ragged:
|
507
533
|
paged_kernel_lens = prefix_lens
|
534
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
508
535
|
else:
|
509
536
|
paged_kernel_lens = seq_lens
|
537
|
+
paged_kernel_lens_sum = seq_lens_sum
|
510
538
|
|
511
539
|
self.call_begin_forward(
|
512
540
|
self.wrapper_ragged,
|
513
541
|
self.wrappers_paged[0],
|
514
542
|
req_pool_indices,
|
515
543
|
paged_kernel_lens,
|
544
|
+
paged_kernel_lens_sum,
|
516
545
|
seq_lens,
|
517
546
|
prefix_lens,
|
518
547
|
None,
|
@@ -522,7 +551,13 @@ class FlashInferIndicesUpdaterPrefill:
|
|
522
551
|
)
|
523
552
|
|
524
553
|
def update_sliding_window(
|
525
|
-
self,
|
554
|
+
self,
|
555
|
+
req_pool_indices: torch.Tensor,
|
556
|
+
seq_lens: torch.Tensor,
|
557
|
+
seq_lens_sum: int,
|
558
|
+
prefix_lens: torch.Tensor,
|
559
|
+
use_ragged: bool,
|
560
|
+
encoder_lens: torch.Tensor,
|
526
561
|
):
|
527
562
|
for wrapper_id in range(2):
|
528
563
|
if wrapper_id == 0:
|
@@ -531,9 +566,12 @@ class FlashInferIndicesUpdaterPrefill:
|
|
531
566
|
seq_lens,
|
532
567
|
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
|
533
568
|
)
|
569
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
534
570
|
else:
|
535
571
|
# full attention
|
536
572
|
paged_kernel_lens = seq_lens
|
573
|
+
paged_kernel_lens_sum = seq_lens_sum
|
574
|
+
|
537
575
|
kv_start_idx = seq_lens - paged_kernel_lens
|
538
576
|
|
539
577
|
self.call_begin_forward(
|
@@ -541,6 +579,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
541
579
|
self.wrappers_paged[wrapper_id],
|
542
580
|
req_pool_indices,
|
543
581
|
paged_kernel_lens,
|
582
|
+
paged_kernel_lens_sum,
|
544
583
|
seq_lens,
|
545
584
|
prefix_lens,
|
546
585
|
kv_start_idx,
|
@@ -550,23 +589,32 @@ class FlashInferIndicesUpdaterPrefill:
|
|
550
589
|
)
|
551
590
|
|
552
591
|
def update_cross_attention(
|
553
|
-
self,
|
592
|
+
self,
|
593
|
+
req_pool_indices: torch.Tensor,
|
594
|
+
seq_lens: torch.Tensor,
|
595
|
+
seq_lens_sum: int,
|
596
|
+
prefix_lens: torch.Tensor,
|
597
|
+
use_ragged: bool,
|
598
|
+
encoder_lens: torch.Tensor,
|
554
599
|
):
|
555
600
|
for wrapper_id in range(2):
|
556
601
|
if wrapper_id == 0:
|
557
602
|
# normal attention
|
558
603
|
paged_kernel_lens = seq_lens
|
559
604
|
kv_start_idx = encoder_lens
|
605
|
+
paged_kernel_lens_sum = seq_lens_sum
|
560
606
|
else:
|
561
607
|
# cross attention
|
562
608
|
paged_kernel_lens = encoder_lens
|
563
609
|
kv_start_idx = torch.zeros_like(encoder_lens)
|
610
|
+
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
564
611
|
|
565
612
|
self.call_begin_forward(
|
566
613
|
self.wrapper_ragged,
|
567
614
|
self.wrappers_paged[wrapper_id],
|
568
615
|
req_pool_indices,
|
569
616
|
paged_kernel_lens,
|
617
|
+
paged_kernel_lens_sum,
|
570
618
|
seq_lens,
|
571
619
|
prefix_lens,
|
572
620
|
kv_start_idx,
|
@@ -579,19 +627,22 @@ class FlashInferIndicesUpdaterPrefill:
|
|
579
627
|
self,
|
580
628
|
wrapper_ragged,
|
581
629
|
wrapper_paged,
|
582
|
-
req_pool_indices,
|
583
|
-
paged_kernel_lens,
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
630
|
+
req_pool_indices: torch.Tensor,
|
631
|
+
paged_kernel_lens: torch.Tensor,
|
632
|
+
paged_kernel_lens_sum: int,
|
633
|
+
seq_lens: torch.Tensor,
|
634
|
+
prefix_lens: torch.Tensor,
|
635
|
+
kv_start_idx: torch.Tensor,
|
636
|
+
kv_indptr: torch.Tensor,
|
637
|
+
qo_indptr: torch.Tensor,
|
638
|
+
use_ragged: bool,
|
590
639
|
):
|
591
640
|
bs = len(req_pool_indices)
|
592
641
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
593
642
|
kv_indptr = kv_indptr[: bs + 1]
|
594
|
-
kv_indices = torch.empty(
|
643
|
+
kv_indices = torch.empty(
|
644
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
645
|
+
)
|
595
646
|
create_flashinfer_kv_indices_triton[(bs,)](
|
596
647
|
self.req_to_token,
|
597
648
|
req_pool_indices,
|
@@ -599,7 +650,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
599
650
|
kv_indptr,
|
600
651
|
kv_start_idx,
|
601
652
|
kv_indices,
|
602
|
-
self.
|
653
|
+
self.req_to_token.shape[1],
|
603
654
|
)
|
604
655
|
|
605
656
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
@@ -638,10 +689,11 @@ def create_flashinfer_kv_indices_triton(
|
|
638
689
|
kv_indptr,
|
639
690
|
kv_start_idx,
|
640
691
|
kv_indices_ptr,
|
641
|
-
|
692
|
+
req_to_token_ptr_stride: tl.constexpr,
|
642
693
|
):
|
643
694
|
BLOCK_SIZE: tl.constexpr = 512
|
644
695
|
pid = tl.program_id(axis=0)
|
696
|
+
|
645
697
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
646
698
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
647
699
|
|
@@ -652,15 +704,15 @@ def create_flashinfer_kv_indices_triton(
|
|
652
704
|
kv_end = kv_start
|
653
705
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
654
706
|
|
655
|
-
req_to_token_ptr += req_pool_index * max_context_len
|
656
|
-
kv_indices_ptr += kv_indices_offset
|
657
|
-
|
658
|
-
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
659
|
-
st_offset = tl.arange(0, BLOCK_SIZE)
|
660
707
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
661
|
-
for
|
662
|
-
|
663
|
-
|
664
|
-
tl.
|
665
|
-
|
666
|
-
|
708
|
+
for i in range(num_loop):
|
709
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
710
|
+
mask = offset < kv_end - kv_start
|
711
|
+
data = tl.load(
|
712
|
+
req_to_token_ptr
|
713
|
+
+ req_pool_index * req_to_token_ptr_stride
|
714
|
+
+ kv_start
|
715
|
+
+ offset,
|
716
|
+
mask=mask,
|
717
|
+
)
|
718
|
+
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import torch
|
6
|
-
import torch.nn as nn
|
7
6
|
|
8
7
|
from sglang.srt.layers.attention import AttentionBackend
|
9
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -28,9 +27,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
28
27
|
|
29
28
|
self.decode_attention_fwd = decode_attention_fwd
|
30
29
|
self.extend_attention_fwd = extend_attention_fwd
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
|
31
|
+
if model_runner.server_args.enable_dp_attention:
|
32
|
+
self.num_head = model_runner.model_config.num_attention_heads
|
33
|
+
else:
|
34
|
+
self.num_head = (
|
35
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
36
|
+
)
|
34
37
|
|
35
38
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
36
39
|
self.reduce_dtype = torch.float32
|
@@ -50,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
50
53
|
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
51
54
|
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
52
55
|
|
53
|
-
total_num_tokens =
|
56
|
+
total_num_tokens = forward_batch.seq_lens_sum
|
54
57
|
attn_logits = torch.empty(
|
55
58
|
(self.num_head, total_num_tokens),
|
56
59
|
dtype=self.reduce_dtype,
|
@@ -61,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
61
64
|
max_extend_len = None
|
62
65
|
else:
|
63
66
|
start_loc = attn_logits = max_seq_len = None
|
64
|
-
|
65
|
-
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
67
|
+
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
66
68
|
|
67
69
|
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
68
70
|
|