sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +46 -41
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +7 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +16 -0
- sglang/srt/mem_cache/memory_pool_host.py +18 -11
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +245 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -10
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
|
|
175
175
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
176
176
|
assert num_experts % self.moe_ep_size == 0
|
177
177
|
self.num_local_experts = num_experts // self.moe_ep_size
|
178
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
179
|
+
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
178
180
|
if self.moe_ep_size > 1:
|
179
181
|
# TODO(ch-wan): support shared experts fusion
|
180
182
|
# Create a tensor of size num_experts filled with -1
|
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
|
|
593
595
|
|
594
596
|
if (
|
595
597
|
"compressed" in self.quant_method.__class__.__name__.lower()
|
596
|
-
|
597
|
-
and (param.data[expert_id]
|
598
|
+
or "w4afp8" in self.quant_config.get_name()
|
599
|
+
and (param.data[expert_id] != 1).any()
|
600
|
+
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
|
598
601
|
):
|
599
602
|
raise ValueError(
|
600
603
|
"input_scales of w1 and w3 of a layer "
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import triton
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
|
10
|
+
_is_cuda = is_cuda()
|
11
|
+
_is_hip = is_hip()
|
12
|
+
|
13
|
+
if _is_cuda or _is_hip:
|
14
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
15
|
+
|
16
|
+
|
17
|
+
def moe_align_block_size(
|
18
|
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
19
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
20
|
+
"""
|
21
|
+
Aligns the token distribution across experts to be compatible with block
|
22
|
+
size for matrix multiplication.
|
23
|
+
|
24
|
+
Parameters:
|
25
|
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
26
|
+
top-k expert indices for each token.
|
27
|
+
- block_size: The block size used in block matrix multiplication.
|
28
|
+
- num_experts: The total number of experts.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
32
|
+
to their allocated expert.
|
33
|
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
34
|
+
- num_tokens_post_padded: The total number of tokens after padding,
|
35
|
+
ensuring divisibility by block_size.
|
36
|
+
|
37
|
+
This function pads the number of tokens that each expert needs to process
|
38
|
+
so that it is divisible by block_size.
|
39
|
+
Padding ensures that during block matrix multiplication, the dimensions
|
40
|
+
align correctly.
|
41
|
+
|
42
|
+
Example:
|
43
|
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
44
|
+
block_size = 4, and num_experts = 4:
|
45
|
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
46
|
+
with each expert needing to process 3 tokens.
|
47
|
+
- As block_size is 4, we pad 1 token for each expert.
|
48
|
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
49
|
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
50
|
+
- After sorting by expert index, we obtain token_ids
|
51
|
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
52
|
+
Tokens 12 are non-existent (padding) and are ignored in
|
53
|
+
the subsequent matrix multiplication.
|
54
|
+
- The padding ensures that the total number of tokens is now divisible
|
55
|
+
by block_size for proper block matrix operations.
|
56
|
+
"""
|
57
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
58
|
+
sorted_ids = torch.empty(
|
59
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
60
|
+
)
|
61
|
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
62
|
+
expert_ids = torch.empty(
|
63
|
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
64
|
+
)
|
65
|
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
66
|
+
|
67
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
68
|
+
cumsum_buffer = torch.empty(
|
69
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
70
|
+
)
|
71
|
+
|
72
|
+
# Threshold based on benchmark results
|
73
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
74
|
+
if not fuse_sorted_ids_padding:
|
75
|
+
sorted_ids.fill_(topk_ids.numel())
|
76
|
+
|
77
|
+
sgl_moe_align_block_size(
|
78
|
+
topk_ids,
|
79
|
+
num_experts + 1,
|
80
|
+
block_size,
|
81
|
+
sorted_ids,
|
82
|
+
expert_ids,
|
83
|
+
num_tokens_post_pad,
|
84
|
+
cumsum_buffer,
|
85
|
+
fuse_sorted_ids_padding,
|
86
|
+
)
|
87
|
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
|
|
132
132
|
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
133
|
)
|
134
134
|
|
135
|
+
old_compile_mode = deep_gemm.get_compile_mode()
|
136
|
+
deep_gemm.set_compile_mode(1)
|
135
137
|
# TODO can use multi thread
|
136
138
|
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
137
139
|
executor.execute(m=m)
|
140
|
+
deep_gemm.set_compile_mode(old_compile_mode)
|
141
|
+
|
142
|
+
# clean up input buffers
|
143
|
+
torch.cuda.current_stream().synchronize()
|
144
|
+
del executor
|
145
|
+
torch.cuda.empty_cache()
|
138
146
|
|
139
147
|
|
140
148
|
class _BaseWarmupExecutor:
|
@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
517
517
|
def get_config_filenames(cls) -> List[str]:
|
518
518
|
return ["hf_quant_config.json"]
|
519
519
|
|
520
|
+
@staticmethod
|
521
|
+
def common_group_size(cfg: dict) -> int:
|
522
|
+
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
523
|
+
sizes = set()
|
524
|
+
|
525
|
+
# Top-level and 'quantization' block
|
526
|
+
v = cfg.get("group_size")
|
527
|
+
if isinstance(v, int):
|
528
|
+
sizes.add(v)
|
529
|
+
q = cfg.get("quantization")
|
530
|
+
if isinstance(q, dict):
|
531
|
+
v = q.get("group_size")
|
532
|
+
if isinstance(v, int):
|
533
|
+
sizes.add(v)
|
534
|
+
|
535
|
+
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
|
536
|
+
for g in (cfg.get("config_groups") or {}).values():
|
537
|
+
if isinstance(g, dict):
|
538
|
+
v = g.get("group_size")
|
539
|
+
if isinstance(v, int):
|
540
|
+
sizes.add(v)
|
541
|
+
for sub in g.values():
|
542
|
+
if isinstance(sub, dict):
|
543
|
+
v = sub.get("group_size")
|
544
|
+
if isinstance(v, int):
|
545
|
+
sizes.add(v)
|
546
|
+
|
547
|
+
if not sizes:
|
548
|
+
raise ValueError("No group_size found in config.")
|
549
|
+
if len(sizes) > 1:
|
550
|
+
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
|
551
|
+
return next(iter(sizes))
|
552
|
+
|
520
553
|
@classmethod
|
521
554
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
522
555
|
# Handle two different config formats:
|
@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
549
582
|
else:
|
550
583
|
kv_cache_quant_algo = "auto"
|
551
584
|
|
552
|
-
group_size =
|
585
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
553
586
|
exclude_modules = config.get("ignore", [])
|
554
587
|
else:
|
555
588
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
559
592
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
593
|
if not kv_cache_quant_algo:
|
561
594
|
kv_cache_quant_algo = "auto"
|
562
|
-
group_size =
|
595
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
563
596
|
exclude_modules = quant_config.get("exclude_modules", [])
|
564
597
|
except (ValueError, KeyError):
|
565
598
|
raise ValueError(
|
@@ -816,7 +816,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
816
816
|
moe_runner_config: MoeRunnerConfig,
|
817
817
|
) -> torch.Tensor:
|
818
818
|
topk_weights, topk_ids, _ = topk_output
|
819
|
-
|
819
|
+
if _is_hip:
|
820
|
+
topk_weights = topk_weights.to(
|
821
|
+
torch.float32
|
822
|
+
) # aiter's moe_sorting requires topk_weights to be FP32
|
820
823
|
return fused_moe(
|
821
824
|
x,
|
822
825
|
layer.w13_weight,
|
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8
8
|
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
9
9
|
from aiter.ops.shuffle import shuffle_weight
|
10
10
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
11
|
+
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
11
12
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
12
13
|
from aiter.utility import dtypes
|
13
14
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|
38
39
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
39
40
|
return
|
40
41
|
|
41
|
-
# for aiter implement
|
42
|
-
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
43
|
-
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
44
|
-
|
45
|
-
# layer.weight = torch.nn.Parameter(wshuffle,
|
46
|
-
# requires_grad=False)
|
47
|
-
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
48
|
-
# requires_grad=False)
|
49
|
-
|
50
42
|
def create_weights(
|
51
43
|
self,
|
52
44
|
layer: torch.nn.Module,
|
@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|
93
85
|
x: torch.Tensor,
|
94
86
|
bias: Optional[torch.Tensor] = None,
|
95
87
|
) -> torch.Tensor:
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
88
|
+
# This path does not have support for bias currently
|
89
|
+
assert bias is None, "bias is not supported"
|
90
|
+
|
91
|
+
three_d = False
|
92
|
+
x_s = None
|
93
|
+
y = None
|
94
|
+
if isinstance(x, tuple):
|
95
|
+
assert len(x) in [
|
96
|
+
2,
|
97
|
+
3,
|
98
|
+
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
|
99
|
+
if len(x) == 2:
|
100
|
+
x, x_s = x
|
101
|
+
elif len(x) == 3:
|
102
|
+
x, x_s, y = x
|
103
|
+
|
104
|
+
use_fused_quant_gemm = (
|
105
|
+
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
|
114
106
|
)
|
115
107
|
|
116
|
-
|
117
|
-
|
118
|
-
|
108
|
+
if x.dim() == 3:
|
109
|
+
three_d = True
|
110
|
+
x = x.view(-1, x.shape[-1])
|
111
|
+
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
|
112
|
+
|
113
|
+
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
|
114
|
+
# x_s is not None = true, x_q is uint8 num
|
115
|
+
if use_fused_quant_gemm or x_s is not None:
|
116
|
+
x_q = x
|
117
|
+
else:
|
118
|
+
x_q, x_s = dynamic_mxfp4_quant(x)
|
119
|
+
|
120
|
+
if y is None:
|
121
|
+
y = torch.empty(
|
122
|
+
x_q.shape[0],
|
123
|
+
layer.weight.shape[0],
|
124
|
+
device=x_q.device,
|
125
|
+
dtype=self.out_dtype,
|
126
|
+
)
|
127
|
+
|
128
|
+
if use_fused_quant_gemm:
|
129
|
+
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
|
130
|
+
y = y.to(x.dtype)
|
131
|
+
else:
|
132
|
+
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
|
133
|
+
|
134
|
+
if three_d:
|
135
|
+
return y.view(*output_shape)
|
136
|
+
|
137
|
+
return y
|
@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
|
|
5
5
|
from types import MappingProxyType
|
6
6
|
from typing import Any, Optional
|
7
7
|
|
8
|
+
import torch
|
9
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
10
|
+
from torch import nn
|
11
|
+
|
8
12
|
|
9
13
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
10
14
|
if type(dict1) is not type(dict2):
|
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
|
|
105
109
|
elif target == value:
|
106
110
|
return True
|
107
111
|
return False
|
112
|
+
|
113
|
+
|
114
|
+
# utility for tensor dims > 2 cases
|
115
|
+
def b_dynamic_mxfp4_quant(x):
|
116
|
+
h, b, d = x.shape
|
117
|
+
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
|
118
|
+
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
|
119
|
+
|
120
|
+
|
121
|
+
def mxfp4_to_f32(x, is_threed):
|
122
|
+
# 2 because we pack fp4 in uint8.
|
123
|
+
x = x.repeat_interleave(2, dim=-1)
|
124
|
+
if is_threed:
|
125
|
+
x[..., ::2] = x[..., ::2] & 0xF
|
126
|
+
x[..., 1::2] = x[..., 1::2] >> 4
|
127
|
+
else:
|
128
|
+
x[:, ::2] = x[:, ::2] & 0xF
|
129
|
+
x[:, 1::2] = x[:, 1::2] >> 4
|
130
|
+
|
131
|
+
mxfp4_list = [
|
132
|
+
0.0,
|
133
|
+
0.5,
|
134
|
+
1.0,
|
135
|
+
1.5,
|
136
|
+
2.0,
|
137
|
+
3.0,
|
138
|
+
4.0,
|
139
|
+
6.0,
|
140
|
+
-0.0,
|
141
|
+
-0.5,
|
142
|
+
-1.0,
|
143
|
+
-1.5,
|
144
|
+
-2.0,
|
145
|
+
-3.0,
|
146
|
+
-4.0,
|
147
|
+
-6.0,
|
148
|
+
]
|
149
|
+
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
|
150
|
+
return mxfp4_in_f32[x.long()]
|
151
|
+
|
152
|
+
|
153
|
+
def e8m0_to_f32(x):
|
154
|
+
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
|
155
|
+
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
|
156
|
+
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
|
157
|
+
|
158
|
+
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
|
159
|
+
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
|
160
|
+
|
161
|
+
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
|
162
|
+
# Since this custom format has no mantissa, treat 2^128 as NaN.
|
163
|
+
x_f32[x_f32 == 128] = float("nan")
|
164
|
+
return x_f32
|
165
|
+
|
166
|
+
|
167
|
+
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
|
168
|
+
if "mxfp4" in quant_format:
|
169
|
+
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
|
170
|
+
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
|
171
|
+
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
|
172
|
+
if w.dtype == torch.bfloat16:
|
173
|
+
w_kc, w_vc = w.unflatten(
|
174
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
175
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
176
|
+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
177
|
+
w_kc = w_kc.transpose(-2, -1)
|
178
|
+
w_s_kc = w_s_kc.transpose(-2, -1)
|
179
|
+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
180
|
+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
181
|
+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
182
|
+
elif w.dtype == torch.uint8: # static quant for mxfp4
|
183
|
+
# when dtype is uint8, it means the w has been quantized to mxfp4 format
|
184
|
+
# but we must separate it to w_kc and w_vc.
|
185
|
+
# The quantized tensor size is only half of original tensor size
|
186
|
+
# and the scaling factor is 1/32, the transpose behavior will be not correct
|
187
|
+
# need to upcast it to fp32 to separate w to w_kc and w_vc
|
188
|
+
# to ensure the following transpose behavior is correct
|
189
|
+
# and then do mxfp4 quant again
|
190
|
+
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
|
191
|
+
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
|
192
|
+
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
|
193
|
+
w = w * w_scales
|
194
|
+
w_kc, w_vc = w.unflatten(
|
195
|
+
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
|
196
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
197
|
+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
198
|
+
w_kc = w_kc.transpose(-2, -1)
|
199
|
+
w_s_kc = w_s_kc.transpose(-2, -1)
|
200
|
+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
201
|
+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
202
|
+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
203
|
+
|
204
|
+
return w_kc, w_s_kc, w_vc, w_s_vc
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
|
2
|
+
batched_gemm_afp4wfp4_pre_quant,
|
3
|
+
)
|
4
|
+
from aiter.ops.triton.fused_mxfp4_quant import (
|
5
|
+
fused_flatten_mxfp4_quant,
|
6
|
+
fused_rms_mxfp4_quant,
|
7
|
+
)
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"fused_rms_mxfp4_quant",
|
11
|
+
"fused_flatten_mxfp4_quant",
|
12
|
+
"batched_gemm_afp4wfp4_pre_quant",
|
13
|
+
]
|
@@ -1,12 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
8
8
|
from torch.nn.parameter import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
11
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
10
12
|
from sglang.srt.layers.quantization.base_config import (
|
11
13
|
FusedMoEMethodBase,
|
12
14
|
QuantizationConfig,
|
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
|
|
91
93
|
from sglang.srt.layers.linear import LinearBase
|
92
94
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
93
95
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
96
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
94
97
|
|
95
98
|
if isinstance(layer, LinearBase):
|
96
99
|
if is_layer_skipped(prefix, self.ignored_layers):
|
97
100
|
return UnquantizedLinearMethod()
|
98
101
|
return Fp8LinearMethod(self)
|
99
|
-
elif isinstance(layer,
|
102
|
+
elif isinstance(layer, FusedMoE):
|
100
103
|
return W4AFp8MoEMethod(self)
|
101
104
|
return None
|
102
105
|
|
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
|
|
104
107
|
return []
|
105
108
|
|
106
109
|
|
107
|
-
|
110
|
+
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
|
111
|
+
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
112
|
+
s_shape = scales.shape
|
113
|
+
# Reshape to separate groups of 4
|
114
|
+
alignment = 4 if s_shape[2] % 4 == 0 else 1
|
115
|
+
scales_interleaved = scales.reshape(
|
116
|
+
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
|
117
|
+
)
|
118
|
+
# Permute dimensions to interleave
|
119
|
+
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
120
|
+
# Reshape back to original dimensions but with interleaved values
|
121
|
+
scales_interleaved = scales_interleaved.reshape(
|
122
|
+
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
|
123
|
+
)
|
124
|
+
return scales_interleaved.contiguous()
|
125
|
+
|
108
126
|
|
127
|
+
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
109
128
|
def __init__(self, quant_config: W4AFp8Config):
|
110
129
|
self.quant_config = quant_config
|
111
130
|
|
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
234
253
|
|
235
254
|
return
|
236
255
|
|
237
|
-
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
238
|
-
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
239
|
-
s_shape = scales.shape
|
240
|
-
# Reshape to separate groups of 4
|
241
|
-
scales_interleaved = scales.reshape(
|
242
|
-
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
243
|
-
)
|
244
|
-
# Permute dimensions to interleave
|
245
|
-
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
246
|
-
# Reshape back to original dimensions but with interleaved values
|
247
|
-
scales_interleaved = scales_interleaved.reshape(
|
248
|
-
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
249
|
-
)
|
250
|
-
return scales_interleaved.contiguous()
|
251
|
-
|
252
256
|
def process_weights_after_loading(self, layer: Module) -> None:
|
253
257
|
dtype = torch.bfloat16
|
254
258
|
device = layer.w2_weight.device
|
255
259
|
|
256
260
|
# Interleave w13_weight_scale (gate_up_proj)
|
257
261
|
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
258
|
-
w13_weight_scale =
|
262
|
+
w13_weight_scale = interleave_scales(w13_weight_scale)
|
259
263
|
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
260
264
|
|
261
265
|
# Interleave w2_weight_scale (down_proj)
|
262
266
|
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
263
|
-
w2_weight_scale =
|
267
|
+
w2_weight_scale = interleave_scales(w2_weight_scale)
|
264
268
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
265
269
|
|
266
270
|
# Process input scales
|
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
291
295
|
|
292
296
|
topk_weights, topk_ids, _ = topk_output
|
293
297
|
local_topk_ids = topk_ids
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
298
|
+
if get_moe_expert_parallel_world_size() > 1:
|
299
|
+
local_topk_ids = torch.where(
|
300
|
+
topk_ids == -1,
|
301
|
+
layer.num_experts,
|
302
|
+
topk_ids,
|
303
|
+
)
|
299
304
|
|
300
305
|
output = cutlass_w4a8_moe(
|
301
306
|
layer.start_expert_id,
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import torch
|
2
|
+
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
|
3
|
+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
4
|
+
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
|
5
|
+
|
6
|
+
from sglang.srt.utils import BumpAllocator
|
7
|
+
|
8
|
+
__all__ = ["fused_qk_rope_cat"]
|
9
|
+
|
10
|
+
|
11
|
+
def aiter_dsv3_router_gemm(
|
12
|
+
hidden_states: torch.Tensor,
|
13
|
+
weight: torch.Tensor,
|
14
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
15
|
+
):
|
16
|
+
M = hidden_states.shape[0]
|
17
|
+
N = weight.shape[0]
|
18
|
+
y = None
|
19
|
+
|
20
|
+
if M <= 256:
|
21
|
+
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
|
22
|
+
# for now it is also coupled with zero allocator.
|
23
|
+
if gemm_output_zero_allocator != None:
|
24
|
+
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
|
25
|
+
else:
|
26
|
+
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
|
27
|
+
|
28
|
+
if y is not None:
|
29
|
+
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
|
30
|
+
else:
|
31
|
+
logits = gemm_a16w16(hidden_states, weight)
|
32
|
+
|
33
|
+
return logits
|
34
|
+
|
35
|
+
|
36
|
+
def get_dsv3_gemm_output_zero_allocator_size(
|
37
|
+
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
|
38
|
+
):
|
39
|
+
if embedding_dim != 7168 or n_routed_experts != 256:
|
40
|
+
return 0
|
41
|
+
|
42
|
+
per_layer_size = 256 * (allocate_size + n_routed_experts)
|
43
|
+
|
44
|
+
return num_moe_layers * per_layer_size
|
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1433
1433
|
|
1434
1434
|
return position_ids, mrope_position_deltas
|
1435
1435
|
|
1436
|
-
@staticmethod
|
1437
|
-
def get_next_input_positions(
|
1438
|
-
mrope_position_delta: int,
|
1439
|
-
context_len: int,
|
1440
|
-
seq_len: int,
|
1441
|
-
) -> torch.Tensor:
|
1442
|
-
return torch.tensor(
|
1443
|
-
[
|
1444
|
-
list(
|
1445
|
-
range(
|
1446
|
-
context_len + mrope_position_delta,
|
1447
|
-
seq_len + mrope_position_delta,
|
1448
|
-
)
|
1449
|
-
)
|
1450
|
-
for _ in range(3)
|
1451
|
-
]
|
1452
|
-
)
|
1453
|
-
|
1454
1436
|
|
1455
1437
|
class DualChunkRotaryEmbedding(CustomOp):
|
1456
1438
|
"""Rotary positional embedding for Dual Chunk Attention."""
|