sglang 0.5.2rc1__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +43 -40
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/multi_tokenizer_mixin.py +4 -0
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +4 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +5 -5
- sglang/srt/mem_cache/memory_pool_host.py +16 -11
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +10 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +240 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +65 -61
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
517
517
|
def get_config_filenames(cls) -> List[str]:
|
518
518
|
return ["hf_quant_config.json"]
|
519
519
|
|
520
|
+
@staticmethod
|
521
|
+
def common_group_size(cfg: dict) -> int:
|
522
|
+
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
523
|
+
sizes = set()
|
524
|
+
|
525
|
+
# Top-level and 'quantization' block
|
526
|
+
v = cfg.get("group_size")
|
527
|
+
if isinstance(v, int):
|
528
|
+
sizes.add(v)
|
529
|
+
q = cfg.get("quantization")
|
530
|
+
if isinstance(q, dict):
|
531
|
+
v = q.get("group_size")
|
532
|
+
if isinstance(v, int):
|
533
|
+
sizes.add(v)
|
534
|
+
|
535
|
+
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
|
536
|
+
for g in (cfg.get("config_groups") or {}).values():
|
537
|
+
if isinstance(g, dict):
|
538
|
+
v = g.get("group_size")
|
539
|
+
if isinstance(v, int):
|
540
|
+
sizes.add(v)
|
541
|
+
for sub in g.values():
|
542
|
+
if isinstance(sub, dict):
|
543
|
+
v = sub.get("group_size")
|
544
|
+
if isinstance(v, int):
|
545
|
+
sizes.add(v)
|
546
|
+
|
547
|
+
if not sizes:
|
548
|
+
raise ValueError("No group_size found in config.")
|
549
|
+
if len(sizes) > 1:
|
550
|
+
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
|
551
|
+
return next(iter(sizes))
|
552
|
+
|
520
553
|
@classmethod
|
521
554
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
522
555
|
# Handle two different config formats:
|
@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
549
582
|
else:
|
550
583
|
kv_cache_quant_algo = "auto"
|
551
584
|
|
552
|
-
group_size =
|
585
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
553
586
|
exclude_modules = config.get("ignore", [])
|
554
587
|
else:
|
555
588
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
559
592
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
593
|
if not kv_cache_quant_algo:
|
561
594
|
kv_cache_quant_algo = "auto"
|
562
|
-
group_size =
|
595
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
563
596
|
exclude_modules = quant_config.get("exclude_modules", [])
|
564
597
|
except (ValueError, KeyError):
|
565
598
|
raise ValueError(
|
@@ -816,7 +816,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
816
816
|
moe_runner_config: MoeRunnerConfig,
|
817
817
|
) -> torch.Tensor:
|
818
818
|
topk_weights, topk_ids, _ = topk_output
|
819
|
-
|
819
|
+
if _is_hip:
|
820
|
+
topk_weights = topk_weights.to(
|
821
|
+
torch.float32
|
822
|
+
) # aiter's moe_sorting requires topk_weights to be FP32
|
820
823
|
return fused_moe(
|
821
824
|
x,
|
822
825
|
layer.w13_weight,
|
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8
8
|
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
9
9
|
from aiter.ops.shuffle import shuffle_weight
|
10
10
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
11
|
+
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
11
12
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
12
13
|
from aiter.utility import dtypes
|
13
14
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|
38
39
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
39
40
|
return
|
40
41
|
|
41
|
-
# for aiter implement
|
42
|
-
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
43
|
-
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
44
|
-
|
45
|
-
# layer.weight = torch.nn.Parameter(wshuffle,
|
46
|
-
# requires_grad=False)
|
47
|
-
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
48
|
-
# requires_grad=False)
|
49
|
-
|
50
42
|
def create_weights(
|
51
43
|
self,
|
52
44
|
layer: torch.nn.Module,
|
@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|
93
85
|
x: torch.Tensor,
|
94
86
|
bias: Optional[torch.Tensor] = None,
|
95
87
|
) -> torch.Tensor:
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
88
|
+
# This path does not have support for bias currently
|
89
|
+
assert bias is None, "bias is not supported"
|
90
|
+
|
91
|
+
three_d = False
|
92
|
+
x_s = None
|
93
|
+
y = None
|
94
|
+
if isinstance(x, tuple):
|
95
|
+
assert len(x) in [
|
96
|
+
2,
|
97
|
+
3,
|
98
|
+
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
|
99
|
+
if len(x) == 2:
|
100
|
+
x, x_s = x
|
101
|
+
elif len(x) == 3:
|
102
|
+
x, x_s, y = x
|
103
|
+
|
104
|
+
use_fused_quant_gemm = (
|
105
|
+
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
|
114
106
|
)
|
115
107
|
|
116
|
-
|
117
|
-
|
118
|
-
|
108
|
+
if x.dim() == 3:
|
109
|
+
three_d = True
|
110
|
+
x = x.view(-1, x.shape[-1])
|
111
|
+
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
|
112
|
+
|
113
|
+
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
|
114
|
+
# x_s is not None = true, x_q is uint8 num
|
115
|
+
if use_fused_quant_gemm or x_s is not None:
|
116
|
+
x_q = x
|
117
|
+
else:
|
118
|
+
x_q, x_s = dynamic_mxfp4_quant(x)
|
119
|
+
|
120
|
+
if y is None:
|
121
|
+
y = torch.empty(
|
122
|
+
x_q.shape[0],
|
123
|
+
layer.weight.shape[0],
|
124
|
+
device=x_q.device,
|
125
|
+
dtype=self.out_dtype,
|
126
|
+
)
|
127
|
+
|
128
|
+
if use_fused_quant_gemm:
|
129
|
+
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
|
130
|
+
y = y.to(x.dtype)
|
131
|
+
else:
|
132
|
+
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
|
133
|
+
|
134
|
+
if three_d:
|
135
|
+
return y.view(*output_shape)
|
136
|
+
|
137
|
+
return y
|
@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
|
|
5
5
|
from types import MappingProxyType
|
6
6
|
from typing import Any, Optional
|
7
7
|
|
8
|
+
import torch
|
9
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
10
|
+
from torch import nn
|
11
|
+
|
8
12
|
|
9
13
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
10
14
|
if type(dict1) is not type(dict2):
|
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
|
|
105
109
|
elif target == value:
|
106
110
|
return True
|
107
111
|
return False
|
112
|
+
|
113
|
+
|
114
|
+
# utility for tensor dims > 2 cases
|
115
|
+
def b_dynamic_mxfp4_quant(x):
|
116
|
+
h, b, d = x.shape
|
117
|
+
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
|
118
|
+
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
|
119
|
+
|
120
|
+
|
121
|
+
def mxfp4_to_f32(x, is_threed):
|
122
|
+
# 2 because we pack fp4 in uint8.
|
123
|
+
x = x.repeat_interleave(2, dim=-1)
|
124
|
+
if is_threed:
|
125
|
+
x[..., ::2] = x[..., ::2] & 0xF
|
126
|
+
x[..., 1::2] = x[..., 1::2] >> 4
|
127
|
+
else:
|
128
|
+
x[:, ::2] = x[:, ::2] & 0xF
|
129
|
+
x[:, 1::2] = x[:, 1::2] >> 4
|
130
|
+
|
131
|
+
mxfp4_list = [
|
132
|
+
0.0,
|
133
|
+
0.5,
|
134
|
+
1.0,
|
135
|
+
1.5,
|
136
|
+
2.0,
|
137
|
+
3.0,
|
138
|
+
4.0,
|
139
|
+
6.0,
|
140
|
+
-0.0,
|
141
|
+
-0.5,
|
142
|
+
-1.0,
|
143
|
+
-1.5,
|
144
|
+
-2.0,
|
145
|
+
-3.0,
|
146
|
+
-4.0,
|
147
|
+
-6.0,
|
148
|
+
]
|
149
|
+
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
|
150
|
+
return mxfp4_in_f32[x.long()]
|
151
|
+
|
152
|
+
|
153
|
+
def e8m0_to_f32(x):
|
154
|
+
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
|
155
|
+
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
|
156
|
+
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
|
157
|
+
|
158
|
+
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
|
159
|
+
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
|
160
|
+
|
161
|
+
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
|
162
|
+
# Since this custom format has no mantissa, treat 2^128 as NaN.
|
163
|
+
x_f32[x_f32 == 128] = float("nan")
|
164
|
+
return x_f32
|
165
|
+
|
166
|
+
|
167
|
+
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
|
168
|
+
if "mxfp4" in quant_format:
|
169
|
+
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
|
170
|
+
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
|
171
|
+
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
|
172
|
+
if w.dtype == torch.bfloat16:
|
173
|
+
w_kc, w_vc = w.unflatten(
|
174
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
175
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
176
|
+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
177
|
+
w_kc = w_kc.transpose(-2, -1)
|
178
|
+
w_s_kc = w_s_kc.transpose(-2, -1)
|
179
|
+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
180
|
+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
181
|
+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
182
|
+
elif w.dtype == torch.uint8: # static quant for mxfp4
|
183
|
+
# when dtype is uint8, it means the w has been quantized to mxfp4 format
|
184
|
+
# but we must separate it to w_kc and w_vc.
|
185
|
+
# The quantized tensor size is only half of original tensor size
|
186
|
+
# and the scaling factor is 1/32, the transpose behavior will be not correct
|
187
|
+
# need to upcast it to fp32 to separate w to w_kc and w_vc
|
188
|
+
# to ensure the following transpose behavior is correct
|
189
|
+
# and then do mxfp4 quant again
|
190
|
+
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
|
191
|
+
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
|
192
|
+
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
|
193
|
+
w = w * w_scales
|
194
|
+
w_kc, w_vc = w.unflatten(
|
195
|
+
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
|
196
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
197
|
+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
198
|
+
w_kc = w_kc.transpose(-2, -1)
|
199
|
+
w_s_kc = w_s_kc.transpose(-2, -1)
|
200
|
+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
201
|
+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
202
|
+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
203
|
+
|
204
|
+
return w_kc, w_s_kc, w_vc, w_s_vc
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
|
2
|
+
batched_gemm_afp4wfp4_pre_quant,
|
3
|
+
)
|
4
|
+
from aiter.ops.triton.fused_mxfp4_quant import (
|
5
|
+
fused_flatten_mxfp4_quant,
|
6
|
+
fused_rms_mxfp4_quant,
|
7
|
+
)
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"fused_rms_mxfp4_quant",
|
11
|
+
"fused_flatten_mxfp4_quant",
|
12
|
+
"batched_gemm_afp4wfp4_pre_quant",
|
13
|
+
]
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import torch
|
2
|
+
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
|
3
|
+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
4
|
+
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
|
5
|
+
|
6
|
+
from sglang.srt.utils import BumpAllocator
|
7
|
+
|
8
|
+
__all__ = ["fused_qk_rope_cat"]
|
9
|
+
|
10
|
+
|
11
|
+
def aiter_dsv3_router_gemm(
|
12
|
+
hidden_states: torch.Tensor,
|
13
|
+
weight: torch.Tensor,
|
14
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
15
|
+
):
|
16
|
+
M = hidden_states.shape[0]
|
17
|
+
N = weight.shape[0]
|
18
|
+
y = None
|
19
|
+
|
20
|
+
if M <= 256:
|
21
|
+
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
|
22
|
+
# for now it is also coupled with zero allocator.
|
23
|
+
if gemm_output_zero_allocator != None:
|
24
|
+
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
|
25
|
+
else:
|
26
|
+
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
|
27
|
+
|
28
|
+
if y is not None:
|
29
|
+
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
|
30
|
+
else:
|
31
|
+
logits = gemm_a16w16(hidden_states, weight)
|
32
|
+
|
33
|
+
return logits
|
34
|
+
|
35
|
+
|
36
|
+
def get_dsv3_gemm_output_zero_allocator_size(
|
37
|
+
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
|
38
|
+
):
|
39
|
+
if embedding_dim != 7168 or n_routed_experts != 256:
|
40
|
+
return 0
|
41
|
+
|
42
|
+
per_layer_size = 256 * (allocate_size + n_routed_experts)
|
43
|
+
|
44
|
+
return num_moe_layers * per_layer_size
|
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1433
1433
|
|
1434
1434
|
return position_ids, mrope_position_deltas
|
1435
1435
|
|
1436
|
-
@staticmethod
|
1437
|
-
def get_next_input_positions(
|
1438
|
-
mrope_position_delta: int,
|
1439
|
-
context_len: int,
|
1440
|
-
seq_len: int,
|
1441
|
-
) -> torch.Tensor:
|
1442
|
-
return torch.tensor(
|
1443
|
-
[
|
1444
|
-
list(
|
1445
|
-
range(
|
1446
|
-
context_len + mrope_position_delta,
|
1447
|
-
seq_len + mrope_position_delta,
|
1448
|
-
)
|
1449
|
-
)
|
1450
|
-
for _ in range(3)
|
1451
|
-
]
|
1452
|
-
)
|
1453
|
-
|
1454
1436
|
|
1455
1437
|
class DualChunkRotaryEmbedding(CustomOp):
|
1456
1438
|
"""Rotary positional embedding for Dual Chunk Attention."""
|
@@ -324,6 +324,22 @@ class HiCacheController:
|
|
324
324
|
group_ranks, backend="gloo"
|
325
325
|
)
|
326
326
|
|
327
|
+
# Select the get and set functions
|
328
|
+
self.page_get_func = self._generic_page_get
|
329
|
+
self.page_set_func = self._generic_page_set
|
330
|
+
self.batch_exists_func = self.storage_backend.batch_exists
|
331
|
+
self.is_3fs_zerocopy = (
|
332
|
+
self.storage_backend_type == "hf3fs"
|
333
|
+
and self.mem_pool_host.layout == "page_first"
|
334
|
+
)
|
335
|
+
if self.storage_backend_type == "mooncake":
|
336
|
+
self.page_get_func = self._mooncake_page_get
|
337
|
+
self.page_set_func = self._mooncake_page_set
|
338
|
+
elif self.is_3fs_zerocopy:
|
339
|
+
self.page_get_func = self._3fs_zero_copy_page_get
|
340
|
+
self.page_set_func = self._3fs_zero_copy_page_set
|
341
|
+
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
342
|
+
|
327
343
|
self.load_cache_event = load_cache_event
|
328
344
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
329
345
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
@@ -407,6 +423,7 @@ class HiCacheController:
|
|
407
423
|
tp_rank=self.tp_rank,
|
408
424
|
tp_size=self.tp_size,
|
409
425
|
is_mla_model=is_mla_backend,
|
426
|
+
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
410
427
|
model_name=model_name,
|
411
428
|
extra_config=extra_config,
|
412
429
|
)
|
@@ -616,13 +633,19 @@ class HiCacheController:
|
|
616
633
|
for chunk in chunks:
|
617
634
|
self.host_mem_release_queue.put(chunk)
|
618
635
|
|
636
|
+
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
637
|
+
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
638
|
+
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
639
|
+
return hit_page_num
|
640
|
+
|
619
641
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
620
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
642
|
+
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
621
643
|
hash_values, host_indices
|
622
644
|
)
|
623
645
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
624
646
|
if page_data:
|
625
|
-
|
647
|
+
inc = self.page_size * len(hashes) // factor
|
648
|
+
operation.increment(inc)
|
626
649
|
else:
|
627
650
|
logger.warning(
|
628
651
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
@@ -636,7 +659,7 @@ class HiCacheController:
|
|
636
659
|
)
|
637
660
|
get_result = self.storage_backend.batch_get(
|
638
661
|
key_strs,
|
639
|
-
|
662
|
+
target_locations=buffer_ptrs,
|
640
663
|
target_sizes=buffer_sizes,
|
641
664
|
)
|
642
665
|
if get_result != len(hash_values):
|
@@ -647,9 +670,9 @@ class HiCacheController:
|
|
647
670
|
operation.increment(get_result * self.page_size)
|
648
671
|
|
649
672
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
650
|
-
dummy_page_dst = [
|
651
|
-
hash_values
|
652
|
-
|
673
|
+
dummy_page_dst = [
|
674
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
675
|
+
]
|
653
676
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
654
677
|
if page_data is None:
|
655
678
|
return
|
@@ -659,26 +682,16 @@ class HiCacheController:
|
|
659
682
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
660
683
|
)
|
661
684
|
break
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
685
|
+
# Must set the data before increasing the completed tokens.
|
686
|
+
# Otherwise this page may be read before being set.
|
687
|
+
self.mem_pool_host.set_from_flat_data_page(
|
688
|
+
host_indices[i * self.page_size],
|
689
|
+
page_data[i],
|
690
|
+
)
|
691
|
+
if not operation.increment(self.page_size):
|
692
|
+
break # Operation terminated by controller
|
669
693
|
|
670
694
|
def _page_transfer(self, operation):
|
671
|
-
# Select the get function and batch size
|
672
|
-
if self.storage_backend_type == "mooncake":
|
673
|
-
get_func = self._mooncake_page_get
|
674
|
-
elif (
|
675
|
-
self.storage_backend_type == "hf3fs"
|
676
|
-
and self.mem_pool_host.layout == "page_first"
|
677
|
-
):
|
678
|
-
get_func = self._3fs_zero_copy_page_get
|
679
|
-
else:
|
680
|
-
get_func = self._generic_page_get
|
681
|
-
|
682
695
|
# Transfer batch by batch
|
683
696
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
684
697
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -687,7 +700,7 @@ class HiCacheController:
|
|
687
700
|
]
|
688
701
|
prev_completed_tokens = operation.completed_tokens
|
689
702
|
# Get one batch token, and update the completed_tokens if succeed
|
690
|
-
|
703
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
691
704
|
# Check termination
|
692
705
|
if (
|
693
706
|
operation.completed_tokens
|
@@ -744,7 +757,7 @@ class HiCacheController:
|
|
744
757
|
batch_tokens[i : i + self.page_size], last_hash
|
745
758
|
)
|
746
759
|
batch_hashes.append(last_hash)
|
747
|
-
hit_page_num = self.
|
760
|
+
hit_page_num = self.batch_exists_func(batch_hashes)
|
748
761
|
hash_value.extend(batch_hashes[:hit_page_num])
|
749
762
|
storage_query_count += hit_page_num * self.page_size
|
750
763
|
if hit_page_num < len(batch_hashes):
|
@@ -830,30 +843,20 @@ class HiCacheController:
|
|
830
843
|
)
|
831
844
|
success = self.storage_backend.batch_set(
|
832
845
|
key_strs,
|
833
|
-
|
846
|
+
target_locations=buffer_ptrs,
|
834
847
|
target_sizes=buffer_sizes,
|
835
848
|
)
|
836
849
|
return success
|
837
850
|
|
838
851
|
# zero copy
|
839
852
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
840
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
853
|
+
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
841
854
|
hash_values, host_indices
|
842
855
|
)
|
843
856
|
return self.storage_backend.batch_set(hashes, dsts)
|
844
857
|
|
845
858
|
# Backup batch by batch
|
846
859
|
def _page_backup(self, operation):
|
847
|
-
# Select the set function and batch size
|
848
|
-
if self.storage_backend_type == "mooncake":
|
849
|
-
backup_set_func = self._mooncake_page_set
|
850
|
-
elif (
|
851
|
-
self.storage_backend_type == "hf3fs"
|
852
|
-
and self.mem_pool_host.layout == "page_first"
|
853
|
-
):
|
854
|
-
backup_set_func = self._3fs_zero_copy_page_set
|
855
|
-
else:
|
856
|
-
backup_set_func = self._generic_page_set
|
857
860
|
# Backup batch by batch
|
858
861
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
859
862
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -862,7 +865,7 @@ class HiCacheController:
|
|
862
865
|
]
|
863
866
|
# Set one batch token, and record if success.
|
864
867
|
# todo: allow partial success
|
865
|
-
success =
|
868
|
+
success = self.page_set_func(batch_hashes, batch_host_indices)
|
866
869
|
if not success:
|
867
870
|
logger.warning(
|
868
871
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|