sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -413,18 +413,37 @@ def fused_moe_kernel(
|
|
413
413
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
414
414
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
415
415
|
return
|
416
|
-
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
416
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
417
417
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
418
418
|
offs_token = offs_token.to(tl.int64)
|
419
419
|
token_mask = offs_token < num_valid_tokens
|
420
420
|
|
421
|
-
|
421
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
422
|
+
|
423
|
+
if off_experts == -1:
|
424
|
+
# -----------------------------------------------------------
|
425
|
+
# Write back zeros to the output when the expert is not
|
426
|
+
# in the current expert parallel rank.
|
427
|
+
write_zeros_to_output(
|
428
|
+
c_ptr,
|
429
|
+
stride_cm,
|
430
|
+
stride_cn,
|
431
|
+
pid_n,
|
432
|
+
N,
|
433
|
+
offs_token,
|
434
|
+
token_mask,
|
435
|
+
BLOCK_SIZE_M,
|
436
|
+
BLOCK_SIZE_N,
|
437
|
+
compute_type,
|
438
|
+
)
|
439
|
+
return
|
440
|
+
|
441
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
422
442
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
423
443
|
a_ptrs = a_ptr + (
|
424
444
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
425
445
|
)
|
426
446
|
|
427
|
-
off_experts = tl.load(expert_ids_ptr + pid_m)
|
428
447
|
b_ptrs = (
|
429
448
|
b_ptr
|
430
449
|
+ off_experts * stride_be
|
@@ -497,7 +516,6 @@ def fused_moe_kernel(
|
|
497
516
|
|
498
517
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
499
518
|
else:
|
500
|
-
# fix out of shared memory issue
|
501
519
|
if use_fp8_w8a8:
|
502
520
|
accumulator = tl.dot(a, b, acc=accumulator)
|
503
521
|
else:
|
@@ -568,7 +586,7 @@ def moe_align_block_size(
|
|
568
586
|
- The padding ensures that the total number of tokens is now divisible
|
569
587
|
by block_size for proper block matrix operations.
|
570
588
|
"""
|
571
|
-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
589
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
572
590
|
sorted_ids = torch.empty(
|
573
591
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
574
592
|
)
|
@@ -578,13 +596,9 @@ def moe_align_block_size(
|
|
578
596
|
)
|
579
597
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
580
598
|
|
599
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
581
600
|
cumsum_buffer = torch.empty(
|
582
|
-
(num_experts +
|
583
|
-
)
|
584
|
-
token_cnts_buffer = torch.empty(
|
585
|
-
(num_experts + 1) * num_experts,
|
586
|
-
dtype=torch.int32,
|
587
|
-
device=topk_ids.device,
|
601
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
588
602
|
)
|
589
603
|
|
590
604
|
# Threshold based on benchmark results
|
@@ -594,12 +608,11 @@ def moe_align_block_size(
|
|
594
608
|
|
595
609
|
sgl_moe_align_block_size(
|
596
610
|
topk_ids,
|
597
|
-
num_experts,
|
611
|
+
num_experts + 1,
|
598
612
|
block_size,
|
599
613
|
sorted_ids,
|
600
614
|
expert_ids,
|
601
615
|
num_tokens_post_pad,
|
602
|
-
token_cnts_buffer,
|
603
616
|
cumsum_buffer,
|
604
617
|
fuse_sorted_ids_padding,
|
605
618
|
)
|
@@ -1,17 +1,25 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
+
import importlib.util
|
3
4
|
import logging
|
4
5
|
from enum import Enum
|
6
|
+
from functools import lru_cache
|
5
7
|
from typing import List, Optional, Tuple
|
6
8
|
|
7
9
|
import torch
|
10
|
+
from packaging import version as pkg_version
|
8
11
|
|
9
12
|
from sglang.srt.distributed import (
|
13
|
+
get_moe_expert_parallel_rank,
|
14
|
+
get_moe_expert_parallel_world_size,
|
15
|
+
get_moe_tensor_parallel_rank,
|
16
|
+
get_moe_tensor_parallel_world_size,
|
10
17
|
get_tensor_model_parallel_rank,
|
11
18
|
get_tensor_model_parallel_world_size,
|
12
19
|
tensor_model_parallel_all_reduce,
|
13
20
|
)
|
14
|
-
from sglang.srt.
|
21
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
22
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
15
23
|
from sglang.srt.layers.quantization.base_config import (
|
16
24
|
QuantizationConfig,
|
17
25
|
QuantizeMethodBase,
|
@@ -28,6 +36,15 @@ _is_cpu = is_cpu()
|
|
28
36
|
logger = logging.getLogger(__name__)
|
29
37
|
|
30
38
|
|
39
|
+
@lru_cache(maxsize=1)
|
40
|
+
def should_use_flashinfer_trtllm_moe():
|
41
|
+
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
42
|
+
not importlib.util.find_spec("flashinfer")
|
43
|
+
or pkg_version.parse(__import__("flashinfer").__version__)
|
44
|
+
>= pkg_version.parse("0.2.9rc1")
|
45
|
+
)
|
46
|
+
|
47
|
+
|
31
48
|
class FusedMoeWeightScaleSupported(Enum):
|
32
49
|
TENSOR = "tensor"
|
33
50
|
CHANNEL = "channel"
|
@@ -62,8 +79,9 @@ class FusedMoE(torch.nn.Module):
|
|
62
79
|
num_experts: int,
|
63
80
|
hidden_size: int,
|
64
81
|
intermediate_size: int,
|
82
|
+
layer_id: int,
|
65
83
|
top_k: Optional[int] = None,
|
66
|
-
|
84
|
+
num_fused_shared_experts: int = 0,
|
67
85
|
params_dtype: Optional[torch.dtype] = None,
|
68
86
|
reduce_results: bool = False,
|
69
87
|
quant_config: Optional[QuantizationConfig] = None,
|
@@ -77,21 +95,19 @@ class FusedMoE(torch.nn.Module):
|
|
77
95
|
routed_scaling_factor: Optional[float] = None,
|
78
96
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
79
97
|
enable_ep_moe: Optional[bool] = False,
|
80
|
-
skip_quant: Optional[bool] = False,
|
81
98
|
):
|
82
99
|
super().__init__()
|
83
100
|
|
84
101
|
if params_dtype is None:
|
85
102
|
params_dtype = torch.get_default_dtype()
|
86
103
|
|
104
|
+
self.layer_id = layer_id
|
87
105
|
self.top_k = top_k
|
88
106
|
self.hidden_size = hidden_size
|
89
|
-
self.tp_size = (
|
90
|
-
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
91
|
-
)
|
92
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
93
107
|
self.num_experts = num_experts
|
94
|
-
self.
|
108
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
109
|
+
self.expert_map_cpu = None
|
110
|
+
self.expert_map_gpu = None
|
95
111
|
|
96
112
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
97
113
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
@@ -99,28 +115,28 @@ class FusedMoE(torch.nn.Module):
|
|
99
115
|
enable_ep_moe = False
|
100
116
|
|
101
117
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
118
|
+
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
119
|
+
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
120
|
+
self.moe_tp_size = get_moe_tensor_parallel_world_size()
|
121
|
+
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
122
|
+
assert num_experts % self.moe_ep_size == 0
|
123
|
+
self.num_local_experts = num_experts // self.moe_ep_size
|
102
124
|
if enable_ep_moe:
|
103
|
-
|
104
|
-
self.ep_rank = self.tp_rank
|
105
|
-
self.tp_size = 1
|
106
|
-
self.tp_rank = 0
|
125
|
+
# TODO(ch-wan): support shared experts fusion
|
107
126
|
# Create a tensor of size num_experts filled with -1
|
108
|
-
self.
|
127
|
+
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
109
128
|
# Create a expert map for the local experts
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
self.ep_rank
|
114
|
-
* self.num_local_experts : (self.ep_rank + 1)
|
129
|
+
self.expert_map_cpu[
|
130
|
+
self.moe_ep_rank
|
131
|
+
* self.num_local_experts : (self.moe_ep_rank + 1)
|
115
132
|
* self.num_local_experts
|
116
133
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
self.num_local_experts = num_experts
|
134
|
+
if not self.enable_flashinfer_cutlass_moe:
|
135
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
136
|
+
|
121
137
|
self.routed_scaling_factor = routed_scaling_factor
|
122
|
-
assert intermediate_size % self.
|
123
|
-
self.intermediate_size_per_partition = intermediate_size // self.
|
138
|
+
assert intermediate_size % self.moe_tp_size == 0
|
139
|
+
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
124
140
|
self.reduce_results = reduce_results
|
125
141
|
self.activation = activation
|
126
142
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
@@ -132,9 +148,6 @@ class FusedMoE(torch.nn.Module):
|
|
132
148
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
133
149
|
)
|
134
150
|
|
135
|
-
if skip_quant:
|
136
|
-
return
|
137
|
-
|
138
151
|
if quant_config is None:
|
139
152
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
140
153
|
self.use_triton_kernels
|
@@ -363,9 +376,9 @@ class FusedMoE(torch.nn.Module):
|
|
363
376
|
expert_data.copy_(loaded_weight)
|
364
377
|
|
365
378
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
366
|
-
if self.
|
379
|
+
if self.expert_map_cpu is None:
|
367
380
|
return expert_id
|
368
|
-
return self.
|
381
|
+
return self.expert_map_cpu[expert_id].item()
|
369
382
|
|
370
383
|
def weight_loader(
|
371
384
|
self,
|
@@ -375,10 +388,48 @@ class FusedMoE(torch.nn.Module):
|
|
375
388
|
shard_id: str,
|
376
389
|
expert_id: int,
|
377
390
|
) -> None:
|
391
|
+
|
392
|
+
global_expert_location_metadata = get_global_expert_location_metadata()
|
393
|
+
if global_expert_location_metadata is None:
|
394
|
+
self._weight_loader_impl(
|
395
|
+
param=param,
|
396
|
+
loaded_weight=loaded_weight,
|
397
|
+
weight_name=weight_name,
|
398
|
+
shard_id=shard_id,
|
399
|
+
expert_id=expert_id,
|
400
|
+
)
|
401
|
+
return
|
402
|
+
|
403
|
+
if expert_id >= self.num_experts - self.num_fused_shared_experts:
|
404
|
+
# This is a shared expert.
|
405
|
+
physical_expert_ids = [expert_id]
|
406
|
+
else:
|
407
|
+
physical_expert_ids = (
|
408
|
+
global_expert_location_metadata.logical_to_all_physical(
|
409
|
+
self.layer_id, expert_id
|
410
|
+
)
|
411
|
+
)
|
412
|
+
|
413
|
+
for physical_expert_id in physical_expert_ids:
|
414
|
+
self._weight_loader_physical(
|
415
|
+
param=param,
|
416
|
+
loaded_weight=loaded_weight,
|
417
|
+
weight_name=weight_name,
|
418
|
+
shard_id=shard_id,
|
419
|
+
expert_id=physical_expert_id,
|
420
|
+
)
|
421
|
+
|
422
|
+
def _weight_loader_physical(
|
423
|
+
self,
|
424
|
+
param: torch.nn.Parameter,
|
425
|
+
loaded_weight: torch.Tensor,
|
426
|
+
weight_name: str,
|
427
|
+
shard_id: str,
|
428
|
+
expert_id: int,
|
429
|
+
) -> None:
|
378
430
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
379
431
|
if expert_id == -1:
|
380
432
|
return
|
381
|
-
|
382
433
|
self._weight_loader_impl(
|
383
434
|
param=param,
|
384
435
|
loaded_weight=loaded_weight,
|
@@ -396,8 +447,7 @@ class FusedMoE(torch.nn.Module):
|
|
396
447
|
expert_id: int,
|
397
448
|
) -> None:
|
398
449
|
|
399
|
-
|
400
|
-
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
450
|
+
tp_rank = self.moe_tp_rank
|
401
451
|
|
402
452
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
403
453
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
@@ -417,7 +467,7 @@ class FusedMoE(torch.nn.Module):
|
|
417
467
|
)
|
418
468
|
|
419
469
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
420
|
-
if
|
470
|
+
if should_use_flashinfer_trtllm_moe():
|
421
471
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
422
472
|
|
423
473
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
@@ -571,9 +621,14 @@ class FusedMoE(torch.nn.Module):
|
|
571
621
|
)
|
572
622
|
return
|
573
623
|
|
574
|
-
def forward(self, hidden_states: torch.Tensor, topk_output:
|
624
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
575
625
|
assert self.quant_method is not None
|
576
626
|
|
627
|
+
if self.expert_map_gpu is not None:
|
628
|
+
topk_output = topk_output._replace(
|
629
|
+
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
630
|
+
)
|
631
|
+
|
577
632
|
# Matrix multiply.
|
578
633
|
final_hidden_states = self.quant_method.apply(
|
579
634
|
layer=self,
|
@@ -584,17 +639,17 @@ class FusedMoE(torch.nn.Module):
|
|
584
639
|
routed_scaling_factor=self.routed_scaling_factor,
|
585
640
|
**(
|
586
641
|
dict(
|
587
|
-
tp_rank=self.
|
588
|
-
tp_size=self.
|
589
|
-
ep_rank=self.
|
590
|
-
ep_size=self.
|
642
|
+
tp_rank=self.moe_tp_rank,
|
643
|
+
tp_size=self.moe_tp_size,
|
644
|
+
ep_rank=self.moe_ep_rank,
|
645
|
+
ep_size=self.moe_ep_size,
|
591
646
|
)
|
592
647
|
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
593
648
|
else {}
|
594
649
|
),
|
595
650
|
)
|
596
651
|
|
597
|
-
if self.reduce_results and (self.
|
652
|
+
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
598
653
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
599
654
|
|
600
655
|
return final_hidden_states
|
@@ -627,3 +682,61 @@ class FusedMoE(torch.nn.Module):
|
|
627
682
|
("w3", ckpt_up_proj_name),
|
628
683
|
]
|
629
684
|
]
|
685
|
+
|
686
|
+
@classmethod
|
687
|
+
def make_expert_input_scale_params_mapping(
|
688
|
+
cls,
|
689
|
+
num_experts: int,
|
690
|
+
) -> List[Tuple[str, str, int, str]]:
|
691
|
+
# (param_name, weight_name, expert_id, shard_id)
|
692
|
+
return [
|
693
|
+
(
|
694
|
+
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
695
|
+
f"experts.{expert_id}.{shard_id}.",
|
696
|
+
expert_id,
|
697
|
+
shard_id,
|
698
|
+
)
|
699
|
+
for expert_id in range(num_experts)
|
700
|
+
for shard_id in ["w1", "w2", "w3"]
|
701
|
+
]
|
702
|
+
|
703
|
+
|
704
|
+
class FlashInferFusedMoE(FusedMoE):
|
705
|
+
def __init__(self, *args, **kwargs):
|
706
|
+
renormalize = kwargs.pop("renormalize", True)
|
707
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
708
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
709
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
710
|
+
topk_group = kwargs.pop("topk_group", None)
|
711
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
712
|
+
super().__init__(*args, **kwargs)
|
713
|
+
self.renormalize = renormalize
|
714
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
715
|
+
self.use_grouped_topk = use_grouped_topk
|
716
|
+
if self.use_grouped_topk:
|
717
|
+
assert num_expert_group is not None and topk_group is not None
|
718
|
+
self.num_expert_group = num_expert_group
|
719
|
+
self.topk_group = topk_group
|
720
|
+
self.correction_bias = correction_bias
|
721
|
+
|
722
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
723
|
+
assert self.quant_method is not None
|
724
|
+
assert (
|
725
|
+
self.renormalize
|
726
|
+
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
727
|
+
assert (
|
728
|
+
self.num_fused_shared_experts == 0
|
729
|
+
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
730
|
+
# Matrix multiply.
|
731
|
+
final_hidden_states = self.quant_method.apply_with_router_logits(
|
732
|
+
layer=self,
|
733
|
+
x=hidden_states,
|
734
|
+
router_logits=router_logits,
|
735
|
+
activation=self.activation,
|
736
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
737
|
+
)
|
738
|
+
|
739
|
+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
740
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
741
|
+
|
742
|
+
return final_hidden_states
|
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
|
72
72
|
is_hip,
|
73
73
|
is_npu,
|
74
74
|
log_info_on_rank0,
|
75
|
+
next_power_of_2,
|
75
76
|
print_warning_once,
|
76
77
|
set_weight_attrs,
|
77
78
|
use_intel_amx_backend,
|
@@ -172,7 +173,6 @@ class Fp8Config(QuantizationConfig):
|
|
172
173
|
self, layer: torch.nn.Module, prefix: str
|
173
174
|
) -> Optional[QuantizeMethodBase]:
|
174
175
|
from sglang.srt.layers.linear import LinearBase
|
175
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
176
176
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
177
177
|
|
178
178
|
if isinstance(layer, LinearBase):
|
@@ -181,8 +181,6 @@ class Fp8Config(QuantizationConfig):
|
|
181
181
|
return Fp8LinearMethod(self)
|
182
182
|
elif isinstance(layer, FusedMoE):
|
183
183
|
return Fp8MoEMethod(self)
|
184
|
-
elif isinstance(layer, EPMoE):
|
185
|
-
return Fp8EPMoEMethod(self)
|
186
184
|
return None
|
187
185
|
|
188
186
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -493,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
493
491
|
)
|
494
492
|
|
495
493
|
|
494
|
+
def get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
495
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
496
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
497
|
+
# And pad the number to the next power of 2.
|
498
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
499
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
500
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
501
|
+
return tile_tokens_dim
|
502
|
+
|
503
|
+
|
496
504
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
497
505
|
"""MoE method for FP8.
|
498
506
|
Supports loading FP8 checkpoints with static weight scale and
|
@@ -984,23 +992,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
984
992
|
no_combine: bool = False,
|
985
993
|
routed_scaling_factor: Optional[float] = None,
|
986
994
|
) -> torch.Tensor:
|
987
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
988
995
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
989
996
|
|
990
|
-
if isinstance(layer, EPMoE):
|
991
|
-
layer.w13_weight_scale = (
|
992
|
-
layer.w13_weight_scale_inv
|
993
|
-
if self.block_quant
|
994
|
-
else layer.w13_weight_scale
|
995
|
-
)
|
996
|
-
layer.w2_weight_scale = (
|
997
|
-
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
998
|
-
)
|
999
|
-
return layer.run_moe(
|
1000
|
-
hidden_states=x,
|
1001
|
-
topk_output=topk_output,
|
1002
|
-
)
|
1003
|
-
|
1004
997
|
if use_intel_amx_backend(layer):
|
1005
998
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
1006
999
|
|
@@ -1094,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1094
1087
|
routed_scaling_factor=routed_scaling_factor,
|
1095
1088
|
)
|
1096
1089
|
|
1090
|
+
def apply_with_router_logits(
|
1091
|
+
self,
|
1092
|
+
layer: torch.nn.Module,
|
1093
|
+
x: torch.Tensor,
|
1094
|
+
router_logits: torch.Tensor,
|
1095
|
+
*,
|
1096
|
+
activation: str = "silu",
|
1097
|
+
routed_scaling_factor: Optional[float] = None,
|
1098
|
+
) -> torch.Tensor:
|
1099
|
+
assert (
|
1100
|
+
activation == "silu"
|
1101
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
1102
|
+
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
1103
|
+
# NOTE: scales of hidden states have to be transposed!
|
1104
|
+
a_sf_t = a_sf.t().contiguous()
|
1105
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1106
|
+
|
1107
|
+
return trtllm_fp8_block_scale_moe(
|
1108
|
+
routing_logits=router_logits.to(torch.float32),
|
1109
|
+
routing_bias=layer.correction_bias.to(x.dtype),
|
1110
|
+
hidden_states=a_q,
|
1111
|
+
hidden_states_scale=a_sf_t,
|
1112
|
+
gemm1_weights=layer.w13_weight,
|
1113
|
+
gemm1_weights_scale=layer.w13_weight_scale_inv,
|
1114
|
+
gemm2_weights=layer.w2_weight,
|
1115
|
+
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
1116
|
+
num_experts=layer.num_experts,
|
1117
|
+
top_k=layer.top_k,
|
1118
|
+
n_group=layer.num_expert_group,
|
1119
|
+
topk_group=layer.topk_group,
|
1120
|
+
intermediate_size=layer.w2_weight.shape[2],
|
1121
|
+
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
1122
|
+
local_num_experts=layer.num_local_experts,
|
1123
|
+
routed_scaling_factor=routed_scaling_factor,
|
1124
|
+
tile_tokens_dim=get_tile_tokens_dim(
|
1125
|
+
x.shape[0], layer.top_k, layer.num_experts
|
1126
|
+
),
|
1127
|
+
routing_method_type=2, # DeepSeek-styled routing method
|
1128
|
+
use_shuffled_weight=False,
|
1129
|
+
)
|
1130
|
+
|
1097
1131
|
def maybe_apply_hip_fused_experts(
|
1098
1132
|
self,
|
1099
1133
|
layer: torch.nn.Module,
|
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
204
204
|
routed_scaling_factor: Optional[float] = None,
|
205
205
|
) -> torch.Tensor:
|
206
206
|
|
207
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
208
|
-
|
209
|
-
if isinstance(layer, EPMoE):
|
210
|
-
return layer.run_moe(
|
211
|
-
hidden_states=x,
|
212
|
-
topk_output=topk_output,
|
213
|
-
)
|
214
|
-
|
215
207
|
return self.forward(
|
216
208
|
x=x,
|
217
209
|
layer=layer,
|
@@ -231,7 +231,10 @@ class W8A8Int8Config(QuantizationConfig):
|
|
231
231
|
|
232
232
|
@classmethod
|
233
233
|
def get_config_filenames(cls) -> List[str]:
|
234
|
-
|
234
|
+
filenames = []
|
235
|
+
if _is_npu:
|
236
|
+
filenames.append("quant_model_description.json")
|
237
|
+
return filenames
|
235
238
|
|
236
239
|
@classmethod
|
237
240
|
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|