sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
"""Fused operators for activation layers."""
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from flashinfer.activation import silu_and_mul
|
19
|
+
from vllm.model_executor.custom_op import CustomOp
|
20
|
+
|
21
|
+
|
22
|
+
class SiluAndMul(CustomOp):
|
23
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
24
|
+
d = x.shape[-1] // 2
|
25
|
+
return F.silu(x[..., :d]) * x[..., d:]
|
26
|
+
|
27
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
28
|
+
d = x.shape[-1] // 2
|
29
|
+
output_shape = x.shape[:-1] + (d,)
|
30
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
31
|
+
silu_and_mul(x, out)
|
32
|
+
return out
|
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for decoding.
|
18
|
+
"""
|
19
|
+
|
16
20
|
# Adapted from
|
17
21
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
18
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
|
|
194
198
|
tl.store(out_ptrs, acc)
|
195
199
|
|
196
200
|
|
197
|
-
def
|
201
|
+
def _decode_att_m_fwd(
|
198
202
|
q,
|
199
203
|
k_buffer,
|
200
204
|
att_out,
|
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
|
|
254
258
|
)
|
255
259
|
|
256
260
|
|
257
|
-
def
|
261
|
+
def _decode_softmax_reducev_fwd(
|
258
262
|
logics,
|
259
263
|
v_buffer,
|
260
264
|
o,
|
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
|
|
292
296
|
)
|
293
297
|
|
294
298
|
|
295
|
-
def
|
299
|
+
def decode_attention_fwd(
|
296
300
|
q,
|
297
301
|
k_buffer,
|
298
302
|
v_buffer,
|
@@ -312,7 +316,7 @@ def token_attention_fwd(
|
|
312
316
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
313
317
|
)
|
314
318
|
|
315
|
-
|
319
|
+
_decode_att_m_fwd(
|
316
320
|
q,
|
317
321
|
k_buffer,
|
318
322
|
att_m,
|
@@ -324,7 +328,7 @@ def token_attention_fwd(
|
|
324
328
|
sm_scale,
|
325
329
|
logit_cap,
|
326
330
|
)
|
327
|
-
|
331
|
+
_decode_softmax_reducev_fwd(
|
328
332
|
att_m,
|
329
333
|
v_buffer,
|
330
334
|
o,
|
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for prefill.
|
18
|
+
It supporst page size = 1 and prefill with KV cache (i.e. extend).
|
19
|
+
"""
|
20
|
+
|
16
21
|
import torch
|
17
22
|
import triton
|
18
23
|
import triton.language as tl
|
19
24
|
|
20
|
-
from sglang.srt.layers.
|
25
|
+
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
21
26
|
|
22
27
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
23
28
|
|
@@ -270,7 +275,9 @@ def extend_attention_fwd(
|
|
270
275
|
BLOCK_DPE = 0
|
271
276
|
BLOCK_DV = Lv
|
272
277
|
|
273
|
-
if CUDA_CAPABILITY[0] >=
|
278
|
+
if CUDA_CAPABILITY[0] >= 9:
|
279
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
280
|
+
elif CUDA_CAPABILITY[0] >= 8:
|
274
281
|
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
275
282
|
else:
|
276
283
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
@@ -0,0 +1 @@
|
|
1
|
+
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
@@ -1,20 +1,5 @@
|
|
1
|
-
"""
|
2
|
-
Copyright 2023-2024 SGLang Team
|
3
|
-
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
you may not use this file except in compliance with the License.
|
5
|
-
You may obtain a copy of the License at
|
6
|
-
|
7
|
-
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
|
9
|
-
Unless required by applicable law or agreed to in writing, software
|
10
|
-
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
See the License for the specific language governing permissions and
|
13
|
-
limitations under the License.
|
14
|
-
"""
|
15
|
-
|
16
1
|
# Adapted from
|
17
|
-
# https://github.com/vllm-project/vllm/
|
2
|
+
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
|
18
3
|
"""Fused MoE kernel."""
|
19
4
|
import functools
|
20
5
|
import json
|
@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
|
24
9
|
import torch
|
25
10
|
import triton
|
26
11
|
import triton.language as tl
|
12
|
+
import vllm.envs as envs
|
27
13
|
from vllm import _custom_ops as ops
|
28
14
|
from vllm.logger import init_logger
|
29
15
|
|
@@ -373,6 +359,31 @@ def get_default_config(
|
|
373
359
|
return config
|
374
360
|
|
375
361
|
|
362
|
+
def try_get_optimal_moe_config(
|
363
|
+
w1_shape: Tuple[int, ...],
|
364
|
+
w2_shape: Tuple[int, ...],
|
365
|
+
top_k: int,
|
366
|
+
dtype: Optional[str],
|
367
|
+
M: int,
|
368
|
+
override_config: Optional[Dict[str, Any]] = None,
|
369
|
+
):
|
370
|
+
if override_config:
|
371
|
+
config = override_config
|
372
|
+
else:
|
373
|
+
# First try to load optimal config from the file
|
374
|
+
E, _, N = w2_shape
|
375
|
+
configs = get_moe_configs(E, N, dtype)
|
376
|
+
|
377
|
+
if configs:
|
378
|
+
# If an optimal configuration map has been found, look up the
|
379
|
+
# optimal config
|
380
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
381
|
+
else:
|
382
|
+
# Else use the default config
|
383
|
+
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
|
384
|
+
return config
|
385
|
+
|
386
|
+
|
376
387
|
def fused_topk(
|
377
388
|
hidden_states: torch.Tensor,
|
378
389
|
gating_output: torch.Tensor,
|
@@ -403,6 +414,41 @@ def fused_topk(
|
|
403
414
|
return topk_weights, topk_ids
|
404
415
|
|
405
416
|
|
417
|
+
# This is used by the Deepseek-V2 model
|
418
|
+
def grouped_topk(
|
419
|
+
hidden_states: torch.Tensor,
|
420
|
+
gating_output: torch.Tensor,
|
421
|
+
topk: int,
|
422
|
+
renormalize: bool,
|
423
|
+
num_expert_group: int = 0,
|
424
|
+
topk_group: int = 0,
|
425
|
+
):
|
426
|
+
|
427
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
428
|
+
|
429
|
+
scores = torch.softmax(gating_output, dim=-1)
|
430
|
+
num_token = scores.shape[0]
|
431
|
+
group_scores = (
|
432
|
+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
433
|
+
) # [n, n_group]
|
434
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
435
|
+
1
|
436
|
+
] # [n, top_k_group]
|
437
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
438
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
439
|
+
score_mask = (
|
440
|
+
group_mask.unsqueeze(-1)
|
441
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
442
|
+
.reshape(num_token, -1)
|
443
|
+
) # [n, e]
|
444
|
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
445
|
+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
446
|
+
|
447
|
+
if renormalize:
|
448
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
449
|
+
return topk_weights, topk_ids
|
450
|
+
|
451
|
+
|
406
452
|
def fused_experts(
|
407
453
|
hidden_states: torch.Tensor,
|
408
454
|
w1: torch.Tensor,
|
@@ -425,24 +471,23 @@ def fused_experts(
|
|
425
471
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
426
472
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
427
473
|
|
428
|
-
|
474
|
+
num_tokens, _ = hidden_states.shape
|
429
475
|
E, N, _ = w1.shape
|
476
|
+
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
477
|
+
# https://github.com/vllm-project/vllm/issues/5938
|
478
|
+
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
479
|
+
M = min(num_tokens, CHUNK_SIZE)
|
480
|
+
|
481
|
+
get_config_func = functools.partial(
|
482
|
+
try_get_optimal_moe_config,
|
483
|
+
w1.shape,
|
484
|
+
w2.shape,
|
485
|
+
topk_ids.shape[1],
|
486
|
+
"float8" if use_fp8 else None,
|
487
|
+
override_config=override_config,
|
488
|
+
)
|
430
489
|
|
431
|
-
|
432
|
-
config = override_config
|
433
|
-
else:
|
434
|
-
# First try to load optimal config from the file
|
435
|
-
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
436
|
-
|
437
|
-
if configs:
|
438
|
-
# If an optimal configuration map has been found, look up the
|
439
|
-
# optimal config
|
440
|
-
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
441
|
-
else:
|
442
|
-
# Else use the default config
|
443
|
-
config = get_default_config(
|
444
|
-
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
445
|
-
)
|
490
|
+
config = get_config_func(M)
|
446
491
|
|
447
492
|
intermediate_cache1 = torch.empty(
|
448
493
|
(M, topk_ids.shape[1], N),
|
@@ -460,56 +505,85 @@ def fused_experts(
|
|
460
505
|
dtype=hidden_states.dtype,
|
461
506
|
)
|
462
507
|
|
463
|
-
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
464
|
-
topk_ids, config["BLOCK_SIZE_M"], E
|
465
|
-
)
|
466
508
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
467
509
|
|
468
|
-
|
469
|
-
hidden_states
|
470
|
-
|
471
|
-
|
472
|
-
a1_scale,
|
473
|
-
w1_scale,
|
474
|
-
topk_weights,
|
475
|
-
topk_ids,
|
476
|
-
sorted_token_ids,
|
477
|
-
expert_ids,
|
478
|
-
num_tokens_post_padded,
|
479
|
-
False,
|
480
|
-
topk_ids.shape[1],
|
481
|
-
config,
|
482
|
-
compute_type=compute_type,
|
483
|
-
use_fp8=use_fp8,
|
484
|
-
)
|
510
|
+
if inplace:
|
511
|
+
out_hidden_states = hidden_states
|
512
|
+
else:
|
513
|
+
out_hidden_states = torch.empty_like(hidden_states)
|
485
514
|
|
486
|
-
|
515
|
+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
516
|
+
begin_chunk_idx, end_chunk_idx = (
|
517
|
+
chunk * CHUNK_SIZE,
|
518
|
+
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
519
|
+
)
|
520
|
+
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
521
|
+
tokens_in_chunk, _ = curr_hidden_states.shape
|
522
|
+
|
523
|
+
if tokens_in_chunk == 0:
|
524
|
+
break
|
525
|
+
|
526
|
+
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
527
|
+
# Adjust the intermediate cache size and config for the last
|
528
|
+
# chunk. Note that in most cases we only have one chunk
|
529
|
+
# so the cache size and config are already set correctly and
|
530
|
+
# do not need to be adjusted.
|
531
|
+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
532
|
+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
533
|
+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
534
|
+
config = get_config_func(tokens_in_chunk)
|
535
|
+
|
536
|
+
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
537
|
+
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
538
|
+
|
539
|
+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
540
|
+
curr_topk_ids, config["BLOCK_SIZE_M"], E
|
541
|
+
)
|
487
542
|
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
543
|
+
invoke_fused_moe_kernel(
|
544
|
+
curr_hidden_states,
|
545
|
+
w1,
|
546
|
+
intermediate_cache1,
|
547
|
+
a1_scale,
|
548
|
+
w1_scale,
|
549
|
+
curr_topk_weights,
|
550
|
+
curr_topk_ids,
|
551
|
+
sorted_token_ids,
|
552
|
+
expert_ids,
|
553
|
+
num_tokens_post_padded,
|
554
|
+
False,
|
555
|
+
topk_ids.shape[1],
|
556
|
+
config,
|
557
|
+
compute_type=compute_type,
|
558
|
+
use_fp8=use_fp8,
|
559
|
+
)
|
505
560
|
|
506
|
-
|
507
|
-
|
561
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
562
|
+
|
563
|
+
invoke_fused_moe_kernel(
|
564
|
+
intermediate_cache2,
|
565
|
+
w2,
|
566
|
+
intermediate_cache3,
|
567
|
+
a2_scale,
|
568
|
+
w2_scale,
|
569
|
+
curr_topk_weights,
|
570
|
+
curr_topk_ids,
|
571
|
+
sorted_token_ids,
|
572
|
+
expert_ids,
|
573
|
+
num_tokens_post_padded,
|
574
|
+
True,
|
575
|
+
1,
|
576
|
+
config,
|
577
|
+
compute_type=compute_type,
|
578
|
+
use_fp8=use_fp8,
|
579
|
+
)
|
580
|
+
|
581
|
+
torch.sum(
|
508
582
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
509
583
|
dim=1,
|
510
|
-
out=
|
584
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
511
585
|
)
|
512
|
-
return
|
586
|
+
return out_hidden_states
|
513
587
|
|
514
588
|
|
515
589
|
def fused_moe(
|
@@ -521,6 +595,9 @@ def fused_moe(
|
|
521
595
|
renormalize: bool,
|
522
596
|
inplace: bool = False,
|
523
597
|
override_config: Optional[Dict[str, Any]] = None,
|
598
|
+
use_grouped_topk: bool = False,
|
599
|
+
num_expert_group: Optional[int] = None,
|
600
|
+
topk_group: Optional[int] = None,
|
524
601
|
use_fp8: bool = False,
|
525
602
|
w1_scale: Optional[torch.Tensor] = None,
|
526
603
|
w2_scale: Optional[torch.Tensor] = None,
|
@@ -543,6 +620,10 @@ def fused_moe(
|
|
543
620
|
Defaults to False.
|
544
621
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
545
622
|
for the kernel configuration.
|
623
|
+
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
624
|
+
- topk_group: Optional[int]: additional parameter for grouped_topk
|
625
|
+
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
626
|
+
note: Deepseekv2 model uses grouped_topk
|
546
627
|
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
547
628
|
products for w1 and w2. Defaults to False.
|
548
629
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
@@ -556,12 +637,18 @@ def fused_moe(
|
|
556
637
|
# Check constraints.
|
557
638
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
558
639
|
|
559
|
-
if
|
560
|
-
|
561
|
-
|
640
|
+
if use_grouped_topk:
|
641
|
+
assert num_expert_group is not None and topk_group is not None
|
642
|
+
topk_weights, topk_ids = grouped_topk(
|
643
|
+
hidden_states,
|
644
|
+
gating_output,
|
645
|
+
topk,
|
646
|
+
renormalize,
|
647
|
+
num_expert_group,
|
648
|
+
topk_group,
|
562
649
|
)
|
563
650
|
else:
|
564
|
-
topk_weights, topk_ids =
|
651
|
+
topk_weights, topk_ids = fused_topk(
|
565
652
|
hidden_states, gating_output, topk, renormalize
|
566
653
|
)
|
567
654
|
|
@@ -579,33 +666,3 @@ def fused_moe(
|
|
579
666
|
a1_scale=a1_scale,
|
580
667
|
a2_scale=a2_scale,
|
581
668
|
)
|
582
|
-
|
583
|
-
|
584
|
-
def fused_topk_v0_4_3(
|
585
|
-
hidden_states: torch.Tensor,
|
586
|
-
gating_output: torch.Tensor,
|
587
|
-
topk: int,
|
588
|
-
renormalize: bool,
|
589
|
-
):
|
590
|
-
import vllm._moe_C as moe_kernels
|
591
|
-
|
592
|
-
M, _ = hidden_states.shape
|
593
|
-
|
594
|
-
topk_weights = torch.empty(
|
595
|
-
M, topk, dtype=torch.float32, device=hidden_states.device
|
596
|
-
)
|
597
|
-
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
598
|
-
token_expert_indicies = torch.empty(
|
599
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
600
|
-
)
|
601
|
-
moe_kernels.topk_softmax(
|
602
|
-
topk_weights,
|
603
|
-
topk_ids,
|
604
|
-
token_expert_indicies,
|
605
|
-
gating_output.float(), # TODO(woosuk): Optimize this.
|
606
|
-
)
|
607
|
-
del token_expert_indicies # Not used. Will be used in the future.
|
608
|
-
if renormalize:
|
609
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
610
|
-
|
611
|
-
return topk_weights, topk_ids
|