sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/configs/model_config.py +2 -1
  4. sglang/srt/disaggregation/mini_lb.py +2 -2
  5. sglang/srt/distributed/parallel_state.py +46 -41
  6. sglang/srt/entrypoints/engine.py +1 -1
  7. sglang/srt/entrypoints/http_server.py +5 -1
  8. sglang/srt/entrypoints/openai/protocol.py +3 -3
  9. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  10. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  11. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  12. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  13. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  14. sglang/srt/layers/attention/aiter_backend.py +93 -68
  15. sglang/srt/layers/communicator.py +45 -7
  16. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  17. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  18. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  24. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  25. sglang/srt/layers/moe/utils.py +0 -1
  26. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  27. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  28. sglang/srt/layers/quantization/mxfp4.py +4 -1
  29. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  30. sglang/srt/layers/quantization/quark/utils.py +97 -0
  31. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  32. sglang/srt/layers/quantization/w4afp8.py +30 -25
  33. sglang/srt/layers/rocm_linear_utils.py +44 -0
  34. sglang/srt/layers/rotary_embedding.py +0 -18
  35. sglang/srt/managers/cache_controller.py +42 -39
  36. sglang/srt/managers/detokenizer_manager.py +0 -34
  37. sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
  38. sglang/srt/managers/schedule_policy.py +3 -2
  39. sglang/srt/managers/scheduler.py +7 -100
  40. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  41. sglang/srt/managers/template_manager.py +3 -3
  42. sglang/srt/managers/tokenizer_manager.py +1 -0
  43. sglang/srt/mem_cache/allocator.py +1 -1
  44. sglang/srt/mem_cache/hicache_storage.py +15 -10
  45. sglang/srt/mem_cache/hiradix_cache.py +16 -0
  46. sglang/srt/mem_cache/memory_pool_host.py +18 -11
  47. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  48. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
  49. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  50. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  51. sglang/srt/metrics/collector.py +12 -4
  52. sglang/srt/metrics/utils.py +48 -0
  53. sglang/srt/model_executor/forward_batch_info.py +16 -17
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +245 -36
  56. sglang/srt/models/glm4_moe.py +10 -1
  57. sglang/srt/models/gpt_oss.py +5 -4
  58. sglang/srt/models/internvl.py +28 -0
  59. sglang/srt/models/longcat_flash.py +26 -15
  60. sglang/srt/models/longcat_flash_nextn.py +23 -15
  61. sglang/srt/models/minicpmv.py +165 -3
  62. sglang/srt/models/qwen2_moe.py +4 -1
  63. sglang/srt/models/qwen3.py +8 -2
  64. sglang/srt/models/qwen3_moe.py +39 -8
  65. sglang/srt/models/torch_native_llama.py +1 -1
  66. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  67. sglang/srt/server_args.py +79 -2
  68. sglang/srt/speculative/eagle_worker.py +158 -112
  69. sglang/srt/utils.py +12 -10
  70. sglang/test/few_shot_gsm8k.py +1 -0
  71. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  72. sglang/utils.py +1 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
  75. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
  76. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  77. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  78. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  79. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  80. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  81. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  82. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  83. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -175,6 +175,8 @@ class FusedMoE(torch.nn.Module):
175
175
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
176
176
  assert num_experts % self.moe_ep_size == 0
177
177
  self.num_local_experts = num_experts // self.moe_ep_size
178
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
179
+ self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
178
180
  if self.moe_ep_size > 1:
179
181
  # TODO(ch-wan): support shared experts fusion
180
182
  # Create a tensor of size num_experts filled with -1
@@ -593,8 +595,9 @@ class FusedMoE(torch.nn.Module):
593
595
 
594
596
  if (
595
597
  "compressed" in self.quant_method.__class__.__name__.lower()
596
- and param.data[expert_id] != 1
597
- and (param.data[expert_id] - loaded_weight).abs() > 1e-5
598
+ or "w4afp8" in self.quant_config.get_name()
599
+ and (param.data[expert_id] != 1).any()
600
+ and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
598
601
  ):
599
602
  raise ValueError(
600
603
  "input_scales of w1 and w3 of a layer "
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from sglang.srt.utils import is_cuda, is_hip
9
+
10
+ _is_cuda = is_cuda()
11
+ _is_hip = is_hip()
12
+
13
+ if _is_cuda or _is_hip:
14
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
15
+
16
+
17
+ def moe_align_block_size(
18
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
19
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20
+ """
21
+ Aligns the token distribution across experts to be compatible with block
22
+ size for matrix multiplication.
23
+
24
+ Parameters:
25
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
26
+ top-k expert indices for each token.
27
+ - block_size: The block size used in block matrix multiplication.
28
+ - num_experts: The total number of experts.
29
+
30
+ Returns:
31
+ - sorted_token_ids: A tensor containing the sorted token indices according
32
+ to their allocated expert.
33
+ - expert_ids: A tensor indicating the assigned expert index for each block.
34
+ - num_tokens_post_padded: The total number of tokens after padding,
35
+ ensuring divisibility by block_size.
36
+
37
+ This function pads the number of tokens that each expert needs to process
38
+ so that it is divisible by block_size.
39
+ Padding ensures that during block matrix multiplication, the dimensions
40
+ align correctly.
41
+
42
+ Example:
43
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
44
+ block_size = 4, and num_experts = 4:
45
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
46
+ with each expert needing to process 3 tokens.
47
+ - As block_size is 4, we pad 1 token for each expert.
48
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
49
+ - Then append padding tokens [12, 12, 12, 12] for each block.
50
+ - After sorting by expert index, we obtain token_ids
51
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
52
+ Tokens 12 are non-existent (padding) and are ignored in
53
+ the subsequent matrix multiplication.
54
+ - The padding ensures that the total number of tokens is now divisible
55
+ by block_size for proper block matrix operations.
56
+ """
57
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
58
+ sorted_ids = torch.empty(
59
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
60
+ )
61
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
62
+ expert_ids = torch.empty(
63
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
64
+ )
65
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
66
+
67
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
68
+ cumsum_buffer = torch.empty(
69
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
70
+ )
71
+
72
+ # Threshold based on benchmark results
73
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
74
+ if not fuse_sorted_ids_padding:
75
+ sorted_ids.fill_(topk_ids.numel())
76
+
77
+ sgl_moe_align_block_size(
78
+ topk_ids,
79
+ num_experts + 1,
80
+ block_size,
81
+ sorted_ids,
82
+ expert_ids,
83
+ num_tokens_post_pad,
84
+ cumsum_buffer,
85
+ fuse_sorted_ids_padding,
86
+ )
87
+ return sorted_ids, expert_ids, num_tokens_post_pad
@@ -162,7 +162,6 @@ def get_deepep_config() -> str:
162
162
  def is_tbo_enabled() -> bool:
163
163
  global IS_TBO_ENABLED
164
164
  if IS_TBO_ENABLED is None:
165
- logger.warning("IS_TBO_ENABLED is not initialized, using False")
166
165
  IS_TBO_ENABLED = False
167
166
  return IS_TBO_ENABLED
168
167
 
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
132
132
  kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
133
133
  )
134
134
 
135
+ old_compile_mode = deep_gemm.get_compile_mode()
136
+ deep_gemm.set_compile_mode(1)
135
137
  # TODO can use multi thread
136
138
  for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
137
139
  executor.execute(m=m)
140
+ deep_gemm.set_compile_mode(old_compile_mode)
141
+
142
+ # clean up input buffers
143
+ torch.cuda.current_stream().synchronize()
144
+ del executor
145
+ torch.cuda.empty_cache()
138
146
 
139
147
 
140
148
  class _BaseWarmupExecutor:
@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
517
517
  def get_config_filenames(cls) -> List[str]:
518
518
  return ["hf_quant_config.json"]
519
519
 
520
+ @staticmethod
521
+ def common_group_size(cfg: dict) -> int:
522
+ """Return the unique group_size across the config; raise if missing/mismatched."""
523
+ sizes = set()
524
+
525
+ # Top-level and 'quantization' block
526
+ v = cfg.get("group_size")
527
+ if isinstance(v, int):
528
+ sizes.add(v)
529
+ q = cfg.get("quantization")
530
+ if isinstance(q, dict):
531
+ v = q.get("group_size")
532
+ if isinstance(v, int):
533
+ sizes.add(v)
534
+
535
+ # config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
536
+ for g in (cfg.get("config_groups") or {}).values():
537
+ if isinstance(g, dict):
538
+ v = g.get("group_size")
539
+ if isinstance(v, int):
540
+ sizes.add(v)
541
+ for sub in g.values():
542
+ if isinstance(sub, dict):
543
+ v = sub.get("group_size")
544
+ if isinstance(v, int):
545
+ sizes.add(v)
546
+
547
+ if not sizes:
548
+ raise ValueError("No group_size found in config.")
549
+ if len(sizes) > 1:
550
+ raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
551
+ return next(iter(sizes))
552
+
520
553
  @classmethod
521
554
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
522
555
  # Handle two different config formats:
@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
549
582
  else:
550
583
  kv_cache_quant_algo = "auto"
551
584
 
552
- group_size = config.get("group_size")
585
+ group_size = ModelOptFp4Config.common_group_size(config)
553
586
  exclude_modules = config.get("ignore", [])
554
587
  else:
555
588
  # Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
559
592
  kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
593
  if not kv_cache_quant_algo:
561
594
  kv_cache_quant_algo = "auto"
562
- group_size = quant_config.get("group_size")
595
+ group_size = ModelOptFp4Config.common_group_size(config)
563
596
  exclude_modules = quant_config.get("exclude_modules", [])
564
597
  except (ValueError, KeyError):
565
598
  raise ValueError(
@@ -816,7 +816,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
816
816
  moe_runner_config: MoeRunnerConfig,
817
817
  ) -> torch.Tensor:
818
818
  topk_weights, topk_ids, _ = topk_output
819
-
819
+ if _is_hip:
820
+ topk_weights = topk_weights.to(
821
+ torch.float32
822
+ ) # aiter's moe_sorting requires topk_weights to be FP32
820
823
  return fused_moe(
821
824
  x,
822
825
  layer.w13_weight,
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
8
  from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
9
  from aiter.ops.shuffle import shuffle_weight
10
10
  from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
+ from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
11
12
  from aiter.ops.triton.quant import dynamic_mxfp4_quant
12
13
  from aiter.utility import dtypes
13
14
  from aiter.utility.fp4_utils import e8m0_shuffle
@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
38
39
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
39
40
  return
40
41
 
41
- # for aiter implement
42
- # wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
43
- # w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
44
-
45
- # layer.weight = torch.nn.Parameter(wshuffle,
46
- # requires_grad=False)
47
- # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
48
- # requires_grad=False)
49
-
50
42
  def create_weights(
51
43
  self,
52
44
  layer: torch.nn.Module,
@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
93
85
  x: torch.Tensor,
94
86
  bias: Optional[torch.Tensor] = None,
95
87
  ) -> torch.Tensor:
96
-
97
- out_dtype = x.dtype
98
- # M = x.shape[0]
99
- # N = layer.weight.shape[0]
100
-
101
- # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
102
- # x, x_scales_shuffle = quant_func(x, shuffle=True)
103
-
104
- # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
105
-
106
- # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
107
-
108
- # return out[:M]
109
-
110
- # triton implement
111
- x_q, x_s = dynamic_mxfp4_quant(x)
112
- y = torch.empty(
113
- x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
88
+ # This path does not have support for bias currently
89
+ assert bias is None, "bias is not supported"
90
+
91
+ three_d = False
92
+ x_s = None
93
+ y = None
94
+ if isinstance(x, tuple):
95
+ assert len(x) in [
96
+ 2,
97
+ 3,
98
+ ], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
99
+ if len(x) == 2:
100
+ x, x_s = x
101
+ elif len(x) == 3:
102
+ x, x_s, y = x
103
+
104
+ use_fused_quant_gemm = (
105
+ x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
114
106
  )
115
107
 
116
- out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
117
-
118
- return out
108
+ if x.dim() == 3:
109
+ three_d = True
110
+ x = x.view(-1, x.shape[-1])
111
+ output_shape = [*x.shape[:-1], layer.weight.shape[0]]
112
+
113
+ # use_fused_quant_gemm = true, x_q is a bf16/fp16 num
114
+ # x_s is not None = true, x_q is uint8 num
115
+ if use_fused_quant_gemm or x_s is not None:
116
+ x_q = x
117
+ else:
118
+ x_q, x_s = dynamic_mxfp4_quant(x)
119
+
120
+ if y is None:
121
+ y = torch.empty(
122
+ x_q.shape[0],
123
+ layer.weight.shape[0],
124
+ device=x_q.device,
125
+ dtype=self.out_dtype,
126
+ )
127
+
128
+ if use_fused_quant_gemm:
129
+ gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
130
+ y = y.to(x.dtype)
131
+ else:
132
+ gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
133
+
134
+ if three_d:
135
+ return y.view(*output_shape)
136
+
137
+ return y
@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
5
5
  from types import MappingProxyType
6
6
  from typing import Any, Optional
7
7
 
8
+ import torch
9
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
10
+ from torch import nn
11
+
8
12
 
9
13
  def deep_compare(dict1: Any, dict2: Any) -> bool:
10
14
  if type(dict1) is not type(dict2):
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
105
109
  elif target == value:
106
110
  return True
107
111
  return False
112
+
113
+
114
+ # utility for tensor dims > 2 cases
115
+ def b_dynamic_mxfp4_quant(x):
116
+ h, b, d = x.shape
117
+ x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
118
+ return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
119
+
120
+
121
+ def mxfp4_to_f32(x, is_threed):
122
+ # 2 because we pack fp4 in uint8.
123
+ x = x.repeat_interleave(2, dim=-1)
124
+ if is_threed:
125
+ x[..., ::2] = x[..., ::2] & 0xF
126
+ x[..., 1::2] = x[..., 1::2] >> 4
127
+ else:
128
+ x[:, ::2] = x[:, ::2] & 0xF
129
+ x[:, 1::2] = x[:, 1::2] >> 4
130
+
131
+ mxfp4_list = [
132
+ 0.0,
133
+ 0.5,
134
+ 1.0,
135
+ 1.5,
136
+ 2.0,
137
+ 3.0,
138
+ 4.0,
139
+ 6.0,
140
+ -0.0,
141
+ -0.5,
142
+ -1.0,
143
+ -1.5,
144
+ -2.0,
145
+ -3.0,
146
+ -4.0,
147
+ -6.0,
148
+ ]
149
+ mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
150
+ return mxfp4_in_f32[x.long()]
151
+
152
+
153
+ def e8m0_to_f32(x):
154
+ # Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
155
+ # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
156
+ # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
157
+
158
+ # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
159
+ x_f32 = 2 ** ((x.to(torch.float32)) - 127)
160
+
161
+ # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
162
+ # Since this custom format has no mantissa, treat 2^128 as NaN.
163
+ x_f32[x_f32 == 128] = float("nan")
164
+ return x_f32
165
+
166
+
167
+ def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
168
+ if "mxfp4" in quant_format:
169
+ # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
170
+ # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
171
+ # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
172
+ if w.dtype == torch.bfloat16:
173
+ w_kc, w_vc = w.unflatten(
174
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
175
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
176
+ w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
177
+ w_kc = w_kc.transpose(-2, -1)
178
+ w_s_kc = w_s_kc.transpose(-2, -1)
179
+ w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
180
+ w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
181
+ w_s_vc = w_s_vc.contiguous().transpose(1, 2)
182
+ elif w.dtype == torch.uint8: # static quant for mxfp4
183
+ # when dtype is uint8, it means the w has been quantized to mxfp4 format
184
+ # but we must separate it to w_kc and w_vc.
185
+ # The quantized tensor size is only half of original tensor size
186
+ # and the scaling factor is 1/32, the transpose behavior will be not correct
187
+ # need to upcast it to fp32 to separate w to w_kc and w_vc
188
+ # to ensure the following transpose behavior is correct
189
+ # and then do mxfp4 quant again
190
+ w = mxfp4_to_f32(w, True).to(torch.bfloat16)
191
+ w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
192
+ w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
193
+ w = w * w_scales
194
+ w_kc, w_vc = w.unflatten(
195
+ 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
196
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
197
+ w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
198
+ w_kc = w_kc.transpose(-2, -1)
199
+ w_s_kc = w_s_kc.transpose(-2, -1)
200
+ w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
201
+ w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
202
+ w_s_vc = w_s_vc.contiguous().transpose(1, 2)
203
+
204
+ return w_kc, w_s_kc, w_vc, w_s_vc
@@ -0,0 +1,13 @@
1
+ from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
2
+ batched_gemm_afp4wfp4_pre_quant,
3
+ )
4
+ from aiter.ops.triton.fused_mxfp4_quant import (
5
+ fused_flatten_mxfp4_quant,
6
+ fused_rms_mxfp4_quant,
7
+ )
8
+
9
+ __all__ = [
10
+ "fused_rms_mxfp4_quant",
11
+ "fused_flatten_mxfp4_quant",
12
+ "batched_gemm_afp4wfp4_pre_quant",
13
+ ]
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
8
8
  from torch.nn.parameter import Parameter
9
9
 
10
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
11
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
10
12
  from sglang.srt.layers.quantization.base_config import (
11
13
  FusedMoEMethodBase,
12
14
  QuantizationConfig,
@@ -91,12 +93,13 @@ class W4AFp8Config(QuantizationConfig):
91
93
  from sglang.srt.layers.linear import LinearBase
92
94
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
93
95
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
96
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
94
97
 
95
98
  if isinstance(layer, LinearBase):
96
99
  if is_layer_skipped(prefix, self.ignored_layers):
97
100
  return UnquantizedLinearMethod()
98
101
  return Fp8LinearMethod(self)
99
- elif isinstance(layer, EPMoE):
102
+ elif isinstance(layer, FusedMoE):
100
103
  return W4AFp8MoEMethod(self)
101
104
  return None
102
105
 
@@ -104,8 +107,24 @@ class W4AFp8Config(QuantizationConfig):
104
107
  return []
105
108
 
106
109
 
107
- class W4AFp8MoEMethod(FusedMoEMethodBase):
110
+ def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
111
+ """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
112
+ s_shape = scales.shape
113
+ # Reshape to separate groups of 4
114
+ alignment = 4 if s_shape[2] % 4 == 0 else 1
115
+ scales_interleaved = scales.reshape(
116
+ s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
117
+ )
118
+ # Permute dimensions to interleave
119
+ scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
120
+ # Reshape back to original dimensions but with interleaved values
121
+ scales_interleaved = scales_interleaved.reshape(
122
+ s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
123
+ )
124
+ return scales_interleaved.contiguous()
125
+
108
126
 
127
+ class W4AFp8MoEMethod(FusedMoEMethodBase):
109
128
  def __init__(self, quant_config: W4AFp8Config):
110
129
  self.quant_config = quant_config
111
130
 
@@ -234,33 +253,18 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
234
253
 
235
254
  return
236
255
 
237
- def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
238
- """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
239
- s_shape = scales.shape
240
- # Reshape to separate groups of 4
241
- scales_interleaved = scales.reshape(
242
- s_shape[0], s_shape[1], (s_shape[2] // 4), 4
243
- )
244
- # Permute dimensions to interleave
245
- scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
246
- # Reshape back to original dimensions but with interleaved values
247
- scales_interleaved = scales_interleaved.reshape(
248
- s_shape[0], s_shape[2] // 4, s_shape[1] * 4
249
- )
250
- return scales_interleaved.contiguous()
251
-
252
256
  def process_weights_after_loading(self, layer: Module) -> None:
253
257
  dtype = torch.bfloat16
254
258
  device = layer.w2_weight.device
255
259
 
256
260
  # Interleave w13_weight_scale (gate_up_proj)
257
261
  w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
258
- w13_weight_scale = self._interleave_scales(w13_weight_scale)
262
+ w13_weight_scale = interleave_scales(w13_weight_scale)
259
263
  layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
260
264
 
261
265
  # Interleave w2_weight_scale (down_proj)
262
266
  w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
263
- w2_weight_scale = self._interleave_scales(w2_weight_scale)
267
+ w2_weight_scale = interleave_scales(w2_weight_scale)
264
268
  layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
265
269
 
266
270
  # Process input scales
@@ -291,11 +295,12 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
291
295
 
292
296
  topk_weights, topk_ids, _ = topk_output
293
297
  local_topk_ids = topk_ids
294
- local_topk_ids = torch.where(
295
- topk_ids == -1,
296
- layer.num_experts,
297
- topk_ids,
298
- )
298
+ if get_moe_expert_parallel_world_size() > 1:
299
+ local_topk_ids = torch.where(
300
+ topk_ids == -1,
301
+ layer.num_experts,
302
+ topk_ids,
303
+ )
299
304
 
300
305
  output = cutlass_w4a8_moe(
301
306
  layer.start_expert_id,
@@ -0,0 +1,44 @@
1
+ import torch
2
+ from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
3
+ from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
4
+ from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
5
+
6
+ from sglang.srt.utils import BumpAllocator
7
+
8
+ __all__ = ["fused_qk_rope_cat"]
9
+
10
+
11
+ def aiter_dsv3_router_gemm(
12
+ hidden_states: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ gemm_output_zero_allocator: BumpAllocator = None,
15
+ ):
16
+ M = hidden_states.shape[0]
17
+ N = weight.shape[0]
18
+ y = None
19
+
20
+ if M <= 256:
21
+ # TODO (cagri): convert to bfloat16 as part of another kernel to save time
22
+ # for now it is also coupled with zero allocator.
23
+ if gemm_output_zero_allocator != None:
24
+ y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
25
+ else:
26
+ y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
27
+
28
+ if y is not None:
29
+ logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
30
+ else:
31
+ logits = gemm_a16w16(hidden_states, weight)
32
+
33
+ return logits
34
+
35
+
36
+ def get_dsv3_gemm_output_zero_allocator_size(
37
+ n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
38
+ ):
39
+ if embedding_dim != 7168 or n_routed_experts != 256:
40
+ return 0
41
+
42
+ per_layer_size = 256 * (allocate_size + n_routed_experts)
43
+
44
+ return num_moe_layers * per_layer_size
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
1433
1433
 
1434
1434
  return position_ids, mrope_position_deltas
1435
1435
 
1436
- @staticmethod
1437
- def get_next_input_positions(
1438
- mrope_position_delta: int,
1439
- context_len: int,
1440
- seq_len: int,
1441
- ) -> torch.Tensor:
1442
- return torch.tensor(
1443
- [
1444
- list(
1445
- range(
1446
- context_len + mrope_position_delta,
1447
- seq_len + mrope_position_delta,
1448
- )
1449
- )
1450
- for _ in range(3)
1451
- ]
1452
- )
1453
-
1454
1436
 
1455
1437
  class DualChunkRotaryEmbedding(CustomOp):
1456
1438
  """Rotary positional embedding for Dual Chunk Attention."""