sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
|
|
175
175
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
176
176
|
assert num_experts % self.moe_ep_size == 0
|
177
177
|
self.num_local_experts = num_experts // self.moe_ep_size
|
178
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
179
|
+
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
178
180
|
if self.moe_ep_size > 1:
|
179
181
|
# TODO(ch-wan): support shared experts fusion
|
180
182
|
# Create a tensor of size num_experts filled with -1
|
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
|
|
593
595
|
|
594
596
|
if (
|
595
597
|
"compressed" in self.quant_method.__class__.__name__.lower()
|
596
|
-
|
597
|
-
and (param.data[expert_id]
|
598
|
+
or "w4afp8" in self.quant_config.get_name()
|
599
|
+
and (param.data[expert_id] != 1).any()
|
600
|
+
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
|
598
601
|
):
|
599
602
|
raise ValueError(
|
600
603
|
"input_scales of w1 and w3 of a layer "
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import triton
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
|
10
|
+
_is_cuda = is_cuda()
|
11
|
+
_is_hip = is_hip()
|
12
|
+
|
13
|
+
if _is_cuda or _is_hip:
|
14
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
15
|
+
|
16
|
+
|
17
|
+
def moe_align_block_size(
|
18
|
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
19
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
20
|
+
"""
|
21
|
+
Aligns the token distribution across experts to be compatible with block
|
22
|
+
size for matrix multiplication.
|
23
|
+
|
24
|
+
Parameters:
|
25
|
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
26
|
+
top-k expert indices for each token.
|
27
|
+
- block_size: The block size used in block matrix multiplication.
|
28
|
+
- num_experts: The total number of experts.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
32
|
+
to their allocated expert.
|
33
|
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
34
|
+
- num_tokens_post_padded: The total number of tokens after padding,
|
35
|
+
ensuring divisibility by block_size.
|
36
|
+
|
37
|
+
This function pads the number of tokens that each expert needs to process
|
38
|
+
so that it is divisible by block_size.
|
39
|
+
Padding ensures that during block matrix multiplication, the dimensions
|
40
|
+
align correctly.
|
41
|
+
|
42
|
+
Example:
|
43
|
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
44
|
+
block_size = 4, and num_experts = 4:
|
45
|
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
46
|
+
with each expert needing to process 3 tokens.
|
47
|
+
- As block_size is 4, we pad 1 token for each expert.
|
48
|
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
49
|
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
50
|
+
- After sorting by expert index, we obtain token_ids
|
51
|
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
52
|
+
Tokens 12 are non-existent (padding) and are ignored in
|
53
|
+
the subsequent matrix multiplication.
|
54
|
+
- The padding ensures that the total number of tokens is now divisible
|
55
|
+
by block_size for proper block matrix operations.
|
56
|
+
"""
|
57
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
58
|
+
sorted_ids = torch.empty(
|
59
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
60
|
+
)
|
61
|
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
62
|
+
expert_ids = torch.empty(
|
63
|
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
64
|
+
)
|
65
|
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
66
|
+
|
67
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
68
|
+
cumsum_buffer = torch.empty(
|
69
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
70
|
+
)
|
71
|
+
|
72
|
+
# Threshold based on benchmark results
|
73
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
74
|
+
if not fuse_sorted_ids_padding:
|
75
|
+
sorted_ids.fill_(topk_ids.numel())
|
76
|
+
|
77
|
+
sgl_moe_align_block_size(
|
78
|
+
topk_ids,
|
79
|
+
num_experts + 1,
|
80
|
+
block_size,
|
81
|
+
sorted_ids,
|
82
|
+
expert_ids,
|
83
|
+
num_tokens_post_pad,
|
84
|
+
cumsum_buffer,
|
85
|
+
fuse_sorted_ids_padding,
|
86
|
+
)
|
87
|
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -304,12 +304,12 @@ class TopK(CustomOp):
|
|
304
304
|
global_num_experts = router_logits.shape[-1]
|
305
305
|
|
306
306
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
307
|
-
if global_num_experts == 256
|
307
|
+
if global_num_experts == 256:
|
308
308
|
|
309
309
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
310
310
|
router_logits = router_logits.to(torch.float32)
|
311
311
|
|
312
|
-
|
312
|
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
313
313
|
router_logits,
|
314
314
|
k=self.topk_config.top_k,
|
315
315
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
@@ -321,6 +321,16 @@ class TopK(CustomOp):
|
|
321
321
|
routed_scaling_factor=routed_scaling_factor,
|
322
322
|
eps=float(1e-20),
|
323
323
|
)
|
324
|
+
|
325
|
+
if self.topk_config.renormalize:
|
326
|
+
topk_weights_sum = (
|
327
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
328
|
+
if self.topk_config.num_fused_shared_experts == 0
|
329
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
330
|
+
)
|
331
|
+
topk_weights = topk_weights / topk_weights_sum
|
332
|
+
|
333
|
+
return StandardTopKOutput(topk_weights, topk_ids, _)
|
324
334
|
else:
|
325
335
|
self.topk_config.torch_native = True
|
326
336
|
return select_experts(
|
@@ -347,17 +357,28 @@ def fused_topk_torch_native(
|
|
347
357
|
gating_output: torch.Tensor,
|
348
358
|
topk: int,
|
349
359
|
renormalize: bool,
|
360
|
+
correction_bias: torch.Tensor = None,
|
350
361
|
):
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
362
|
+
if correction_bias is not None:
|
363
|
+
n_routed_experts = gating_output.shape[-1]
|
364
|
+
scores = gating_output.softmax(dim=-1)
|
365
|
+
scores_for_choice = scores.view(
|
366
|
+
-1, n_routed_experts
|
367
|
+
) + correction_bias.unsqueeze(0)
|
368
|
+
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
369
|
+
topk_weights = scores.gather(1, topk_ids)
|
370
|
+
else:
|
371
|
+
assert (
|
372
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
373
|
+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
374
|
+
M, _ = hidden_states.shape
|
375
|
+
topk_weights = torch.empty(
|
376
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
377
|
+
)
|
378
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
379
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
380
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
381
|
+
|
361
382
|
if renormalize:
|
362
383
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
363
384
|
return topk_weights, topk_ids
|
@@ -370,6 +391,7 @@ def fused_topk_cpu(
|
|
370
391
|
renormalize: bool,
|
371
392
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
372
393
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
394
|
+
correction_bias: torch.Tensor = None,
|
373
395
|
):
|
374
396
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
375
397
|
hidden_states=hidden_states,
|
@@ -815,6 +837,7 @@ def select_experts(
|
|
815
837
|
gating_output=router_logits,
|
816
838
|
topk=top_k,
|
817
839
|
renormalize=renormalize,
|
840
|
+
correction_bias=correction_bias,
|
818
841
|
)
|
819
842
|
elif custom_routing_function is None:
|
820
843
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
@@ -93,7 +93,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
93
93
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
94
94
|
logger.warning(
|
95
95
|
"Entering DeepGEMM JIT Pre-Compile session. "
|
96
|
-
"It may
|
96
|
+
"It may take a long time (typically 10-20 mins) "
|
97
97
|
"if you have not run `sglang.compile_deep_gemm`. "
|
98
98
|
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
99
99
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
|
|
132
132
|
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
133
|
)
|
134
134
|
|
135
|
+
old_compile_mode = deep_gemm.get_compile_mode()
|
136
|
+
deep_gemm.set_compile_mode(1)
|
135
137
|
# TODO can use multi thread
|
136
138
|
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
137
139
|
executor.execute(m=m)
|
140
|
+
deep_gemm.set_compile_mode(old_compile_mode)
|
141
|
+
|
142
|
+
# clean up input buffers
|
143
|
+
torch.cuda.current_stream().synchronize()
|
144
|
+
del executor
|
145
|
+
torch.cuda.empty_cache()
|
138
146
|
|
139
147
|
|
140
148
|
class _BaseWarmupExecutor:
|
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
599
599
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
600
600
|
if re.fullmatch(regex_str, prefix):
|
601
601
|
return True
|
602
|
+
|
603
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
604
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
605
|
+
pattern_last_part = pattern.split(".")[-1]
|
606
|
+
prefix_last_part = prefix.split(".")[-1]
|
607
|
+
if pattern_last_part in prefix_last_part:
|
608
|
+
return True
|
602
609
|
return False
|
603
610
|
|
604
611
|
def get_quant_method(
|
@@ -66,10 +66,15 @@ _is_hip = is_hip()
|
|
66
66
|
|
67
67
|
if _is_hip:
|
68
68
|
# import aiter
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
69
|
+
try:
|
70
|
+
from aiter import ActivationType, QuantType, dtypes
|
71
|
+
from aiter.fused_moe import fused_moe
|
72
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
73
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
74
|
+
except ImportError as err:
|
75
|
+
ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
|
76
|
+
e8m0_shuffle
|
77
|
+
) = err
|
73
78
|
|
74
79
|
|
75
80
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
@@ -77,6 +77,19 @@ def is_layer_skipped(
|
|
77
77
|
)
|
78
78
|
else:
|
79
79
|
is_skipped = prefix in ignored_layers
|
80
|
+
if "gate_up_proj" in prefix:
|
81
|
+
prefix_gate = prefix.replace("gate_up_proj", "gate_proj")
|
82
|
+
prefix_up = prefix.replace("gate_up_proj", "up_proj")
|
83
|
+
if prefix_gate in ignored_layers and prefix_up in ignored_layers:
|
84
|
+
is_skipped = True
|
85
|
+
elif "experts" in prefix:
|
86
|
+
is_skipped = any(
|
87
|
+
[
|
88
|
+
prefix in layer_name
|
89
|
+
for layer_name in ignored_layers
|
90
|
+
if "experts" in layer_name
|
91
|
+
]
|
92
|
+
)
|
80
93
|
|
81
94
|
assert is_skipped is not None
|
82
95
|
return is_skipped
|
@@ -1,12 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
8
8
|
from torch.nn.parameter import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
11
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
10
12
|
from sglang.srt.layers.quantization.base_config import (
|
11
13
|
FusedMoEMethodBase,
|
12
14
|
QuantizationConfig,
|
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
|
|
91
93
|
from sglang.srt.layers.linear import LinearBase
|
92
94
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
93
95
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
96
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
94
97
|
|
95
98
|
if isinstance(layer, LinearBase):
|
96
99
|
if is_layer_skipped(prefix, self.ignored_layers):
|
97
100
|
return UnquantizedLinearMethod()
|
98
101
|
return Fp8LinearMethod(self)
|
99
|
-
elif isinstance(layer,
|
102
|
+
elif isinstance(layer, FusedMoE):
|
100
103
|
return W4AFp8MoEMethod(self)
|
101
104
|
return None
|
102
105
|
|
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
|
|
104
107
|
return []
|
105
108
|
|
106
109
|
|
107
|
-
|
110
|
+
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
|
111
|
+
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
112
|
+
s_shape = scales.shape
|
113
|
+
# Reshape to separate groups of 4
|
114
|
+
alignment = 4 if s_shape[2] % 4 == 0 else 1
|
115
|
+
scales_interleaved = scales.reshape(
|
116
|
+
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
|
117
|
+
)
|
118
|
+
# Permute dimensions to interleave
|
119
|
+
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
120
|
+
# Reshape back to original dimensions but with interleaved values
|
121
|
+
scales_interleaved = scales_interleaved.reshape(
|
122
|
+
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
|
123
|
+
)
|
124
|
+
return scales_interleaved.contiguous()
|
125
|
+
|
108
126
|
|
127
|
+
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
109
128
|
def __init__(self, quant_config: W4AFp8Config):
|
110
129
|
self.quant_config = quant_config
|
111
130
|
|
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
234
253
|
|
235
254
|
return
|
236
255
|
|
237
|
-
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
238
|
-
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
239
|
-
s_shape = scales.shape
|
240
|
-
# Reshape to separate groups of 4
|
241
|
-
scales_interleaved = scales.reshape(
|
242
|
-
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
243
|
-
)
|
244
|
-
# Permute dimensions to interleave
|
245
|
-
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
246
|
-
# Reshape back to original dimensions but with interleaved values
|
247
|
-
scales_interleaved = scales_interleaved.reshape(
|
248
|
-
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
249
|
-
)
|
250
|
-
return scales_interleaved.contiguous()
|
251
|
-
|
252
256
|
def process_weights_after_loading(self, layer: Module) -> None:
|
253
257
|
dtype = torch.bfloat16
|
254
258
|
device = layer.w2_weight.device
|
255
259
|
|
256
260
|
# Interleave w13_weight_scale (gate_up_proj)
|
257
261
|
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
258
|
-
w13_weight_scale =
|
262
|
+
w13_weight_scale = interleave_scales(w13_weight_scale)
|
259
263
|
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
260
264
|
|
261
265
|
# Interleave w2_weight_scale (down_proj)
|
262
266
|
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
263
|
-
w2_weight_scale =
|
267
|
+
w2_weight_scale = interleave_scales(w2_weight_scale)
|
264
268
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
265
269
|
|
266
270
|
# Process input scales
|
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
291
295
|
|
292
296
|
topk_weights, topk_ids, _ = topk_output
|
293
297
|
local_topk_ids = topk_ids
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
298
|
+
if get_moe_expert_parallel_world_size() > 1:
|
299
|
+
local_topk_ids = torch.where(
|
300
|
+
topk_ids == -1,
|
301
|
+
layer.num_experts,
|
302
|
+
topk_ids,
|
303
|
+
)
|
299
304
|
|
300
305
|
output = cutlass_w4a8_moe(
|
301
306
|
layer.start_expert_id,
|
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
|
|
551
551
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
552
552
|
params_dict = {}
|
553
553
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
554
|
-
params_dict["input_offset"] = torch.empty(1, dtype=
|
554
|
+
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
|
555
555
|
return params_dict
|
556
556
|
|
557
557
|
@staticmethod
|
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
|
|
582
582
|
if original_dtype != torch.int8:
|
583
583
|
x = torch_npu.npu_quantize(
|
584
584
|
x,
|
585
|
-
layer.
|
585
|
+
layer.aclnn_input_scale_reciprocal,
|
586
586
|
layer.aclnn_input_offset,
|
587
587
|
torch.qint8,
|
588
588
|
-1,
|
589
|
-
|
589
|
+
False,
|
590
590
|
)
|
591
591
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
592
592
|
# bias will not get added more than once in Attention TP>1 case)
|
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
|
|
608
608
|
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
609
609
|
requires_grad=False,
|
610
610
|
)
|
611
|
+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
612
|
+
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
613
|
+
requires_grad=False,
|
614
|
+
)
|
611
615
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
612
616
|
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
613
617
|
requires_grad=False,
|
@@ -1876,7 +1876,7 @@ def rotate_half(x):
|
|
1876
1876
|
return torch.cat((-x2, x1), dim=-1)
|
1877
1877
|
|
1878
1878
|
|
1879
|
-
def
|
1879
|
+
def apply_rotary_pos_emb_native(
|
1880
1880
|
q: torch.Tensor,
|
1881
1881
|
k: torch.Tensor,
|
1882
1882
|
cos: torch.Tensor,
|
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
|
|
1899
1899
|
return q_embed, k_embed
|
1900
1900
|
|
1901
1901
|
|
1902
|
+
def apply_rotary_pos_emb_npu(
|
1903
|
+
q: torch.Tensor,
|
1904
|
+
k: torch.Tensor,
|
1905
|
+
cos: torch.Tensor,
|
1906
|
+
sin: torch.Tensor,
|
1907
|
+
unsqueeze_dim=1,
|
1908
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1909
|
+
if q.shape[1] != 128:
|
1910
|
+
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
1911
|
+
cos = cos.unsqueeze(unsqueeze_dim)
|
1912
|
+
cos = torch.transpose(cos, 1, 2)
|
1913
|
+
sin = sin.unsqueeze(unsqueeze_dim)
|
1914
|
+
sin = torch.transpose(sin, 1, 2)
|
1915
|
+
q = torch.transpose(q, 1, 2)
|
1916
|
+
k = torch.transpose(k, 1, 2)
|
1917
|
+
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
1918
|
+
q_embed = torch.transpose(q_embed, 1, 2)
|
1919
|
+
k_embed = torch.transpose(k_embed, 1, 2)
|
1920
|
+
return q_embed, k_embed
|
1921
|
+
|
1922
|
+
|
1923
|
+
if _is_npu:
|
1924
|
+
apply_rotary_pos_emb = apply_rotary_pos_emb_npu
|
1925
|
+
else:
|
1926
|
+
apply_rotary_pos_emb = apply_rotary_pos_emb_native
|
1927
|
+
|
1928
|
+
|
1902
1929
|
def get_rope_cpu(
|
1903
1930
|
head_size: int,
|
1904
1931
|
rotary_dim: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -27,6 +27,7 @@ if is_cuda():
|
|
27
27
|
logger = logging.getLogger(__name__)
|
28
28
|
|
29
29
|
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
30
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
30
31
|
|
31
32
|
|
32
33
|
class Sampler(nn.Module):
|
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
|
|
77
78
|
batch_next_token_ids = torch.argmax(logits, -1)
|
78
79
|
if return_logprob:
|
79
80
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
81
|
+
|
80
82
|
else:
|
83
|
+
# Post process original logits. if temperatures are all 1.0, no need to rescale
|
84
|
+
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
85
|
+
logprobs = torch.softmax(logits, dim=-1)
|
86
|
+
|
81
87
|
# Post process logits
|
82
88
|
logits.div_(sampling_info.temperatures)
|
83
89
|
logits[:] = torch.softmax(logits, dim=-1)
|
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
|
|
116
122
|
|
117
123
|
if return_logprob:
|
118
124
|
# clamp to avoid -inf
|
119
|
-
|
125
|
+
if RETURN_ORIGINAL_LOGPROB:
|
126
|
+
logprobs = torch.log(logprobs).clamp(
|
127
|
+
min=torch.finfo(logprobs.dtype).min
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
120
131
|
|
121
132
|
# Attach logprobs to logits_output (in-place modification)
|
122
133
|
if return_logprob:
|
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
|
|
201
212
|
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
202
213
|
|
203
214
|
|
204
|
-
def get_top_logprobs(
|
215
|
+
def get_top_logprobs(
|
216
|
+
logprobs: torch.Tensor,
|
217
|
+
top_logprobs_nums: List[int],
|
218
|
+
):
|
205
219
|
max_k = max(top_logprobs_nums)
|
206
220
|
ret = logprobs.topk(max_k, dim=1)
|
207
221
|
values = ret.values.tolist()
|
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
|
212
226
|
for i, k in enumerate(top_logprobs_nums):
|
213
227
|
output_top_logprobs_val.append(values[i][:k])
|
214
228
|
output_top_logprobs_idx.append(indices[i][:k])
|
215
|
-
|
229
|
+
|
230
|
+
return (
|
231
|
+
output_top_logprobs_val,
|
232
|
+
output_top_logprobs_idx,
|
233
|
+
)
|
216
234
|
|
217
235
|
|
218
|
-
def get_token_ids_logprobs(
|
236
|
+
def get_token_ids_logprobs(
|
237
|
+
logprobs: torch.Tensor,
|
238
|
+
token_ids_logprobs: List[List[int]],
|
239
|
+
):
|
219
240
|
output_token_ids_logprobs_val = []
|
220
241
|
output_token_ids_logprobs_idx = []
|
221
242
|
for i, token_ids in enumerate(token_ids_logprobs):
|
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
|
|
226
247
|
output_token_ids_logprobs_val.append([])
|
227
248
|
output_token_ids_logprobs_idx.append([])
|
228
249
|
|
229
|
-
return
|
250
|
+
return (
|
251
|
+
output_token_ids_logprobs_val,
|
252
|
+
output_token_ids_logprobs_idx,
|
253
|
+
)
|
230
254
|
|
231
255
|
|
232
256
|
def apply_custom_logit_processor(
|