sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
|
|
8
8
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
9
9
|
|
10
10
|
|
11
|
+
@triton.jit
|
12
|
+
def tanh(x):
|
13
|
+
# Tanh is just a scaled sigmoid
|
14
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
15
|
+
|
16
|
+
|
11
17
|
@triton.jit
|
12
18
|
def _fwd_kernel(
|
13
19
|
Q_Extend,
|
@@ -39,6 +45,7 @@ def _fwd_kernel(
|
|
39
45
|
BLOCK_DMODEL: tl.constexpr,
|
40
46
|
BLOCK_M: tl.constexpr,
|
41
47
|
BLOCK_N: tl.constexpr,
|
48
|
+
logit_cap: tl.constexpr,
|
42
49
|
):
|
43
50
|
cur_seq = tl.program_id(0)
|
44
51
|
cur_head = tl.program_id(1)
|
@@ -90,6 +97,10 @@ def _fwd_kernel(
|
|
90
97
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
91
98
|
qk += tl.dot(q, k)
|
92
99
|
qk *= sm_scale
|
100
|
+
|
101
|
+
if logit_cap > 0:
|
102
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
103
|
+
|
93
104
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
94
105
|
|
95
106
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
@@ -126,6 +137,10 @@ def _fwd_kernel(
|
|
126
137
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
127
138
|
qk += tl.dot(q, k)
|
128
139
|
qk *= sm_scale
|
140
|
+
|
141
|
+
if logit_cap > 0:
|
142
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
143
|
+
|
129
144
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
130
145
|
start_n + offs_n[None, :]
|
131
146
|
)
|
@@ -176,6 +191,7 @@ def extend_attention_fwd(
|
|
176
191
|
b_seq_len_extend,
|
177
192
|
max_len_in_batch,
|
178
193
|
max_len_extend,
|
194
|
+
logit_cap=-1,
|
179
195
|
):
|
180
196
|
"""
|
181
197
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -271,6 +287,7 @@ def extend_attention_fwd(
|
|
271
287
|
BLOCK_N=BLOCK_N,
|
272
288
|
num_warps=num_warps,
|
273
289
|
num_stages=num_stages,
|
290
|
+
logit_cap=logit_cap,
|
274
291
|
)
|
275
292
|
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
276
293
|
|
@@ -0,0 +1,485 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
|
3
|
+
"""Fused MoE kernel."""
|
4
|
+
import functools
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
from typing import Any, Dict, Optional, Tuple
|
8
|
+
|
9
|
+
import torch
|
10
|
+
import triton
|
11
|
+
import triton.language as tl
|
12
|
+
|
13
|
+
from vllm import _custom_ops as ops
|
14
|
+
from vllm.logger import init_logger
|
15
|
+
from vllm.utils import is_hip
|
16
|
+
|
17
|
+
logger = init_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
@triton.jit
|
21
|
+
def fused_moe_kernel(
|
22
|
+
# Pointers to matrices
|
23
|
+
a_ptr,
|
24
|
+
b_ptr,
|
25
|
+
c_ptr,
|
26
|
+
a_scale_ptr,
|
27
|
+
b_scale_ptr,
|
28
|
+
topk_weights_ptr,
|
29
|
+
sorted_token_ids_ptr,
|
30
|
+
expert_ids_ptr,
|
31
|
+
num_tokens_post_padded_ptr,
|
32
|
+
# Matrix dimensions
|
33
|
+
N,
|
34
|
+
K,
|
35
|
+
EM,
|
36
|
+
num_valid_tokens,
|
37
|
+
# The stride variables represent how much to increase the ptr by when
|
38
|
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
39
|
+
# how much to increase `a_ptr` by to get the element one row down
|
40
|
+
# (A has M rows).
|
41
|
+
stride_am,
|
42
|
+
stride_ak,
|
43
|
+
stride_be,
|
44
|
+
stride_bk,
|
45
|
+
stride_bn,
|
46
|
+
stride_cm,
|
47
|
+
stride_cn,
|
48
|
+
# Meta-parameters
|
49
|
+
BLOCK_SIZE_M: tl.constexpr,
|
50
|
+
BLOCK_SIZE_N: tl.constexpr,
|
51
|
+
BLOCK_SIZE_K: tl.constexpr,
|
52
|
+
GROUP_SIZE_M: tl.constexpr,
|
53
|
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
54
|
+
top_k: tl.constexpr,
|
55
|
+
compute_type: tl.constexpr,
|
56
|
+
use_fp8: tl.constexpr,
|
57
|
+
):
|
58
|
+
"""
|
59
|
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
60
|
+
token and expert matrices.
|
61
|
+
|
62
|
+
Key Parameters:
|
63
|
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
64
|
+
be any shape representing batches and K is the feature dimension of
|
65
|
+
each token.
|
66
|
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
67
|
+
the number of experts, K is the input feature dimension, and N is
|
68
|
+
the output feature dimension.
|
69
|
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
70
|
+
total number of tokens post padding, topk is the number of times
|
71
|
+
each token is repeated, and N is the output feature dimension.
|
72
|
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
73
|
+
repeated topk times and arranged by the expert index they are
|
74
|
+
assigned to.
|
75
|
+
- expert_ids: A tensor containing the indices of the expert for each
|
76
|
+
block. It determines which expert matrix from B should be used for
|
77
|
+
each block in A.
|
78
|
+
This kernel performs the multiplication of a token by its corresponding
|
79
|
+
expert matrix as determined by `expert_ids`. The sorting of
|
80
|
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
81
|
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
82
|
+
multiplication across different blocks processed by the same expert.
|
83
|
+
"""
|
84
|
+
# -----------------------------------------------------------
|
85
|
+
# Map program ids `pid` to the block of C it should compute.
|
86
|
+
# This is done in a grouped ordering to promote L2 data reuse.
|
87
|
+
pid = tl.program_id(axis=0)
|
88
|
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
89
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
90
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
91
|
+
group_id = pid // num_pid_in_group
|
92
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
93
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
94
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
95
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
96
|
+
|
97
|
+
# ----------------------------------------------------------
|
98
|
+
# Create pointers for the first blocks of A and B.
|
99
|
+
# We will advance this pointer as we move in the K direction
|
100
|
+
# and accumulate
|
101
|
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
102
|
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
103
|
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
104
|
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
105
|
+
return
|
106
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
107
|
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
108
|
+
token_mask = offs_token < num_valid_tokens
|
109
|
+
|
110
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
111
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
112
|
+
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
113
|
+
offs_k[None, :] * stride_ak)
|
114
|
+
|
115
|
+
off_experts = tl.load(expert_ids_ptr + pid_m)
|
116
|
+
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
117
|
+
offs_bn[None, :] * stride_bn)
|
118
|
+
|
119
|
+
if use_fp8:
|
120
|
+
a_scale = tl.load(a_scale_ptr)
|
121
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
122
|
+
|
123
|
+
# -----------------------------------------------------------
|
124
|
+
# Iterate to compute a block of the C matrix.
|
125
|
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
126
|
+
# of fp32 values for higher accuracy.
|
127
|
+
# `accumulator` will be converted back to fp16 after the loop.
|
128
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
129
|
+
|
130
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
131
|
+
# Load the next block of A and B, generate a mask by checking the
|
132
|
+
# K dimension.
|
133
|
+
a = tl.load(a_ptrs,
|
134
|
+
mask=token_mask[:, None] &
|
135
|
+
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
136
|
+
other=0.0)
|
137
|
+
b = tl.load(b_ptrs,
|
138
|
+
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
139
|
+
other=0.0)
|
140
|
+
# We accumulate along the K dimension.
|
141
|
+
if use_fp8:
|
142
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
143
|
+
else:
|
144
|
+
accumulator += tl.dot(a, b)
|
145
|
+
# Advance the ptrs to the next K block.
|
146
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
147
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
148
|
+
|
149
|
+
if MUL_ROUTED_WEIGHT:
|
150
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
151
|
+
mask=token_mask,
|
152
|
+
other=0)
|
153
|
+
accumulator = accumulator * moe_weight[:, None]
|
154
|
+
|
155
|
+
if use_fp8:
|
156
|
+
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
157
|
+
else:
|
158
|
+
accumulator = accumulator.to(compute_type)
|
159
|
+
# -----------------------------------------------------------
|
160
|
+
# Write back the block of the output
|
161
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
162
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
163
|
+
None, :]
|
164
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
165
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
166
|
+
|
167
|
+
|
168
|
+
def moe_align_block_size(
|
169
|
+
topk_ids: torch.Tensor, block_size: int,
|
170
|
+
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
171
|
+
"""
|
172
|
+
Aligns the token distribution across experts to be compatible with block
|
173
|
+
size for matrix multiplication.
|
174
|
+
|
175
|
+
Parameters:
|
176
|
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
177
|
+
top-k expert indices for each token.
|
178
|
+
- block_size: The block size used in block matrix multiplication.
|
179
|
+
- num_experts: The total number of experts.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
183
|
+
to their allocated expert.
|
184
|
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
185
|
+
- num_tokens_post_padded: The total number of tokens after padding,
|
186
|
+
ensuring divisibility by block_size.
|
187
|
+
|
188
|
+
This function pads the number of tokens that each expert needs to process
|
189
|
+
so that it is divisible by block_size.
|
190
|
+
Padding ensures that during block matrix multiplication, the dimensions
|
191
|
+
align correctly.
|
192
|
+
|
193
|
+
Example:
|
194
|
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
195
|
+
block_size = 4, and num_experts = 4:
|
196
|
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
197
|
+
with each expert needing to process 3 tokens.
|
198
|
+
- As block_size is 4, we pad 1 token for each expert.
|
199
|
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
200
|
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
201
|
+
- After sorting by expert index, we obtain token_ids
|
202
|
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
203
|
+
Tokens 12 are non-existent (padding) and are ignored in
|
204
|
+
the subsequent matrix multiplication.
|
205
|
+
- The padding ensures that the total number of tokens is now divisible
|
206
|
+
by block_size for proper block matrix operations.
|
207
|
+
"""
|
208
|
+
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
209
|
+
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
210
|
+
dtype=torch.int32,
|
211
|
+
device=topk_ids.device)
|
212
|
+
sorted_ids.fill_(topk_ids.numel())
|
213
|
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
214
|
+
expert_ids = torch.empty((max_num_m_blocks, ),
|
215
|
+
dtype=torch.int32,
|
216
|
+
device=topk_ids.device)
|
217
|
+
num_tokens_post_pad = torch.empty((1),
|
218
|
+
dtype=torch.int32,
|
219
|
+
device=topk_ids.device)
|
220
|
+
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
221
|
+
expert_ids, num_tokens_post_pad)
|
222
|
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
223
|
+
|
224
|
+
|
225
|
+
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
226
|
+
A_scale: Optional[torch.Tensor],
|
227
|
+
B_scale: Optional[torch.Tensor],
|
228
|
+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
229
|
+
sorted_token_ids: torch.Tensor,
|
230
|
+
expert_ids: torch.Tensor,
|
231
|
+
num_tokens_post_padded: torch.Tensor,
|
232
|
+
mul_routed_weight: bool, top_k: int,
|
233
|
+
config: Dict[str, Any], compute_type: tl.dtype,
|
234
|
+
use_fp8: bool) -> None:
|
235
|
+
assert topk_weights.stride(1) == 1
|
236
|
+
assert sorted_token_ids.stride(0) == 1
|
237
|
+
|
238
|
+
if not use_fp8:
|
239
|
+
assert A_scale is None
|
240
|
+
assert B_scale is None
|
241
|
+
else:
|
242
|
+
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
243
|
+
assert B_scale is not None
|
244
|
+
|
245
|
+
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
246
|
+
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
247
|
+
|
248
|
+
fused_moe_kernel[grid](
|
249
|
+
A,
|
250
|
+
B,
|
251
|
+
C,
|
252
|
+
A_scale,
|
253
|
+
B_scale,
|
254
|
+
topk_weights,
|
255
|
+
sorted_token_ids,
|
256
|
+
expert_ids,
|
257
|
+
num_tokens_post_padded,
|
258
|
+
B.shape[1],
|
259
|
+
B.shape[2],
|
260
|
+
sorted_token_ids.shape[0],
|
261
|
+
topk_ids.numel(),
|
262
|
+
A.stride(0),
|
263
|
+
A.stride(1),
|
264
|
+
B.stride(0),
|
265
|
+
B.stride(2),
|
266
|
+
B.stride(1),
|
267
|
+
C.stride(1),
|
268
|
+
C.stride(2),
|
269
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
270
|
+
top_k=top_k,
|
271
|
+
compute_type=compute_type,
|
272
|
+
use_fp8=use_fp8,
|
273
|
+
**config,
|
274
|
+
)
|
275
|
+
|
276
|
+
|
277
|
+
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
278
|
+
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
279
|
+
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
280
|
+
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
281
|
+
|
282
|
+
|
283
|
+
@functools.lru_cache
|
284
|
+
def get_moe_configs(E: int, N: int,
|
285
|
+
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
286
|
+
"""
|
287
|
+
Return optimized configurations for the fused MoE kernel.
|
288
|
+
|
289
|
+
The return value will be a dictionary that maps an irregular grid of
|
290
|
+
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
291
|
+
kernel on a given batch size bs, the closest batch size in the grid should
|
292
|
+
be picked and the associated configuration chosen to invoke the kernel.
|
293
|
+
"""
|
294
|
+
|
295
|
+
# First look up if an optimized configuration is available in the configs
|
296
|
+
# directory
|
297
|
+
json_file_name = get_config_file_name(E, N, dtype)
|
298
|
+
|
299
|
+
config_file_path = os.path.join(
|
300
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
301
|
+
if os.path.exists(config_file_path):
|
302
|
+
with open(config_file_path) as f:
|
303
|
+
logger.info("Using configuration from %s for MoE layer.",
|
304
|
+
config_file_path)
|
305
|
+
# If a configuration has been found, return it
|
306
|
+
return {int(key): val for key, val in json.load(f).items()}
|
307
|
+
|
308
|
+
# If no optimized configuration is available, we will use the default
|
309
|
+
# configuration
|
310
|
+
return None
|
311
|
+
|
312
|
+
|
313
|
+
def fused_moe(
|
314
|
+
hidden_states: torch.Tensor,
|
315
|
+
w1: torch.Tensor,
|
316
|
+
w2: torch.Tensor,
|
317
|
+
gating_output: torch.Tensor,
|
318
|
+
topk: int,
|
319
|
+
renormalize: bool,
|
320
|
+
inplace: bool = False,
|
321
|
+
override_config: Optional[Dict[str, Any]] = None,
|
322
|
+
use_fp8: bool = False,
|
323
|
+
w1_scale: Optional[torch.Tensor] = None,
|
324
|
+
w2_scale: Optional[torch.Tensor] = None,
|
325
|
+
a1_scale: Optional[torch.Tensor] = None,
|
326
|
+
a2_scale: Optional[torch.Tensor] = None,
|
327
|
+
) -> torch.Tensor:
|
328
|
+
"""
|
329
|
+
This function computes a Mixture of Experts (MoE) layer using two sets of
|
330
|
+
weights, w1 and w2, and top-k gating mechanism.
|
331
|
+
|
332
|
+
Parameters:
|
333
|
+
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
334
|
+
- w1 (torch.Tensor): The first set of expert weights.
|
335
|
+
- w2 (torch.Tensor): The second set of expert weights.
|
336
|
+
- gating_output (torch.Tensor): The output of the gating operation
|
337
|
+
(before softmax).
|
338
|
+
- topk (int): The number of top-k experts to select.
|
339
|
+
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
340
|
+
- inplace (bool): If True, perform the operation in-place.
|
341
|
+
Defaults to False.
|
342
|
+
- override_config (Optional[Dict[str, Any]]): Optional override
|
343
|
+
for the kernel configuration.
|
344
|
+
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
345
|
+
products for w1 and w2. Defaults to False.
|
346
|
+
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
347
|
+
w1.
|
348
|
+
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
349
|
+
w2.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
- torch.Tensor: The output tensor after applying the MoE layer.
|
353
|
+
"""
|
354
|
+
# Check constraints.
|
355
|
+
assert hidden_states.shape[0] == gating_output.shape[0], (
|
356
|
+
"Number of tokens mismatch")
|
357
|
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
358
|
+
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
359
|
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
360
|
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
361
|
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
362
|
+
assert hidden_states.dtype in [
|
363
|
+
torch.float32, torch.float16, torch.bfloat16
|
364
|
+
]
|
365
|
+
M, _ = hidden_states.shape
|
366
|
+
E, N, _ = w1.shape
|
367
|
+
|
368
|
+
if is_hip():
|
369
|
+
# The MoE kernels are not yet supported on ROCm.
|
370
|
+
routing_weights = torch.softmax(gating_output,
|
371
|
+
dim=-1,
|
372
|
+
dtype=torch.float32)
|
373
|
+
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
374
|
+
else:
|
375
|
+
import vllm._moe_C as moe_kernels
|
376
|
+
|
377
|
+
topk_weights = torch.empty(M,
|
378
|
+
topk,
|
379
|
+
dtype=torch.float32,
|
380
|
+
device=hidden_states.device)
|
381
|
+
topk_ids = torch.empty(M,
|
382
|
+
topk,
|
383
|
+
dtype=torch.int32,
|
384
|
+
device=hidden_states.device)
|
385
|
+
token_expert_indicies = torch.empty(M,
|
386
|
+
topk,
|
387
|
+
dtype=torch.int32,
|
388
|
+
device=hidden_states.device)
|
389
|
+
moe_kernels.topk_softmax(
|
390
|
+
topk_weights,
|
391
|
+
topk_ids,
|
392
|
+
token_expert_indicies,
|
393
|
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
394
|
+
)
|
395
|
+
del token_expert_indicies # Not used. Will be used in the future.
|
396
|
+
if renormalize:
|
397
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
398
|
+
|
399
|
+
if override_config:
|
400
|
+
config = override_config
|
401
|
+
else:
|
402
|
+
# First try to load optimal config from the file
|
403
|
+
configs = get_moe_configs(E, w2.shape[2],
|
404
|
+
"float8" if use_fp8 else None)
|
405
|
+
|
406
|
+
if configs:
|
407
|
+
# If an optimal configuration map has been found, look up the
|
408
|
+
# optimal config
|
409
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
410
|
+
else:
|
411
|
+
# Else use the default config
|
412
|
+
config = {
|
413
|
+
"BLOCK_SIZE_M": 128,
|
414
|
+
"BLOCK_SIZE_N": 64,
|
415
|
+
"BLOCK_SIZE_K": 128,
|
416
|
+
"GROUP_SIZE_M": 1,
|
417
|
+
"num_warps": 4,
|
418
|
+
"num_stages": 4
|
419
|
+
}
|
420
|
+
|
421
|
+
if M <= E:
|
422
|
+
config = {
|
423
|
+
"BLOCK_SIZE_M": 128,
|
424
|
+
"BLOCK_SIZE_N": 256,
|
425
|
+
"BLOCK_SIZE_K": 128,
|
426
|
+
"GROUP_SIZE_M": 16,
|
427
|
+
"num_warps": 8,
|
428
|
+
"num_stages": 4
|
429
|
+
}
|
430
|
+
|
431
|
+
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
432
|
+
device=hidden_states.device,
|
433
|
+
dtype=hidden_states.dtype)
|
434
|
+
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
435
|
+
device=hidden_states.device,
|
436
|
+
dtype=hidden_states.dtype)
|
437
|
+
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
438
|
+
device=hidden_states.device,
|
439
|
+
dtype=hidden_states.dtype)
|
440
|
+
|
441
|
+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
442
|
+
topk_ids, config['BLOCK_SIZE_M'], E)
|
443
|
+
compute_type = (tl.bfloat16
|
444
|
+
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
445
|
+
|
446
|
+
invoke_fused_moe_kernel(hidden_states,
|
447
|
+
w1,
|
448
|
+
intermediate_cache1,
|
449
|
+
a1_scale,
|
450
|
+
w1_scale,
|
451
|
+
topk_weights,
|
452
|
+
topk_ids,
|
453
|
+
sorted_token_ids,
|
454
|
+
expert_ids,
|
455
|
+
num_tokens_post_padded,
|
456
|
+
False,
|
457
|
+
topk_ids.shape[1],
|
458
|
+
config,
|
459
|
+
compute_type=compute_type,
|
460
|
+
use_fp8=use_fp8)
|
461
|
+
|
462
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
463
|
+
|
464
|
+
invoke_fused_moe_kernel(intermediate_cache2,
|
465
|
+
w2,
|
466
|
+
intermediate_cache3,
|
467
|
+
a2_scale,
|
468
|
+
w2_scale,
|
469
|
+
topk_weights,
|
470
|
+
topk_ids,
|
471
|
+
sorted_token_ids,
|
472
|
+
expert_ids,
|
473
|
+
num_tokens_post_padded,
|
474
|
+
True,
|
475
|
+
1,
|
476
|
+
config,
|
477
|
+
compute_type=compute_type,
|
478
|
+
use_fp8=use_fp8)
|
479
|
+
|
480
|
+
if inplace:
|
481
|
+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
482
|
+
dim=1,
|
483
|
+
out=hidden_states)
|
484
|
+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
485
|
+
dim=1)
|
@@ -5,7 +5,7 @@ from vllm.distributed import (
|
|
5
5
|
tensor_model_parallel_all_gather,
|
6
6
|
)
|
7
7
|
|
8
|
-
from sglang.srt.managers.
|
8
|
+
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
9
9
|
|
10
10
|
|
11
11
|
class LogitsProcessor(nn.Module):
|
@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
|
|
50
50
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
51
51
|
pt = 0
|
52
52
|
# NOTE: the GPU-CPU overhead can be reduced
|
53
|
-
extend_seq_lens_cpu = input_metadata.extend_seq_lens.
|
54
|
-
for i in
|
55
|
-
if
|
53
|
+
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
54
|
+
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
55
|
+
if extend_seq_len == 0:
|
56
56
|
prefill_top_logprobs.append([])
|
57
57
|
decode_top_logprobs.append([])
|
58
58
|
continue
|
59
59
|
k = input_metadata.top_logprobs_nums[i]
|
60
|
-
t = all_logprobs[pt : pt +
|
60
|
+
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
61
61
|
vs_cpu = t.values.tolist()
|
62
62
|
ps_cpu = t.indices.tolist()
|
63
63
|
prefill_top_logprobs.append(
|
64
64
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
65
65
|
)
|
66
66
|
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
67
|
-
pt +=
|
67
|
+
pt += extend_seq_len
|
68
|
+
|
68
69
|
return prefill_top_logprobs, decode_top_logprobs
|
69
70
|
|
70
71
|
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
|
|
145
146
|
)
|
146
147
|
|
147
148
|
|
148
|
-
|
149
|
+
def test():
|
149
150
|
all_logprobs = torch.tensor(
|
150
151
|
# s s s
|
151
152
|
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
@@ -173,3 +174,7 @@ if __name__ == "__main__":
|
|
173
174
|
print("start", start)
|
174
175
|
print("end", end)
|
175
176
|
print("sum_logp", sum_logp)
|
177
|
+
|
178
|
+
|
179
|
+
if __name__ == "__main__":
|
180
|
+
test()
|
@@ -1,22 +1,26 @@
|
|
1
1
|
import torch
|
2
|
+
import numpy as np
|
2
3
|
from torch import nn
|
3
4
|
|
4
5
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
5
6
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
6
7
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
7
|
-
from sglang.srt.managers.
|
8
|
+
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
8
9
|
|
9
10
|
|
10
11
|
class RadixAttention(nn.Module):
|
11
|
-
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
|
12
|
+
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
|
12
13
|
super().__init__()
|
13
14
|
self.tp_q_head_num = num_heads
|
14
15
|
self.tp_k_head_num = num_kv_heads
|
15
16
|
self.tp_v_head_num = num_kv_heads
|
16
17
|
self.head_dim = head_dim
|
17
18
|
self.layer_id = layer_id
|
19
|
+
self.logit_cap = logit_cap
|
18
20
|
|
19
|
-
|
21
|
+
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
22
|
+
|
23
|
+
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
20
24
|
|
21
25
|
if global_server_args_dict.get("enable_flashinfer", False):
|
22
26
|
self.prefill_forward = self.prefill_forward_flashinfer
|
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
|
|
38
42
|
input_metadata.start_loc,
|
39
43
|
input_metadata.seq_lens,
|
40
44
|
input_metadata.max_seq_len,
|
45
|
+
self.logit_cap,
|
41
46
|
)
|
42
47
|
self.store_kv_cache(k, v, input_metadata)
|
43
48
|
|
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
|
|
62
67
|
input_metadata.extend_seq_lens,
|
63
68
|
input_metadata.max_seq_len,
|
64
69
|
input_metadata.max_extend_len,
|
70
|
+
self.logit_cap,
|
65
71
|
)
|
66
72
|
|
67
73
|
return o
|
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
|
|
82
88
|
input_metadata.max_seq_len,
|
83
89
|
input_metadata.other_kv_index,
|
84
90
|
input_metadata.total_num_tokens,
|
91
|
+
self.logit_cap,
|
85
92
|
)
|
86
93
|
|
87
94
|
return o
|