sglang 0.3.6__py3-none-any.whl → 0.3.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +4 -7
- sglang/bench_one_batch_server.py +2 -2
- sglang/bench_serving.py +75 -26
- sglang/check_env.py +7 -1
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +0 -3
- sglang/srt/configs/model_config.py +15 -20
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +14 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -18
- sglang/srt/managers/image_processor.py +6 -9
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +92 -27
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +94 -72
- sglang/srt/managers/session_controller.py +29 -19
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +20 -16
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +60 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +208 -19
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/METADATA +25 -15
- sglang-0.3.6.post2.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/top_level.txt +0 -0
@@ -1,39 +1,30 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Constrained decoding with xgrammar backend."""
|
17
15
|
|
18
16
|
import logging
|
19
17
|
from typing import List, Tuple
|
20
18
|
|
21
19
|
import torch
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
import_error = None
|
32
|
-
except ImportError as e:
|
33
|
-
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
34
|
-
ImportError
|
35
|
-
)
|
36
|
-
import_error = e
|
20
|
+
from xgrammar import (
|
21
|
+
CompiledGrammar,
|
22
|
+
GrammarCompiler,
|
23
|
+
GrammarMatcher,
|
24
|
+
TokenizerInfo,
|
25
|
+
allocate_token_bitmask,
|
26
|
+
apply_token_bitmask_inplace,
|
27
|
+
)
|
37
28
|
|
38
29
|
from sglang.srt.constrained.base_grammar_backend import (
|
39
30
|
BaseGrammarBackend,
|
@@ -43,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|
43
34
|
logger = logging.getLogger(__name__)
|
44
35
|
|
45
36
|
|
46
|
-
MAX_ROLLBACK_TOKENS =
|
37
|
+
MAX_ROLLBACK_TOKENS = 200
|
47
38
|
|
48
39
|
|
49
40
|
class XGrammarGrammar(BaseGrammarObject):
|
@@ -88,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
88
79
|
def allocate_vocab_mask(
|
89
80
|
self, vocab_size: int, batch_size: int, device
|
90
81
|
) -> torch.Tensor:
|
91
|
-
return
|
82
|
+
return allocate_token_bitmask(batch_size, vocab_size)
|
92
83
|
|
93
84
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
94
85
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
95
86
|
|
96
87
|
@staticmethod
|
97
88
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
98
|
-
|
89
|
+
if vocab_mask.device.type != logits.device.type:
|
90
|
+
# vocab_mask must then be on the same device as logits
|
91
|
+
# when applying the token bitmask, so we check and move if needed
|
92
|
+
vocab_mask = vocab_mask.to(logits.device)
|
93
|
+
|
94
|
+
apply_token_bitmask_inplace(logits, vocab_mask)
|
99
95
|
|
100
96
|
def copy(self):
|
101
|
-
matcher = GrammarMatcher(
|
102
|
-
self.ctx,
|
103
|
-
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
104
|
-
vocab_size=self.vocab_size,
|
105
|
-
)
|
97
|
+
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
106
98
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
107
99
|
|
108
100
|
|
@@ -114,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
114
106
|
):
|
115
107
|
super().__init__()
|
116
108
|
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
self.grammar_cache = None
|
122
|
-
return
|
123
|
-
|
124
|
-
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
|
125
|
-
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
|
109
|
+
tokenizer_info = TokenizerInfo.from_huggingface(
|
110
|
+
tokenizer, vocab_size=vocab_size
|
111
|
+
)
|
112
|
+
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
126
113
|
self.vocab_size = vocab_size
|
127
114
|
|
128
115
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
129
|
-
if import_error:
|
130
|
-
raise import_error
|
131
116
|
|
132
117
|
key_type, key_string = key
|
133
118
|
if key_type == "json":
|
134
119
|
try:
|
135
|
-
ctx = self.
|
120
|
+
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
136
121
|
except RuntimeError as e:
|
137
122
|
logging.warning(
|
138
123
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
@@ -146,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
146
131
|
else:
|
147
132
|
raise ValueError(f"Invalid key_type: {key_type}")
|
148
133
|
|
149
|
-
matcher = GrammarMatcher(
|
150
|
-
ctx,
|
151
|
-
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
152
|
-
vocab_size=self.vocab_size,
|
153
|
-
)
|
134
|
+
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
154
135
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
155
136
|
|
156
137
|
def reset(self):
|
157
|
-
if self.
|
158
|
-
self.
|
138
|
+
if self.grammar_compiler:
|
139
|
+
self.grammar_compiler.clear_cache()
|
sglang/srt/conversation.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Conversation chat templates."""
|
17
15
|
|
18
16
|
# Adapted from
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Utilities for Huggingface Transformers."""
|
17
15
|
|
18
16
|
import contextlib
|
sglang/srt/layers/activation.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
16
|
import logging
|
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import os
|
10
11
|
from enum import Enum, auto
|
11
12
|
from typing import TYPE_CHECKING, List
|
12
13
|
|
@@ -17,7 +18,7 @@ import triton.language as tl
|
|
17
18
|
from sglang.global_config import global_config
|
18
19
|
from sglang.srt.layers.attention import AttentionBackend
|
19
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
20
|
-
from sglang.srt.utils import is_flashinfer_available
|
21
|
+
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
21
22
|
|
22
23
|
if TYPE_CHECKING:
|
23
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
45
46
|
super().__init__()
|
46
47
|
|
47
48
|
# Parse constants
|
48
|
-
if
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
self.decode_use_tensor_cores = True
|
49
|
+
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
50
|
+
self.decode_use_tensor_cores = get_bool_env_var(
|
51
|
+
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
|
52
|
+
)
|
53
53
|
else:
|
54
|
-
|
54
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
55
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
56
|
+
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
57
|
+
):
|
58
|
+
self.decode_use_tensor_cores = True
|
59
|
+
else:
|
60
|
+
self.decode_use_tensor_cores = False
|
61
|
+
|
55
62
|
self.max_context_len = model_runner.model_config.context_len
|
56
63
|
|
57
64
|
assert not (
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Memory-efficient attention for decoding.
|
18
16
|
It supports page size = 1.
|
@@ -26,6 +24,8 @@ import triton.language as tl
|
|
26
24
|
|
27
25
|
from sglang.srt.utils import is_hip
|
28
26
|
|
27
|
+
is_hip_ = is_hip()
|
28
|
+
|
29
29
|
|
30
30
|
@triton.jit
|
31
31
|
def tanh(x):
|
@@ -52,12 +52,13 @@ def _fwd_kernel_stage1(
|
|
52
52
|
kv_group_num: tl.constexpr,
|
53
53
|
BLOCK_DMODEL: tl.constexpr,
|
54
54
|
BLOCK_N: tl.constexpr,
|
55
|
+
SPLIT_K: tl.constexpr,
|
55
56
|
logit_cap: tl.constexpr,
|
56
57
|
Lk: tl.constexpr,
|
57
58
|
):
|
58
59
|
cur_batch = tl.program_id(0)
|
59
60
|
cur_head = tl.program_id(1)
|
60
|
-
|
61
|
+
split_k_id = tl.program_id(2)
|
61
62
|
|
62
63
|
reduce_dtype = Att_Out.dtype.element_ty
|
63
64
|
cur_kv_head = cur_head // kv_group_num
|
@@ -67,22 +68,18 @@ def _fwd_kernel_stage1(
|
|
67
68
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
68
69
|
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
69
70
|
|
70
|
-
cur_batch_start_index = 0
|
71
|
-
cur_batch_end_index = cur_batch_seq_len
|
72
|
-
|
73
71
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
72
|
+
q = tl.load(Q + off_q).to(reduce_dtype)
|
74
73
|
|
75
|
-
|
74
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
|
75
|
+
split_k_start = kv_len_per_split * split_k_id
|
76
|
+
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
|
76
77
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
for start_mark in range(0, block_mask, 1):
|
81
|
-
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
|
82
|
-
offs_n_new = cur_batch_start_index + offs_n
|
78
|
+
for start_n in range(split_k_start, split_k_end, BLOCK_N):
|
79
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
83
80
|
k_loc = tl.load(
|
84
|
-
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
85
|
-
mask=
|
81
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
82
|
+
mask=offs_n < split_k_end,
|
86
83
|
other=0,
|
87
84
|
)
|
88
85
|
offs_buf_k = (
|
@@ -92,7 +89,7 @@ def _fwd_kernel_stage1(
|
|
92
89
|
)
|
93
90
|
k = tl.load(
|
94
91
|
K_Buffer + offs_buf_k,
|
95
|
-
mask=(
|
92
|
+
mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
|
96
93
|
other=0.0,
|
97
94
|
).to(reduce_dtype)
|
98
95
|
att_value = tl.sum(q[None, :] * k, 1)
|
@@ -102,7 +99,7 @@ def _fwd_kernel_stage1(
|
|
102
99
|
att_value = logit_cap * tanh(att_value / logit_cap)
|
103
100
|
|
104
101
|
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
105
|
-
tl.store(Att_Out + off_o, att_value, mask=
|
102
|
+
tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end)
|
106
103
|
|
107
104
|
|
108
105
|
@triton.jit
|
@@ -191,11 +188,12 @@ def _decode_att_m_fwd(
|
|
191
188
|
logit_cap,
|
192
189
|
):
|
193
190
|
BLOCK = 32
|
191
|
+
SPLIT_K = 8
|
194
192
|
Lk = k_buffer.shape[-1]
|
195
193
|
|
196
194
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
197
195
|
|
198
|
-
grid = (batch, head_num,
|
196
|
+
grid = (batch, head_num, SPLIT_K)
|
199
197
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
200
198
|
|
201
199
|
if kv_group_num == 1:
|
@@ -223,6 +221,7 @@ def _decode_att_m_fwd(
|
|
223
221
|
kv_group_num=kv_group_num,
|
224
222
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
225
223
|
BLOCK_N=BLOCK,
|
224
|
+
SPLIT_K=SPLIT_K,
|
226
225
|
logit_cap=logit_cap,
|
227
226
|
num_warps=num_warps,
|
228
227
|
num_stages=1,
|
@@ -294,13 +293,14 @@ def _fwd_grouped_kernel_stage1(
|
|
294
293
|
BLOCK_DPE: tl.constexpr,
|
295
294
|
BLOCK_N: tl.constexpr,
|
296
295
|
BLOCK_H: tl.constexpr,
|
296
|
+
SPLIT_K: tl.constexpr,
|
297
297
|
logit_cap: tl.constexpr,
|
298
298
|
Lk: tl.constexpr,
|
299
299
|
):
|
300
300
|
cur_batch = tl.program_id(0)
|
301
301
|
cur_head_id = tl.program_id(1)
|
302
302
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
303
|
-
|
303
|
+
split_k_id = tl.program_id(2)
|
304
304
|
|
305
305
|
reduce_dtype = Att_Out.dtype.element_ty
|
306
306
|
|
@@ -317,30 +317,27 @@ def _fwd_grouped_kernel_stage1(
|
|
317
317
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
318
318
|
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
319
319
|
|
320
|
-
cur_batch_start_index = 0
|
321
|
-
cur_batch_end_index = cur_batch_seq_len
|
322
|
-
|
323
320
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
321
|
+
q = tl.load(
|
322
|
+
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
|
323
|
+
).to(reduce_dtype)
|
324
324
|
|
325
325
|
if BLOCK_DPE > 0:
|
326
326
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
327
327
|
off_qpe = (
|
328
328
|
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
329
329
|
)
|
330
|
+
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)
|
330
331
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
332
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
|
333
|
+
split_k_start = kv_len_per_split * split_k_id
|
334
|
+
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
|
335
335
|
|
336
|
-
for
|
337
|
-
|
338
|
-
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
|
339
|
-
).to(reduce_dtype)
|
340
|
-
offs_n_new = cur_batch_start_index + offs_n
|
336
|
+
for start_n in range(split_k_start, split_k_end, BLOCK_N):
|
337
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
341
338
|
k_loc = tl.load(
|
342
|
-
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
343
|
-
mask=
|
339
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
340
|
+
mask=offs_n < split_k_end,
|
344
341
|
other=0,
|
345
342
|
)
|
346
343
|
offs_buf_k = (
|
@@ -350,14 +347,11 @@ def _fwd_grouped_kernel_stage1(
|
|
350
347
|
)
|
351
348
|
k = tl.load(
|
352
349
|
K_Buffer + offs_buf_k,
|
353
|
-
mask=(
|
350
|
+
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
|
354
351
|
other=0.0,
|
355
352
|
).to(reduce_dtype)
|
356
353
|
qk = tl.dot(q, k)
|
357
354
|
if BLOCK_DPE > 0:
|
358
|
-
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
359
|
-
reduce_dtype
|
360
|
-
)
|
361
355
|
offs_buf_kpe = (
|
362
356
|
k_loc[None, :] * stride_buf_kbs
|
363
357
|
+ cur_kv_head * stride_buf_kh
|
@@ -365,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
|
|
365
359
|
)
|
366
360
|
kpe = tl.load(
|
367
361
|
K_Buffer + offs_buf_kpe,
|
368
|
-
mask=
|
362
|
+
mask=offs_n[None, :] < split_k_end,
|
369
363
|
other=0.0,
|
370
364
|
).to(reduce_dtype)
|
371
365
|
qk += tl.dot(qpe, kpe)
|
@@ -381,7 +375,7 @@ def _fwd_grouped_kernel_stage1(
|
|
381
375
|
tl.store(
|
382
376
|
Att_Out + offs_o,
|
383
377
|
qk,
|
384
|
-
mask=mask_h[:, None] & (
|
378
|
+
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
|
385
379
|
)
|
386
380
|
|
387
381
|
|
@@ -499,16 +493,17 @@ def _decode_grouped_att_m_fwd(
|
|
499
493
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
500
494
|
|
501
495
|
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
496
|
+
SPLIT_K = 8
|
502
497
|
grid = (
|
503
498
|
batch,
|
504
499
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
505
|
-
|
500
|
+
SPLIT_K,
|
506
501
|
)
|
507
502
|
|
508
503
|
num_warps = 4
|
509
504
|
|
510
505
|
extra_kargs = {}
|
511
|
-
if
|
506
|
+
if is_hip_:
|
512
507
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
513
508
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
514
509
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
@@ -534,6 +529,7 @@ def _decode_grouped_att_m_fwd(
|
|
534
529
|
BLOCK_DPE=BLOCK_DPE,
|
535
530
|
BLOCK_N=BLOCK,
|
536
531
|
BLOCK_H=BLOCK_H,
|
532
|
+
SPLIT_K=SPLIT_K,
|
537
533
|
logit_cap=logit_cap,
|
538
534
|
num_warps=num_warps,
|
539
535
|
num_stages=1,
|
@@ -563,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
563
559
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
564
560
|
|
565
561
|
extra_kargs = {}
|
566
|
-
if
|
562
|
+
if is_hip_:
|
567
563
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
568
564
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
569
565
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Memory-efficient attention for prefill.
|
18
16
|
It supports page size = 1 and prefill with KV cache (i.e. extend).
|
@@ -31,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
|
|
31
29
|
if is_cuda_available:
|
32
30
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
33
31
|
|
32
|
+
is_hip_ = is_hip()
|
33
|
+
|
34
34
|
|
35
35
|
@triton.jit
|
36
36
|
def tanh(x):
|
@@ -313,7 +313,7 @@ def extend_attention_fwd(
|
|
313
313
|
num_stages = 1
|
314
314
|
|
315
315
|
extra_kargs = {}
|
316
|
-
if
|
316
|
+
if is_hip_:
|
317
317
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
318
318
|
|
319
319
|
_fwd_kernel[grid](
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Memory-efficient attention for prefill.
|
18
16
|
It supporst page size = 1.
|