sglang 0.5.2rc1__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/disaggregation/mini_lb.py +2 -2
  4. sglang/srt/distributed/parallel_state.py +43 -40
  5. sglang/srt/entrypoints/http_server.py +5 -1
  6. sglang/srt/entrypoints/openai/protocol.py +3 -3
  7. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  8. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  9. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  10. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  11. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  12. sglang/srt/layers/attention/aiter_backend.py +93 -68
  13. sglang/srt/layers/communicator.py +45 -7
  14. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/moe/utils.py +0 -1
  16. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  17. sglang/srt/layers/quantization/mxfp4.py +4 -1
  18. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  19. sglang/srt/layers/quantization/quark/utils.py +97 -0
  20. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  21. sglang/srt/layers/rocm_linear_utils.py +44 -0
  22. sglang/srt/layers/rotary_embedding.py +0 -18
  23. sglang/srt/managers/cache_controller.py +42 -39
  24. sglang/srt/managers/multi_tokenizer_mixin.py +4 -0
  25. sglang/srt/managers/schedule_policy.py +3 -2
  26. sglang/srt/managers/scheduler.py +4 -100
  27. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  28. sglang/srt/managers/template_manager.py +3 -3
  29. sglang/srt/managers/tokenizer_manager.py +1 -0
  30. sglang/srt/mem_cache/allocator.py +1 -1
  31. sglang/srt/mem_cache/hicache_storage.py +15 -10
  32. sglang/srt/mem_cache/hiradix_cache.py +5 -5
  33. sglang/srt/mem_cache/memory_pool_host.py +16 -11
  34. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +10 -2
  35. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  36. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  37. sglang/srt/metrics/collector.py +12 -4
  38. sglang/srt/metrics/utils.py +48 -0
  39. sglang/srt/model_executor/forward_batch_info.py +16 -17
  40. sglang/srt/model_executor/model_runner.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +240 -36
  42. sglang/srt/models/glm4_moe.py +10 -1
  43. sglang/srt/models/internvl.py +28 -0
  44. sglang/srt/models/minicpmv.py +165 -3
  45. sglang/srt/models/qwen2_moe.py +4 -1
  46. sglang/srt/models/qwen3.py +8 -2
  47. sglang/srt/models/qwen3_moe.py +39 -8
  48. sglang/srt/models/torch_native_llama.py +1 -1
  49. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  50. sglang/srt/server_args.py +79 -2
  51. sglang/srt/speculative/eagle_worker.py +158 -112
  52. sglang/srt/utils.py +12 -0
  53. sglang/test/few_shot_gsm8k.py +1 -0
  54. sglang/utils.py +1 -0
  55. sglang/version.py +1 -1
  56. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +1 -1
  57. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +65 -61
  58. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  59. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  60. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  61. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  62. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  63. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  64. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  65. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  66. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -162,7 +162,6 @@ def get_deepep_config() -> str:
162
162
  def is_tbo_enabled() -> bool:
163
163
  global IS_TBO_ENABLED
164
164
  if IS_TBO_ENABLED is None:
165
- logger.warning("IS_TBO_ENABLED is not initialized, using False")
166
165
  IS_TBO_ENABLED = False
167
166
  return IS_TBO_ENABLED
168
167
 
@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
517
517
  def get_config_filenames(cls) -> List[str]:
518
518
  return ["hf_quant_config.json"]
519
519
 
520
+ @staticmethod
521
+ def common_group_size(cfg: dict) -> int:
522
+ """Return the unique group_size across the config; raise if missing/mismatched."""
523
+ sizes = set()
524
+
525
+ # Top-level and 'quantization' block
526
+ v = cfg.get("group_size")
527
+ if isinstance(v, int):
528
+ sizes.add(v)
529
+ q = cfg.get("quantization")
530
+ if isinstance(q, dict):
531
+ v = q.get("group_size")
532
+ if isinstance(v, int):
533
+ sizes.add(v)
534
+
535
+ # config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
536
+ for g in (cfg.get("config_groups") or {}).values():
537
+ if isinstance(g, dict):
538
+ v = g.get("group_size")
539
+ if isinstance(v, int):
540
+ sizes.add(v)
541
+ for sub in g.values():
542
+ if isinstance(sub, dict):
543
+ v = sub.get("group_size")
544
+ if isinstance(v, int):
545
+ sizes.add(v)
546
+
547
+ if not sizes:
548
+ raise ValueError("No group_size found in config.")
549
+ if len(sizes) > 1:
550
+ raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
551
+ return next(iter(sizes))
552
+
520
553
  @classmethod
521
554
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
522
555
  # Handle two different config formats:
@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
549
582
  else:
550
583
  kv_cache_quant_algo = "auto"
551
584
 
552
- group_size = config.get("group_size")
585
+ group_size = ModelOptFp4Config.common_group_size(config)
553
586
  exclude_modules = config.get("ignore", [])
554
587
  else:
555
588
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
559
592
  kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
593
  if not kv_cache_quant_algo:
561
594
  kv_cache_quant_algo = "auto"
562
- group_size = quant_config.get("group_size")
595
+ group_size = ModelOptFp4Config.common_group_size(config)
563
596
  exclude_modules = quant_config.get("exclude_modules", [])
564
597
  except (ValueError, KeyError):
565
598
  raise ValueError(
@@ -816,7 +816,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
816
816
  moe_runner_config: MoeRunnerConfig,
817
817
  ) -> torch.Tensor:
818
818
  topk_weights, topk_ids, _ = topk_output
819
-
819
+ if _is_hip:
820
+ topk_weights = topk_weights.to(
821
+ torch.float32
822
+ ) # aiter's moe_sorting requires topk_weights to be FP32
820
823
  return fused_moe(
821
824
  x,
822
825
  layer.w13_weight,
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
8
  from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
9
  from aiter.ops.shuffle import shuffle_weight
10
10
  from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
+ from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
11
12
  from aiter.ops.triton.quant import dynamic_mxfp4_quant
12
13
  from aiter.utility import dtypes
13
14
  from aiter.utility.fp4_utils import e8m0_shuffle
@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
38
39
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
39
40
  return
40
41
 
41
- # for aiter implement
42
- # wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
43
- # w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
44
-
45
- # layer.weight = torch.nn.Parameter(wshuffle,
46
- # requires_grad=False)
47
- # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
48
- # requires_grad=False)
49
-
50
42
  def create_weights(
51
43
  self,
52
44
  layer: torch.nn.Module,
@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
93
85
  x: torch.Tensor,
94
86
  bias: Optional[torch.Tensor] = None,
95
87
  ) -> torch.Tensor:
96
-
97
- out_dtype = x.dtype
98
- # M = x.shape[0]
99
- # N = layer.weight.shape[0]
100
-
101
- # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
102
- # x, x_scales_shuffle = quant_func(x, shuffle=True)
103
-
104
- # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
105
-
106
- # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
107
-
108
- # return out[:M]
109
-
110
- # triton implement
111
- x_q, x_s = dynamic_mxfp4_quant(x)
112
- y = torch.empty(
113
- x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
88
+ # This path does not have support for bias currently
89
+ assert bias is None, "bias is not supported"
90
+
91
+ three_d = False
92
+ x_s = None
93
+ y = None
94
+ if isinstance(x, tuple):
95
+ assert len(x) in [
96
+ 2,
97
+ 3,
98
+ ], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
99
+ if len(x) == 2:
100
+ x, x_s = x
101
+ elif len(x) == 3:
102
+ x, x_s, y = x
103
+
104
+ use_fused_quant_gemm = (
105
+ x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
114
106
  )
115
107
 
116
- out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
117
-
118
- return out
108
+ if x.dim() == 3:
109
+ three_d = True
110
+ x = x.view(-1, x.shape[-1])
111
+ output_shape = [*x.shape[:-1], layer.weight.shape[0]]
112
+
113
+ # use_fused_quant_gemm = true, x_q is a bf16/fp16 num
114
+ # x_s is not None = true, x_q is uint8 num
115
+ if use_fused_quant_gemm or x_s is not None:
116
+ x_q = x
117
+ else:
118
+ x_q, x_s = dynamic_mxfp4_quant(x)
119
+
120
+ if y is None:
121
+ y = torch.empty(
122
+ x_q.shape[0],
123
+ layer.weight.shape[0],
124
+ device=x_q.device,
125
+ dtype=self.out_dtype,
126
+ )
127
+
128
+ if use_fused_quant_gemm:
129
+ gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
130
+ y = y.to(x.dtype)
131
+ else:
132
+ gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
133
+
134
+ if three_d:
135
+ return y.view(*output_shape)
136
+
137
+ return y
@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
5
5
  from types import MappingProxyType
6
6
  from typing import Any, Optional
7
7
 
8
+ import torch
9
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
10
+ from torch import nn
11
+
8
12
 
9
13
  def deep_compare(dict1: Any, dict2: Any) -> bool:
10
14
  if type(dict1) is not type(dict2):
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
105
109
  elif target == value:
106
110
  return True
107
111
  return False
112
+
113
+
114
+ # utility for tensor dims > 2 cases
115
+ def b_dynamic_mxfp4_quant(x):
116
+ h, b, d = x.shape
117
+ x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
118
+ return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
119
+
120
+
121
+ def mxfp4_to_f32(x, is_threed):
122
+ # 2 because we pack fp4 in uint8.
123
+ x = x.repeat_interleave(2, dim=-1)
124
+ if is_threed:
125
+ x[..., ::2] = x[..., ::2] & 0xF
126
+ x[..., 1::2] = x[..., 1::2] >> 4
127
+ else:
128
+ x[:, ::2] = x[:, ::2] & 0xF
129
+ x[:, 1::2] = x[:, 1::2] >> 4
130
+
131
+ mxfp4_list = [
132
+ 0.0,
133
+ 0.5,
134
+ 1.0,
135
+ 1.5,
136
+ 2.0,
137
+ 3.0,
138
+ 4.0,
139
+ 6.0,
140
+ -0.0,
141
+ -0.5,
142
+ -1.0,
143
+ -1.5,
144
+ -2.0,
145
+ -3.0,
146
+ -4.0,
147
+ -6.0,
148
+ ]
149
+ mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
150
+ return mxfp4_in_f32[x.long()]
151
+
152
+
153
+ def e8m0_to_f32(x):
154
+ # Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
155
+ # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
156
+ # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
157
+
158
+ # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
159
+ x_f32 = 2 ** ((x.to(torch.float32)) - 127)
160
+
161
+ # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
162
+ # Since this custom format has no mantissa, treat 2^128 as NaN.
163
+ x_f32[x_f32 == 128] = float("nan")
164
+ return x_f32
165
+
166
+
167
+ def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
168
+ if "mxfp4" in quant_format:
169
+ # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
170
+ # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
171
+ # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
172
+ if w.dtype == torch.bfloat16:
173
+ w_kc, w_vc = w.unflatten(
174
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
175
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
176
+ w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
177
+ w_kc = w_kc.transpose(-2, -1)
178
+ w_s_kc = w_s_kc.transpose(-2, -1)
179
+ w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
180
+ w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
181
+ w_s_vc = w_s_vc.contiguous().transpose(1, 2)
182
+ elif w.dtype == torch.uint8: # static quant for mxfp4
183
+ # when dtype is uint8, it means the w has been quantized to mxfp4 format
184
+ # but we must separate it to w_kc and w_vc.
185
+ # The quantized tensor size is only half of original tensor size
186
+ # and the scaling factor is 1/32, the transpose behavior will be not correct
187
+ # need to upcast it to fp32 to separate w to w_kc and w_vc
188
+ # to ensure the following transpose behavior is correct
189
+ # and then do mxfp4 quant again
190
+ w = mxfp4_to_f32(w, True).to(torch.bfloat16)
191
+ w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
192
+ w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
193
+ w = w * w_scales
194
+ w_kc, w_vc = w.unflatten(
195
+ 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
196
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
197
+ w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
198
+ w_kc = w_kc.transpose(-2, -1)
199
+ w_s_kc = w_s_kc.transpose(-2, -1)
200
+ w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
201
+ w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
202
+ w_s_vc = w_s_vc.contiguous().transpose(1, 2)
203
+
204
+ return w_kc, w_s_kc, w_vc, w_s_vc
@@ -0,0 +1,13 @@
1
+ from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
2
+ batched_gemm_afp4wfp4_pre_quant,
3
+ )
4
+ from aiter.ops.triton.fused_mxfp4_quant import (
5
+ fused_flatten_mxfp4_quant,
6
+ fused_rms_mxfp4_quant,
7
+ )
8
+
9
+ __all__ = [
10
+ "fused_rms_mxfp4_quant",
11
+ "fused_flatten_mxfp4_quant",
12
+ "batched_gemm_afp4wfp4_pre_quant",
13
+ ]
@@ -0,0 +1,44 @@
1
+ import torch
2
+ from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
3
+ from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
4
+ from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
5
+
6
+ from sglang.srt.utils import BumpAllocator
7
+
8
+ __all__ = ["fused_qk_rope_cat"]
9
+
10
+
11
+ def aiter_dsv3_router_gemm(
12
+ hidden_states: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ gemm_output_zero_allocator: BumpAllocator = None,
15
+ ):
16
+ M = hidden_states.shape[0]
17
+ N = weight.shape[0]
18
+ y = None
19
+
20
+ if M <= 256:
21
+ # TODO (cagri): convert to bfloat16 as part of another kernel to save time
22
+ # for now it is also coupled with zero allocator.
23
+ if gemm_output_zero_allocator != None:
24
+ y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
25
+ else:
26
+ y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
27
+
28
+ if y is not None:
29
+ logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
30
+ else:
31
+ logits = gemm_a16w16(hidden_states, weight)
32
+
33
+ return logits
34
+
35
+
36
+ def get_dsv3_gemm_output_zero_allocator_size(
37
+ n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
38
+ ):
39
+ if embedding_dim != 7168 or n_routed_experts != 256:
40
+ return 0
41
+
42
+ per_layer_size = 256 * (allocate_size + n_routed_experts)
43
+
44
+ return num_moe_layers * per_layer_size
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
1433
1433
 
1434
1434
  return position_ids, mrope_position_deltas
1435
1435
 
1436
- @staticmethod
1437
- def get_next_input_positions(
1438
- mrope_position_delta: int,
1439
- context_len: int,
1440
- seq_len: int,
1441
- ) -> torch.Tensor:
1442
- return torch.tensor(
1443
- [
1444
- list(
1445
- range(
1446
- context_len + mrope_position_delta,
1447
- seq_len + mrope_position_delta,
1448
- )
1449
- )
1450
- for _ in range(3)
1451
- ]
1452
- )
1453
-
1454
1436
 
1455
1437
  class DualChunkRotaryEmbedding(CustomOp):
1456
1438
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -324,6 +324,22 @@ class HiCacheController:
324
324
  group_ranks, backend="gloo"
325
325
  )
326
326
 
327
+ # Select the get and set functions
328
+ self.page_get_func = self._generic_page_get
329
+ self.page_set_func = self._generic_page_set
330
+ self.batch_exists_func = self.storage_backend.batch_exists
331
+ self.is_3fs_zerocopy = (
332
+ self.storage_backend_type == "hf3fs"
333
+ and self.mem_pool_host.layout == "page_first"
334
+ )
335
+ if self.storage_backend_type == "mooncake":
336
+ self.page_get_func = self._mooncake_page_get
337
+ self.page_set_func = self._mooncake_page_set
338
+ elif self.is_3fs_zerocopy:
339
+ self.page_get_func = self._3fs_zero_copy_page_get
340
+ self.page_set_func = self._3fs_zero_copy_page_set
341
+ self.batch_exists_func = self._3fs_zero_copy_batch_exists
342
+
327
343
  self.load_cache_event = load_cache_event
328
344
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
329
345
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
@@ -407,6 +423,7 @@ class HiCacheController:
407
423
  tp_rank=self.tp_rank,
408
424
  tp_size=self.tp_size,
409
425
  is_mla_model=is_mla_backend,
426
+ is_page_first_layout=self.mem_pool_host.layout == "page_first",
410
427
  model_name=model_name,
411
428
  extra_config=extra_config,
412
429
  )
@@ -616,13 +633,19 @@ class HiCacheController:
616
633
  for chunk in chunks:
617
634
  self.host_mem_release_queue.put(chunk)
618
635
 
636
+ def _3fs_zero_copy_batch_exists(self, batch_hashes):
637
+ _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
638
+ hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
639
+ return hit_page_num
640
+
619
641
  def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
620
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
642
+ hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
621
643
  hash_values, host_indices
622
644
  )
623
645
  page_data = self.storage_backend.batch_get(hashes, dsts)
624
646
  if page_data:
625
- operation.increment(self.page_size * len(hashes))
647
+ inc = self.page_size * len(hashes) // factor
648
+ operation.increment(inc)
626
649
  else:
627
650
  logger.warning(
628
651
  f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
@@ -636,7 +659,7 @@ class HiCacheController:
636
659
  )
637
660
  get_result = self.storage_backend.batch_get(
638
661
  key_strs,
639
- target_location=buffer_ptrs,
662
+ target_locations=buffer_ptrs,
640
663
  target_sizes=buffer_sizes,
641
664
  )
642
665
  if get_result != len(hash_values):
@@ -647,9 +670,9 @@ class HiCacheController:
647
670
  operation.increment(get_result * self.page_size)
648
671
 
649
672
  def _generic_page_get(self, operation, hash_values, host_indices):
650
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
651
- hash_values
652
- )
673
+ dummy_page_dst = [
674
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
675
+ ]
653
676
  page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
654
677
  if page_data is None:
655
678
  return
@@ -659,26 +682,16 @@ class HiCacheController:
659
682
  f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
660
683
  )
661
684
  break
662
- if operation.increment(self.page_size):
663
- self.mem_pool_host.set_from_flat_data_page(
664
- host_indices[i * self.page_size],
665
- page_data[i],
666
- )
667
- else:
668
- break
685
+ # Must set the data before increasing the completed tokens.
686
+ # Otherwise this page may be read before being set.
687
+ self.mem_pool_host.set_from_flat_data_page(
688
+ host_indices[i * self.page_size],
689
+ page_data[i],
690
+ )
691
+ if not operation.increment(self.page_size):
692
+ break # Operation terminated by controller
669
693
 
670
694
  def _page_transfer(self, operation):
671
- # Select the get function and batch size
672
- if self.storage_backend_type == "mooncake":
673
- get_func = self._mooncake_page_get
674
- elif (
675
- self.storage_backend_type == "hf3fs"
676
- and self.mem_pool_host.layout == "page_first"
677
- ):
678
- get_func = self._3fs_zero_copy_page_get
679
- else:
680
- get_func = self._generic_page_get
681
-
682
695
  # Transfer batch by batch
683
696
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
684
697
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -687,7 +700,7 @@ class HiCacheController:
687
700
  ]
688
701
  prev_completed_tokens = operation.completed_tokens
689
702
  # Get one batch token, and update the completed_tokens if succeed
690
- get_func(operation, batch_hashes, batch_host_indices)
703
+ self.page_get_func(operation, batch_hashes, batch_host_indices)
691
704
  # Check termination
692
705
  if (
693
706
  operation.completed_tokens
@@ -744,7 +757,7 @@ class HiCacheController:
744
757
  batch_tokens[i : i + self.page_size], last_hash
745
758
  )
746
759
  batch_hashes.append(last_hash)
747
- hit_page_num = self.storage_backend.batch_exists(batch_hashes)
760
+ hit_page_num = self.batch_exists_func(batch_hashes)
748
761
  hash_value.extend(batch_hashes[:hit_page_num])
749
762
  storage_query_count += hit_page_num * self.page_size
750
763
  if hit_page_num < len(batch_hashes):
@@ -830,30 +843,20 @@ class HiCacheController:
830
843
  )
831
844
  success = self.storage_backend.batch_set(
832
845
  key_strs,
833
- target_location=buffer_ptrs,
846
+ target_locations=buffer_ptrs,
834
847
  target_sizes=buffer_sizes,
835
848
  )
836
849
  return success
837
850
 
838
851
  # zero copy
839
852
  def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
840
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
853
+ hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
841
854
  hash_values, host_indices
842
855
  )
843
856
  return self.storage_backend.batch_set(hashes, dsts)
844
857
 
845
858
  # Backup batch by batch
846
859
  def _page_backup(self, operation):
847
- # Select the set function and batch size
848
- if self.storage_backend_type == "mooncake":
849
- backup_set_func = self._mooncake_page_set
850
- elif (
851
- self.storage_backend_type == "hf3fs"
852
- and self.mem_pool_host.layout == "page_first"
853
- ):
854
- backup_set_func = self._3fs_zero_copy_page_set
855
- else:
856
- backup_set_func = self._generic_page_set
857
860
  # Backup batch by batch
858
861
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
859
862
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -862,7 +865,7 @@ class HiCacheController:
862
865
  ]
863
866
  # Set one batch token, and record if success.
864
867
  # todo: allow partial success
865
- success = backup_set_func(batch_hashes, batch_host_indices)
868
+ success = self.page_set_func(batch_hashes, batch_host_indices)
866
869
  if not success:
867
870
  logger.warning(
868
871
  f"Write page to storage: {len(batch_hashes)} pages failed."