sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +2 -4
- sglang/bench_serving.py +75 -26
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +13 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -16
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +55 -26
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +89 -70
- sglang/srt/managers/session_controller.py +14 -15
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +13 -15
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +59 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +196 -17
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Utilities for Huggingface Transformers."""
|
17
15
|
|
18
16
|
import contextlib
|
sglang/srt/layers/activation.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
16
|
import logging
|
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import os
|
10
11
|
from enum import Enum, auto
|
11
12
|
from typing import TYPE_CHECKING, List
|
12
13
|
|
@@ -45,13 +46,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
45
46
|
super().__init__()
|
46
47
|
|
47
48
|
# Parse constants
|
48
|
-
if
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
self.decode_use_tensor_cores = True
|
49
|
+
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
50
|
+
self.decode_use_tensor_cores = (
|
51
|
+
os.environ["SGLANG_FLASHINFER_USE_TENSOR_CORE"].lower() == "true"
|
52
|
+
)
|
53
53
|
else:
|
54
|
-
|
54
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
55
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
56
|
+
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
57
|
+
):
|
58
|
+
self.decode_use_tensor_cores = True
|
59
|
+
else:
|
60
|
+
self.decode_use_tensor_cores = False
|
61
|
+
|
55
62
|
self.max_context_len = model_runner.model_config.context_len
|
56
63
|
|
57
64
|
assert not (
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Memory-efficient attention for decoding.
|
18
16
|
It supports page size = 1.
|
@@ -26,6 +24,8 @@ import triton.language as tl
|
|
26
24
|
|
27
25
|
from sglang.srt.utils import is_hip
|
28
26
|
|
27
|
+
is_hip_ = is_hip()
|
28
|
+
|
29
29
|
|
30
30
|
@triton.jit
|
31
31
|
def tanh(x):
|
@@ -52,12 +52,13 @@ def _fwd_kernel_stage1(
|
|
52
52
|
kv_group_num: tl.constexpr,
|
53
53
|
BLOCK_DMODEL: tl.constexpr,
|
54
54
|
BLOCK_N: tl.constexpr,
|
55
|
+
SPLIT_K: tl.constexpr,
|
55
56
|
logit_cap: tl.constexpr,
|
56
57
|
Lk: tl.constexpr,
|
57
58
|
):
|
58
59
|
cur_batch = tl.program_id(0)
|
59
60
|
cur_head = tl.program_id(1)
|
60
|
-
|
61
|
+
split_k_id = tl.program_id(2)
|
61
62
|
|
62
63
|
reduce_dtype = Att_Out.dtype.element_ty
|
63
64
|
cur_kv_head = cur_head // kv_group_num
|
@@ -67,22 +68,18 @@ def _fwd_kernel_stage1(
|
|
67
68
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
68
69
|
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
69
70
|
|
70
|
-
cur_batch_start_index = 0
|
71
|
-
cur_batch_end_index = cur_batch_seq_len
|
72
|
-
|
73
71
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
72
|
+
q = tl.load(Q + off_q).to(reduce_dtype)
|
74
73
|
|
75
|
-
|
74
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
|
75
|
+
split_k_start = kv_len_per_split * split_k_id
|
76
|
+
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
|
76
77
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
for start_mark in range(0, block_mask, 1):
|
81
|
-
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
|
82
|
-
offs_n_new = cur_batch_start_index + offs_n
|
78
|
+
for start_n in range(split_k_start, split_k_end, BLOCK_N):
|
79
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
83
80
|
k_loc = tl.load(
|
84
|
-
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
85
|
-
mask=
|
81
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
82
|
+
mask=offs_n < split_k_end,
|
86
83
|
other=0,
|
87
84
|
)
|
88
85
|
offs_buf_k = (
|
@@ -92,7 +89,7 @@ def _fwd_kernel_stage1(
|
|
92
89
|
)
|
93
90
|
k = tl.load(
|
94
91
|
K_Buffer + offs_buf_k,
|
95
|
-
mask=(
|
92
|
+
mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
|
96
93
|
other=0.0,
|
97
94
|
).to(reduce_dtype)
|
98
95
|
att_value = tl.sum(q[None, :] * k, 1)
|
@@ -102,7 +99,7 @@ def _fwd_kernel_stage1(
|
|
102
99
|
att_value = logit_cap * tanh(att_value / logit_cap)
|
103
100
|
|
104
101
|
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
105
|
-
tl.store(Att_Out + off_o, att_value, mask=
|
102
|
+
tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end)
|
106
103
|
|
107
104
|
|
108
105
|
@triton.jit
|
@@ -191,11 +188,12 @@ def _decode_att_m_fwd(
|
|
191
188
|
logit_cap,
|
192
189
|
):
|
193
190
|
BLOCK = 32
|
191
|
+
SPLIT_K = 8
|
194
192
|
Lk = k_buffer.shape[-1]
|
195
193
|
|
196
194
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
197
195
|
|
198
|
-
grid = (batch, head_num,
|
196
|
+
grid = (batch, head_num, SPLIT_K)
|
199
197
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
200
198
|
|
201
199
|
if kv_group_num == 1:
|
@@ -223,6 +221,7 @@ def _decode_att_m_fwd(
|
|
223
221
|
kv_group_num=kv_group_num,
|
224
222
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
225
223
|
BLOCK_N=BLOCK,
|
224
|
+
SPLIT_K=SPLIT_K,
|
226
225
|
logit_cap=logit_cap,
|
227
226
|
num_warps=num_warps,
|
228
227
|
num_stages=1,
|
@@ -294,13 +293,14 @@ def _fwd_grouped_kernel_stage1(
|
|
294
293
|
BLOCK_DPE: tl.constexpr,
|
295
294
|
BLOCK_N: tl.constexpr,
|
296
295
|
BLOCK_H: tl.constexpr,
|
296
|
+
SPLIT_K: tl.constexpr,
|
297
297
|
logit_cap: tl.constexpr,
|
298
298
|
Lk: tl.constexpr,
|
299
299
|
):
|
300
300
|
cur_batch = tl.program_id(0)
|
301
301
|
cur_head_id = tl.program_id(1)
|
302
302
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
303
|
-
|
303
|
+
split_k_id = tl.program_id(2)
|
304
304
|
|
305
305
|
reduce_dtype = Att_Out.dtype.element_ty
|
306
306
|
|
@@ -317,30 +317,27 @@ def _fwd_grouped_kernel_stage1(
|
|
317
317
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
318
318
|
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
319
319
|
|
320
|
-
cur_batch_start_index = 0
|
321
|
-
cur_batch_end_index = cur_batch_seq_len
|
322
|
-
|
323
320
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
321
|
+
q = tl.load(
|
322
|
+
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
|
323
|
+
).to(reduce_dtype)
|
324
324
|
|
325
325
|
if BLOCK_DPE > 0:
|
326
326
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
327
327
|
off_qpe = (
|
328
328
|
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
329
329
|
)
|
330
|
+
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)
|
330
331
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
332
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
|
333
|
+
split_k_start = kv_len_per_split * split_k_id
|
334
|
+
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
|
335
335
|
|
336
|
-
for
|
337
|
-
|
338
|
-
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
|
339
|
-
).to(reduce_dtype)
|
340
|
-
offs_n_new = cur_batch_start_index + offs_n
|
336
|
+
for start_n in range(split_k_start, split_k_end, BLOCK_N):
|
337
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
341
338
|
k_loc = tl.load(
|
342
|
-
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
343
|
-
mask=
|
339
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
|
340
|
+
mask=offs_n < split_k_end,
|
344
341
|
other=0,
|
345
342
|
)
|
346
343
|
offs_buf_k = (
|
@@ -350,14 +347,11 @@ def _fwd_grouped_kernel_stage1(
|
|
350
347
|
)
|
351
348
|
k = tl.load(
|
352
349
|
K_Buffer + offs_buf_k,
|
353
|
-
mask=(
|
350
|
+
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
|
354
351
|
other=0.0,
|
355
352
|
).to(reduce_dtype)
|
356
353
|
qk = tl.dot(q, k)
|
357
354
|
if BLOCK_DPE > 0:
|
358
|
-
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
359
|
-
reduce_dtype
|
360
|
-
)
|
361
355
|
offs_buf_kpe = (
|
362
356
|
k_loc[None, :] * stride_buf_kbs
|
363
357
|
+ cur_kv_head * stride_buf_kh
|
@@ -365,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
|
|
365
359
|
)
|
366
360
|
kpe = tl.load(
|
367
361
|
K_Buffer + offs_buf_kpe,
|
368
|
-
mask=
|
362
|
+
mask=offs_n[None, :] < split_k_end,
|
369
363
|
other=0.0,
|
370
364
|
).to(reduce_dtype)
|
371
365
|
qk += tl.dot(qpe, kpe)
|
@@ -381,7 +375,7 @@ def _fwd_grouped_kernel_stage1(
|
|
381
375
|
tl.store(
|
382
376
|
Att_Out + offs_o,
|
383
377
|
qk,
|
384
|
-
mask=mask_h[:, None] & (
|
378
|
+
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
|
385
379
|
)
|
386
380
|
|
387
381
|
|
@@ -499,16 +493,17 @@ def _decode_grouped_att_m_fwd(
|
|
499
493
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
500
494
|
|
501
495
|
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
496
|
+
SPLIT_K = 8
|
502
497
|
grid = (
|
503
498
|
batch,
|
504
499
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
505
|
-
|
500
|
+
SPLIT_K,
|
506
501
|
)
|
507
502
|
|
508
503
|
num_warps = 4
|
509
504
|
|
510
505
|
extra_kargs = {}
|
511
|
-
if
|
506
|
+
if is_hip_:
|
512
507
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
513
508
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
514
509
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
@@ -534,6 +529,7 @@ def _decode_grouped_att_m_fwd(
|
|
534
529
|
BLOCK_DPE=BLOCK_DPE,
|
535
530
|
BLOCK_N=BLOCK,
|
536
531
|
BLOCK_H=BLOCK_H,
|
532
|
+
SPLIT_K=SPLIT_K,
|
537
533
|
logit_cap=logit_cap,
|
538
534
|
num_warps=num_warps,
|
539
535
|
num_stages=1,
|
@@ -563,7 +559,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
563
559
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
564
560
|
|
565
561
|
extra_kargs = {}
|
566
|
-
if
|
562
|
+
if is_hip_:
|
567
563
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
568
564
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
569
565
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Memory-efficient attention for prefill.
|
18
16
|
It supports page size = 1 and prefill with KV cache (i.e. extend).
|
@@ -31,6 +29,8 @@ is_cuda_available = torch.cuda.is_available()
|
|
31
29
|
if is_cuda_available:
|
32
30
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
33
31
|
|
32
|
+
is_hip_ = is_hip()
|
33
|
+
|
34
34
|
|
35
35
|
@triton.jit
|
36
36
|
def tanh(x):
|
@@ -313,7 +313,7 @@ def extend_attention_fwd(
|
|
313
313
|
num_stages = 1
|
314
314
|
|
315
315
|
extra_kargs = {}
|
316
|
-
if
|
316
|
+
if is_hip_:
|
317
317
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
318
318
|
|
319
319
|
_fwd_kernel[grid](
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Memory-efficient attention for prefill.
|
18
16
|
It supporst page size = 1.
|
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
from vllm.model_executor.custom_op import CustomOp
|
17
16
|
|
@@ -0,0 +1 @@
|
|
1
|
+
from sglang.srt.layers.fused_moe_grok.layer import FusedMoE, FusedMoEMethodBase
|
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
20
20
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
21
21
|
from vllm.model_executor.utils import set_weight_attrs
|
22
22
|
|
23
|
-
from sglang.srt.layers.
|
23
|
+
from sglang.srt.layers.fused_moe_grok.fused_moe import padding_size
|
24
24
|
from sglang.srt.utils import is_hip
|
25
25
|
|
26
26
|
logger = init_logger(__name__)
|
@@ -123,7 +123,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
123
123
|
num_expert_group: Optional[int],
|
124
124
|
topk_group: Optional[int],
|
125
125
|
) -> torch.Tensor:
|
126
|
-
from sglang.srt.layers.
|
126
|
+
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
|
127
127
|
|
128
128
|
return fused_moe(
|
129
129
|
x,
|
@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
153
153
|
num_expert_group: Optional[int],
|
154
154
|
topk_group: Optional[int],
|
155
155
|
) -> torch.Tensor:
|
156
|
-
|
157
|
-
|
158
|
-
assert not use_grouped_topk
|
159
|
-
assert num_expert_group is None
|
160
|
-
assert topk_group is None
|
161
|
-
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
156
|
+
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
162
157
|
|
163
158
|
|
164
159
|
class FusedMoE(torch.nn.Module):
|
@@ -614,7 +609,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
614
609
|
topk_group: Optional[int] = None,
|
615
610
|
) -> torch.Tensor:
|
616
611
|
|
617
|
-
from sglang.srt.layers.
|
612
|
+
from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
|
618
613
|
|
619
614
|
return fused_moe(
|
620
615
|
x,
|
@@ -1,3 +1,8 @@
|
|
1
|
+
"""
|
2
|
+
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
3
|
+
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
4
|
+
"""
|
5
|
+
|
1
6
|
from typing import Callable, Optional
|
2
7
|
|
3
8
|
import torch
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from contextlib import contextmanager
|
2
|
+
from typing import Any, Dict, Optional
|
3
|
+
|
4
|
+
import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
|
5
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
6
|
+
fused_experts,
|
7
|
+
fused_topk,
|
8
|
+
get_config_file_name,
|
9
|
+
grouped_topk,
|
10
|
+
)
|
11
|
+
from sglang.srt.layers.fused_moe_triton.layer import (
|
12
|
+
FusedMoE,
|
13
|
+
FusedMoEMethodBase,
|
14
|
+
FusedMoeWeightScaleSupported,
|
15
|
+
)
|
16
|
+
|
17
|
+
_config: Optional[Dict[str, Any]] = None
|
18
|
+
|
19
|
+
|
20
|
+
@contextmanager
|
21
|
+
def override_config(config):
|
22
|
+
global _config
|
23
|
+
old_config = _config
|
24
|
+
_config = config
|
25
|
+
yield
|
26
|
+
_config = old_config
|
27
|
+
|
28
|
+
|
29
|
+
def get_config() -> Optional[Dict[str, Any]]:
|
30
|
+
return _config
|
31
|
+
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
"FusedMoE",
|
35
|
+
"FusedMoEMethodBase",
|
36
|
+
"FusedMoeWeightScaleSupported",
|
37
|
+
"override_config",
|
38
|
+
"get_config",
|
39
|
+
"fused_moe",
|
40
|
+
"fused_topk",
|
41
|
+
"fused_experts",
|
42
|
+
"get_config_file_name",
|
43
|
+
"grouped_topk",
|
44
|
+
]
|