sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/srt/layers/fused_moe.py
CHANGED
@@ -9,10 +9,8 @@ from typing import Any, Dict, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
import triton
|
11
11
|
import triton.language as tl
|
12
|
-
|
13
12
|
from vllm import _custom_ops as ops
|
14
13
|
from vllm.logger import init_logger
|
15
|
-
from vllm.utils import is_hip
|
16
14
|
|
17
15
|
logger = init_logger(__name__)
|
18
16
|
|
@@ -109,12 +107,16 @@ def fused_moe_kernel(
|
|
109
107
|
|
110
108
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
111
109
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
112
|
-
a_ptrs = a_ptr + (
|
113
|
-
|
110
|
+
a_ptrs = a_ptr + (
|
111
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
112
|
+
)
|
114
113
|
|
115
114
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
116
|
-
b_ptrs =
|
117
|
-
|
115
|
+
b_ptrs = (
|
116
|
+
b_ptr
|
117
|
+
+ off_experts * stride_be
|
118
|
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
119
|
+
)
|
118
120
|
|
119
121
|
if use_fp8:
|
120
122
|
a_scale = tl.load(a_scale_ptr)
|
@@ -130,13 +132,12 @@ def fused_moe_kernel(
|
|
130
132
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
131
133
|
# Load the next block of A and B, generate a mask by checking the
|
132
134
|
# K dimension.
|
133
|
-
a = tl.load(
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
other=0.0)
|
135
|
+
a = tl.load(
|
136
|
+
a_ptrs,
|
137
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
138
|
+
other=0.0,
|
139
|
+
)
|
140
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
140
141
|
# We accumulate along the K dimension.
|
141
142
|
if use_fp8:
|
142
143
|
accumulator = tl.dot(a, b, acc=accumulator)
|
@@ -147,9 +148,7 @@ def fused_moe_kernel(
|
|
147
148
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
148
149
|
|
149
150
|
if MUL_ROUTED_WEIGHT:
|
150
|
-
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
151
|
-
mask=token_mask,
|
152
|
-
other=0)
|
151
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
153
152
|
accumulator = accumulator * moe_weight[:, None]
|
154
153
|
|
155
154
|
if use_fp8:
|
@@ -159,15 +158,14 @@ def fused_moe_kernel(
|
|
159
158
|
# -----------------------------------------------------------
|
160
159
|
# Write back the block of the output
|
161
160
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
162
|
-
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
163
|
-
None, :]
|
161
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
164
162
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
165
163
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
166
164
|
|
167
165
|
|
168
166
|
def moe_align_block_size(
|
169
|
-
|
170
|
-
|
167
|
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
168
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
171
169
|
"""
|
172
170
|
Aligns the token distribution across experts to be compatible with block
|
173
171
|
size for matrix multiplication.
|
@@ -206,32 +204,38 @@ def moe_align_block_size(
|
|
206
204
|
by block_size for proper block matrix operations.
|
207
205
|
"""
|
208
206
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
209
|
-
sorted_ids = torch.empty(
|
210
|
-
|
211
|
-
|
207
|
+
sorted_ids = torch.empty(
|
208
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
209
|
+
)
|
212
210
|
sorted_ids.fill_(topk_ids.numel())
|
213
211
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
214
|
-
expert_ids = torch.empty(
|
215
|
-
|
216
|
-
|
217
|
-
num_tokens_post_pad = torch.empty((1),
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
expert_ids, num_tokens_post_pad)
|
212
|
+
expert_ids = torch.empty(
|
213
|
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
214
|
+
)
|
215
|
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
216
|
+
ops.moe_align_block_size(
|
217
|
+
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
218
|
+
)
|
222
219
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
223
220
|
|
224
221
|
|
225
|
-
def invoke_fused_moe_kernel(
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
222
|
+
def invoke_fused_moe_kernel(
|
223
|
+
A: torch.Tensor,
|
224
|
+
B: torch.Tensor,
|
225
|
+
C: torch.Tensor,
|
226
|
+
A_scale: Optional[torch.Tensor],
|
227
|
+
B_scale: Optional[torch.Tensor],
|
228
|
+
topk_weights: torch.Tensor,
|
229
|
+
topk_ids: torch.Tensor,
|
230
|
+
sorted_token_ids: torch.Tensor,
|
231
|
+
expert_ids: torch.Tensor,
|
232
|
+
num_tokens_post_padded: torch.Tensor,
|
233
|
+
mul_routed_weight: bool,
|
234
|
+
top_k: int,
|
235
|
+
config: Dict[str, Any],
|
236
|
+
compute_type: tl.dtype,
|
237
|
+
use_fp8: bool,
|
238
|
+
) -> None:
|
235
239
|
assert topk_weights.stride(1) == 1
|
236
240
|
assert sorted_token_ids.stride(0) == 1
|
237
241
|
|
@@ -242,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|
242
246
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
243
247
|
assert B_scale is not None
|
244
248
|
|
245
|
-
grid = lambda META: (
|
246
|
-
|
249
|
+
grid = lambda META: (
|
250
|
+
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
251
|
+
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
252
|
+
)
|
247
253
|
|
248
254
|
fused_moe_kernel[grid](
|
249
255
|
A,
|
@@ -281,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
|
281
287
|
|
282
288
|
|
283
289
|
@functools.lru_cache
|
284
|
-
def get_moe_configs(E: int, N: int,
|
285
|
-
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
290
|
+
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
286
291
|
"""
|
287
292
|
Return optimized configurations for the fused MoE kernel.
|
288
293
|
|
@@ -297,11 +302,11 @@ def get_moe_configs(E: int, N: int,
|
|
297
302
|
json_file_name = get_config_file_name(E, N, dtype)
|
298
303
|
|
299
304
|
config_file_path = os.path.join(
|
300
|
-
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
305
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
306
|
+
)
|
301
307
|
if os.path.exists(config_file_path):
|
302
308
|
with open(config_file_path) as f:
|
303
|
-
logger.info("Using configuration from %s for MoE layer.",
|
304
|
-
config_file_path)
|
309
|
+
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
305
310
|
# If a configuration has been found, return it
|
306
311
|
return {int(key): val for key, val in json.load(f).items()}
|
307
312
|
|
@@ -310,6 +315,188 @@ def get_moe_configs(E: int, N: int,
|
|
310
315
|
return None
|
311
316
|
|
312
317
|
|
318
|
+
def get_default_config(
|
319
|
+
M: int,
|
320
|
+
E: int,
|
321
|
+
N: int,
|
322
|
+
K: int,
|
323
|
+
topk: int,
|
324
|
+
dtype: Optional[str],
|
325
|
+
) -> Dict[str, int]:
|
326
|
+
if dtype == "float8":
|
327
|
+
config = {
|
328
|
+
"BLOCK_SIZE_M": 128,
|
329
|
+
"BLOCK_SIZE_N": 256,
|
330
|
+
"BLOCK_SIZE_K": 128,
|
331
|
+
"GROUP_SIZE_M": 32,
|
332
|
+
"num_warps": 8,
|
333
|
+
"num_stages": 4,
|
334
|
+
}
|
335
|
+
if M <= E:
|
336
|
+
config = {
|
337
|
+
"BLOCK_SIZE_M": 64,
|
338
|
+
"BLOCK_SIZE_N": 128,
|
339
|
+
"BLOCK_SIZE_K": 128,
|
340
|
+
"GROUP_SIZE_M": 1,
|
341
|
+
"num_warps": 4,
|
342
|
+
"num_stages": 4,
|
343
|
+
}
|
344
|
+
else:
|
345
|
+
config = {
|
346
|
+
"BLOCK_SIZE_M": 64,
|
347
|
+
"BLOCK_SIZE_N": 64,
|
348
|
+
"BLOCK_SIZE_K": 32,
|
349
|
+
"GROUP_SIZE_M": 8,
|
350
|
+
}
|
351
|
+
if M <= E:
|
352
|
+
config = {
|
353
|
+
"BLOCK_SIZE_M": 16,
|
354
|
+
"BLOCK_SIZE_N": 32,
|
355
|
+
"BLOCK_SIZE_K": 64,
|
356
|
+
"GROUP_SIZE_M": 1,
|
357
|
+
}
|
358
|
+
return config
|
359
|
+
|
360
|
+
|
361
|
+
def fused_topk(
|
362
|
+
hidden_states: torch.Tensor,
|
363
|
+
gating_output: torch.Tensor,
|
364
|
+
topk: int,
|
365
|
+
renormalize: bool,
|
366
|
+
):
|
367
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
368
|
+
|
369
|
+
M, _ = hidden_states.shape
|
370
|
+
|
371
|
+
topk_weights = torch.empty(
|
372
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
373
|
+
)
|
374
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
375
|
+
token_expert_indicies = torch.empty(
|
376
|
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
377
|
+
)
|
378
|
+
ops.topk_softmax(
|
379
|
+
topk_weights,
|
380
|
+
topk_ids,
|
381
|
+
token_expert_indicies,
|
382
|
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
383
|
+
)
|
384
|
+
del token_expert_indicies # Not used. Will be used in the future.
|
385
|
+
|
386
|
+
if renormalize:
|
387
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
388
|
+
return topk_weights, topk_ids
|
389
|
+
|
390
|
+
|
391
|
+
def fused_experts(
|
392
|
+
hidden_states: torch.Tensor,
|
393
|
+
w1: torch.Tensor,
|
394
|
+
w2: torch.Tensor,
|
395
|
+
topk_weights: torch.Tensor,
|
396
|
+
topk_ids: torch.Tensor,
|
397
|
+
inplace: bool = False,
|
398
|
+
override_config: Optional[Dict[str, Any]] = None,
|
399
|
+
use_fp8: bool = False,
|
400
|
+
w1_scale: Optional[torch.Tensor] = None,
|
401
|
+
w2_scale: Optional[torch.Tensor] = None,
|
402
|
+
a1_scale: Optional[torch.Tensor] = None,
|
403
|
+
a2_scale: Optional[torch.Tensor] = None,
|
404
|
+
):
|
405
|
+
# Check constraints.
|
406
|
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
407
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
408
|
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
409
|
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
410
|
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
411
|
+
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
412
|
+
|
413
|
+
M, _ = hidden_states.shape
|
414
|
+
E, N, _ = w1.shape
|
415
|
+
|
416
|
+
if override_config:
|
417
|
+
config = override_config
|
418
|
+
else:
|
419
|
+
# First try to load optimal config from the file
|
420
|
+
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
421
|
+
|
422
|
+
if configs:
|
423
|
+
# If an optimal configuration map has been found, look up the
|
424
|
+
# optimal config
|
425
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
426
|
+
else:
|
427
|
+
# Else use the default config
|
428
|
+
config = get_default_config(
|
429
|
+
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
430
|
+
)
|
431
|
+
|
432
|
+
intermediate_cache1 = torch.empty(
|
433
|
+
(M, topk_ids.shape[1], N),
|
434
|
+
device=hidden_states.device,
|
435
|
+
dtype=hidden_states.dtype,
|
436
|
+
)
|
437
|
+
intermediate_cache2 = torch.empty(
|
438
|
+
(M * topk_ids.shape[1], N // 2),
|
439
|
+
device=hidden_states.device,
|
440
|
+
dtype=hidden_states.dtype,
|
441
|
+
)
|
442
|
+
intermediate_cache3 = torch.empty(
|
443
|
+
(M, topk_ids.shape[1], w2.shape[1]),
|
444
|
+
device=hidden_states.device,
|
445
|
+
dtype=hidden_states.dtype,
|
446
|
+
)
|
447
|
+
|
448
|
+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
449
|
+
topk_ids, config["BLOCK_SIZE_M"], E
|
450
|
+
)
|
451
|
+
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
452
|
+
|
453
|
+
invoke_fused_moe_kernel(
|
454
|
+
hidden_states,
|
455
|
+
w1,
|
456
|
+
intermediate_cache1,
|
457
|
+
a1_scale,
|
458
|
+
w1_scale,
|
459
|
+
topk_weights,
|
460
|
+
topk_ids,
|
461
|
+
sorted_token_ids,
|
462
|
+
expert_ids,
|
463
|
+
num_tokens_post_padded,
|
464
|
+
False,
|
465
|
+
topk_ids.shape[1],
|
466
|
+
config,
|
467
|
+
compute_type=compute_type,
|
468
|
+
use_fp8=use_fp8,
|
469
|
+
)
|
470
|
+
|
471
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
472
|
+
|
473
|
+
invoke_fused_moe_kernel(
|
474
|
+
intermediate_cache2,
|
475
|
+
w2,
|
476
|
+
intermediate_cache3,
|
477
|
+
a2_scale,
|
478
|
+
w2_scale,
|
479
|
+
topk_weights,
|
480
|
+
topk_ids,
|
481
|
+
sorted_token_ids,
|
482
|
+
expert_ids,
|
483
|
+
num_tokens_post_padded,
|
484
|
+
True,
|
485
|
+
1,
|
486
|
+
config,
|
487
|
+
compute_type=compute_type,
|
488
|
+
use_fp8=use_fp8,
|
489
|
+
)
|
490
|
+
|
491
|
+
if inplace:
|
492
|
+
return torch.sum(
|
493
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
494
|
+
dim=1,
|
495
|
+
out=hidden_states,
|
496
|
+
)
|
497
|
+
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
498
|
+
|
499
|
+
|
313
500
|
def fused_moe(
|
314
501
|
hidden_states: torch.Tensor,
|
315
502
|
w1: torch.Tensor,
|
@@ -352,134 +539,58 @@ def fused_moe(
|
|
352
539
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
353
540
|
"""
|
354
541
|
# Check constraints.
|
355
|
-
assert hidden_states.shape[0] == gating_output.shape[0], (
|
356
|
-
"Number of tokens mismatch")
|
357
|
-
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
358
542
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
359
|
-
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
360
|
-
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
361
|
-
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
362
|
-
assert hidden_states.dtype in [
|
363
|
-
torch.float32, torch.float16, torch.bfloat16
|
364
|
-
]
|
365
|
-
M, _ = hidden_states.shape
|
366
|
-
E, N, _ = w1.shape
|
367
543
|
|
368
|
-
if
|
369
|
-
|
370
|
-
|
371
|
-
dim=-1,
|
372
|
-
dtype=torch.float32)
|
373
|
-
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
374
|
-
else:
|
375
|
-
import vllm._moe_C as moe_kernels
|
376
|
-
|
377
|
-
topk_weights = torch.empty(M,
|
378
|
-
topk,
|
379
|
-
dtype=torch.float32,
|
380
|
-
device=hidden_states.device)
|
381
|
-
topk_ids = torch.empty(M,
|
382
|
-
topk,
|
383
|
-
dtype=torch.int32,
|
384
|
-
device=hidden_states.device)
|
385
|
-
token_expert_indicies = torch.empty(M,
|
386
|
-
topk,
|
387
|
-
dtype=torch.int32,
|
388
|
-
device=hidden_states.device)
|
389
|
-
moe_kernels.topk_softmax(
|
390
|
-
topk_weights,
|
391
|
-
topk_ids,
|
392
|
-
token_expert_indicies,
|
393
|
-
gating_output.float(), # TODO(woosuk): Optimize this.
|
544
|
+
if hasattr(ops, "topk_softmax"):
|
545
|
+
topk_weights, topk_ids = fused_topk(
|
546
|
+
hidden_states, gating_output, topk, renormalize
|
394
547
|
)
|
395
|
-
del token_expert_indicies # Not used. Will be used in the future.
|
396
|
-
if renormalize:
|
397
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
398
|
-
|
399
|
-
if override_config:
|
400
|
-
config = override_config
|
401
548
|
else:
|
402
|
-
|
403
|
-
|
404
|
-
|
549
|
+
topk_weights, topk_ids = fused_topk_v0_4_3(
|
550
|
+
hidden_states, gating_output, topk, renormalize
|
551
|
+
)
|
405
552
|
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
553
|
+
return fused_experts(
|
554
|
+
hidden_states,
|
555
|
+
w1,
|
556
|
+
w2,
|
557
|
+
topk_weights,
|
558
|
+
topk_ids,
|
559
|
+
inplace=inplace,
|
560
|
+
override_config=override_config,
|
561
|
+
use_fp8=use_fp8,
|
562
|
+
w1_scale=w1_scale,
|
563
|
+
w2_scale=w2_scale,
|
564
|
+
a1_scale=a1_scale,
|
565
|
+
a2_scale=a2_scale,
|
566
|
+
)
|
420
567
|
|
421
|
-
if M <= E:
|
422
|
-
config = {
|
423
|
-
"BLOCK_SIZE_M": 128,
|
424
|
-
"BLOCK_SIZE_N": 256,
|
425
|
-
"BLOCK_SIZE_K": 128,
|
426
|
-
"GROUP_SIZE_M": 16,
|
427
|
-
"num_warps": 8,
|
428
|
-
"num_stages": 4
|
429
|
-
}
|
430
|
-
|
431
|
-
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
432
|
-
device=hidden_states.device,
|
433
|
-
dtype=hidden_states.dtype)
|
434
|
-
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
435
|
-
device=hidden_states.device,
|
436
|
-
dtype=hidden_states.dtype)
|
437
|
-
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
438
|
-
device=hidden_states.device,
|
439
|
-
dtype=hidden_states.dtype)
|
440
568
|
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
intermediate_cache1,
|
449
|
-
a1_scale,
|
450
|
-
w1_scale,
|
451
|
-
topk_weights,
|
452
|
-
topk_ids,
|
453
|
-
sorted_token_ids,
|
454
|
-
expert_ids,
|
455
|
-
num_tokens_post_padded,
|
456
|
-
False,
|
457
|
-
topk_ids.shape[1],
|
458
|
-
config,
|
459
|
-
compute_type=compute_type,
|
460
|
-
use_fp8=use_fp8)
|
569
|
+
def fused_topk_v0_4_3(
|
570
|
+
hidden_states: torch.Tensor,
|
571
|
+
gating_output: torch.Tensor,
|
572
|
+
topk: int,
|
573
|
+
renormalize: bool,
|
574
|
+
):
|
575
|
+
import vllm._moe_C as moe_kernels
|
461
576
|
|
462
|
-
|
577
|
+
M, _ = hidden_states.shape
|
463
578
|
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
579
|
+
topk_weights = torch.empty(
|
580
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
581
|
+
)
|
582
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
583
|
+
token_expert_indicies = torch.empty(
|
584
|
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
585
|
+
)
|
586
|
+
moe_kernels.topk_softmax(
|
587
|
+
topk_weights,
|
588
|
+
topk_ids,
|
589
|
+
token_expert_indicies,
|
590
|
+
gating_output.float(), # TODO(woosuk): Optimize this.
|
591
|
+
)
|
592
|
+
del token_expert_indicies # Not used. Will be used in the future.
|
593
|
+
if renormalize:
|
594
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
479
595
|
|
480
|
-
|
481
|
-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
482
|
-
dim=1,
|
483
|
-
out=hidden_states)
|
484
|
-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
485
|
-
dim=1)
|
596
|
+
return topk_weights, topk_ids
|