sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
22
22
  import torch
23
23
  from torch.nn.parameter import Parameter
24
24
 
25
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
26
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
25
27
  from sglang.srt.layers.moe.utils import get_moe_runner_backend
26
28
  from sglang.srt.layers.quantization.base_config import (
27
29
  FusedMoEMethodBase,
@@ -29,14 +31,13 @@ from sglang.srt.layers.quantization.base_config import (
29
31
  QuantizeMethodBase,
30
32
  )
31
33
  from sglang.srt.layers.quantization.utils import is_layer_skipped
32
- from sglang.srt.layers.utils import is_sm100_supported
33
34
  from sglang.srt.managers.schedule_batch import global_server_args_dict
34
35
  from sglang.srt.utils import (
35
36
  direct_register_custom_op,
36
- get_bool_env_var,
37
37
  is_cuda,
38
38
  is_flashinfer_available,
39
39
  is_hip,
40
+ is_sm100_supported,
40
41
  is_triton_kernels_available,
41
42
  log_info_on_rank0,
42
43
  mxfp_supported,
@@ -60,17 +61,24 @@ if is_flashinfer_available():
60
61
  logger = logging.getLogger(__name__)
61
62
 
62
63
  if TYPE_CHECKING:
63
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
64
- from sglang.srt.layers.moe.topk import TopKOutput
64
+ from sglang.srt.layers.moe.token_dispatcher import (
65
+ CombineInput,
66
+ StandardDispatchOutput,
67
+ )
65
68
 
66
69
  _is_hip = is_hip()
67
70
 
68
71
  if _is_hip:
69
72
  # import aiter
70
- from aiter import ActivationType, QuantType, dtypes
71
- from aiter.fused_moe import fused_moe
72
- from aiter.ops.triton.quant import dynamic_mxfp4_quant
73
- from aiter.utility.fp4_utils import e8m0_shuffle
73
+ try:
74
+ from aiter import ActivationType, QuantType, dtypes
75
+ from aiter.fused_moe import fused_moe
76
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
77
+ from aiter.utility.fp4_utils import e8m0_shuffle
78
+ except ImportError as err:
79
+ ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
80
+ e8m0_shuffle
81
+ ) = err
74
82
 
75
83
 
76
84
  def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -146,27 +154,21 @@ def _quant_dequant_mxfp4_fake(
146
154
  return torch.empty_like(x)
147
155
 
148
156
 
149
- try:
150
- direct_register_custom_op(
151
- op_name="dequant_mxfp4",
152
- op_func=_dequant_mxfp4,
153
- mutates_args=[],
154
- fake_impl=_dequant_mxfp4_fake,
155
- )
156
- dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
157
- except AttributeError as error:
158
- raise error
159
-
160
- try:
161
- direct_register_custom_op(
162
- op_name="quant_dequant_mxfp4",
163
- op_func=_quant_dequant_mxfp4,
164
- mutates_args=[],
165
- fake_impl=_quant_dequant_mxfp4_fake,
166
- )
167
- quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
168
- except AttributeError as error:
169
- raise error
157
+ direct_register_custom_op(
158
+ op_name="dequant_mxfp4",
159
+ op_func=_dequant_mxfp4,
160
+ mutates_args=[],
161
+ fake_impl=_dequant_mxfp4_fake,
162
+ )
163
+ dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
164
+
165
+ direct_register_custom_op(
166
+ op_name="quant_dequant_mxfp4",
167
+ op_func=_quant_dequant_mxfp4,
168
+ mutates_args=[],
169
+ fake_impl=_quant_dequant_mxfp4_fake,
170
+ )
171
+ quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
170
172
 
171
173
 
172
174
  class Mxfp4Config(QuantizationConfig):
@@ -285,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
285
287
  layer: torch.nn.Module,
286
288
  num_experts: int,
287
289
  hidden_size: int,
288
- intermediate_size: int,
290
+ intermediate_size_per_partition: int,
289
291
  params_dtype: torch.dtype,
290
292
  with_bias: bool = False,
291
293
  **extra_weight_attrs,
@@ -298,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
298
300
 
299
301
  # pad the intermediate size to be a multiple of 2 * mxfp4_block
300
302
  # for to hold non-uniform sharded tensor as well as swizzling
301
- intermediate_size_per_partition_after_pad = intermediate_size
303
+ intermediate_size_per_partition_after_pad = intermediate_size_per_partition
302
304
  if _is_sm100_supported:
303
305
  if self.use_flashinfer:
304
306
  intermediate_size_per_partition_after_pad = round_up(
305
- intermediate_size, 256
307
+ intermediate_size_per_partition, 256
306
308
  )
307
309
  hidden_size = round_up(hidden_size, 256)
308
310
  else:
309
311
  intermediate_size_per_partition_after_pad = round_up(
310
- intermediate_size, 64
312
+ intermediate_size_per_partition, 64
311
313
  )
312
314
  elif has_triton_kernels:
313
315
  # TODO: this is a hack to make
314
316
  # intermediate_size_per_partition_after_pad the same as the
315
317
  # per_rank_intermediate_size during weight loading
316
318
  intermediate_size_per_partition_after_pad = round_up(
317
- intermediate_size, mxfp4_block
319
+ intermediate_size_per_partition, mxfp4_block
318
320
  )
319
321
 
320
- self.intermediate_size = intermediate_size_per_partition_after_pad
322
+ self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
321
323
 
322
324
  self.hidden_size = hidden_size
323
325
  # Fused gate_up_proj (column parallel)
@@ -412,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
412
414
  assert (
413
415
  layer.w13_weight.dim() == 3
414
416
  and layer.w13_weight.shape[0] == self.num_experts
415
- and layer.w13_weight.shape[1] == self.intermediate_size * 2
417
+ and layer.w13_weight.shape[1]
418
+ == self.intermediate_size_per_partition * 2
416
419
  and layer.w13_weight.shape[2] == self.hidden_size // 2
417
420
  )
418
421
  assert (
419
422
  layer.w13_weight_scale.dim() == 3
420
423
  and layer.w13_weight_scale.shape[0] == self.num_experts
421
- and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
424
+ and layer.w13_weight_scale.shape[1]
425
+ == self.intermediate_size_per_partition * 2
422
426
  and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
423
427
  )
424
428
  assert (
425
429
  layer.w2_weight.dim() == 3
426
430
  and layer.w2_weight.shape[0] == self.num_experts
427
431
  and layer.w2_weight.shape[1] == self.hidden_size
428
- and layer.w2_weight.shape[2] == self.intermediate_size // 2
432
+ and layer.w2_weight.shape[2]
433
+ == self.intermediate_size_per_partition // 2
429
434
  )
430
435
  assert (
431
436
  layer.w2_weight_scale.dim() == 3
432
437
  and layer.w2_weight_scale.shape[1] == self.hidden_size
433
438
  and layer.w2_weight_scale.shape[2]
434
- == self.intermediate_size // sf_block_size
439
+ == self.intermediate_size_per_partition // sf_block_size
435
440
  )
436
441
  assert (
437
442
  layer.w13_weight_bias.dim() == 2
438
443
  and layer.w13_weight_bias.shape[0] == self.num_experts
439
- and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
444
+ and layer.w13_weight_bias.shape[1]
445
+ == self.intermediate_size_per_partition * 2
440
446
  )
441
447
  assert (
442
448
  layer.w2_weight_bias.dim() == 2
@@ -513,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
513
519
  torch.stack(gemm1_scales_mxfp4_shuffled)
514
520
  .reshape(
515
521
  self.num_experts,
516
- 2 * self.intermediate_size,
522
+ 2 * self.intermediate_size_per_partition,
517
523
  self.hidden_size // sf_block_size,
518
524
  )
519
525
  .view(torch.float8_e4m3fn)
@@ -525,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
525
531
  .reshape(
526
532
  self.num_experts,
527
533
  self.hidden_size,
528
- self.intermediate_size // sf_block_size,
534
+ self.intermediate_size_per_partition // sf_block_size,
529
535
  )
530
536
  .view(torch.float8_e4m3fn)
531
537
  )
@@ -615,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
615
621
 
616
622
  return tile_tokens_dim
617
623
 
624
+ def create_moe_runner(
625
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
626
+ ):
627
+ self.moe_runner_config = moe_runner_config
628
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
629
+
618
630
  def apply(
619
631
  self,
620
632
  layer: torch.nn.Module,
621
- x: torch.Tensor,
622
- topk_output: TopKOutput,
623
- moe_runner_config: MoeRunnerConfig,
624
- ) -> torch.Tensor:
633
+ dispatch_output: StandardDispatchOutput,
634
+ ) -> CombineInput:
625
635
 
636
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
626
637
  from sglang.srt.layers.moe.topk import TopKOutputChecker
627
638
 
639
+ x = dispatch_output.hidden_states
640
+ topk_output = dispatch_output.topk_output
641
+
642
+ moe_runner_config = self.moe_runner_config
643
+
628
644
  if self.use_flashinfer:
629
645
  # When bf16 mode is enabled, we don't need to quantize the input,
630
646
  # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
@@ -676,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
676
692
  top_k,
677
693
  None, # n_group # TODO: support n_group
678
694
  None, # topk_group # TODO: support topk_group
679
- self.intermediate_size, # padded to multiple of 256
695
+ self.intermediate_size_per_partition, # padded to multiple of 256
680
696
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
681
697
  layer.num_local_experts, # local num experts
682
698
  None,
@@ -684,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
684
700
  1, # routing_method_type, renormalize
685
701
  True, # do finalize
686
702
  )[0]
687
- return trtllm_gen_output
703
+ return StandardCombineInput(hidden_states=trtllm_gen_output)
688
704
 
689
705
  if self.use_triton_kernels:
690
706
  assert (
691
707
  layer.moe_ep_size == 1
692
708
  ), "Expert parallel is not supported when using triton kernels"
693
709
  if self.with_bias:
694
- return self.triton_kernel_moe_with_bias_forward(
710
+ output = self.triton_kernel_moe_with_bias_forward(
695
711
  hidden_states=x,
696
712
  w1=self.w13_weight_triton_tensor,
697
713
  w1_pcg=self.w13_precision_config,
@@ -703,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
703
719
  moe_runner_config=moe_runner_config,
704
720
  )
705
721
  else:
706
- return self.triton_kernel_moe_forward(
722
+ output = self.triton_kernel_moe_forward(
707
723
  hidden_states=x,
708
724
  w1=layer.w13_weight,
709
725
  w2=layer.w2_weight,
710
726
  topk_output=topk_output,
711
727
  moe_runner_config=moe_runner_config,
712
728
  )
729
+ return StandardCombineInput(hidden_states=output)
713
730
  else:
714
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
715
-
716
- return fused_experts(
717
- hidden_states=x,
718
- w1=layer.w13_weight,
719
- w2=layer.w2_weight,
720
- topk_output=topk_output,
721
- moe_runner_config=moe_runner_config,
722
- b1=layer.w13_weight_bias,
723
- b2=layer.w2_weight_bias,
731
+ quant_info = TritonMoeQuantInfo(
732
+ w13_weight=layer.w13_weight,
733
+ w2_weight=layer.w2_weight,
734
+ w13_weight_bias=layer.w13_weight_bias,
735
+ w2_weight_bias=layer.w2_weight_bias,
724
736
  )
737
+ return self.runner.run(dispatch_output, quant_info)
725
738
 
726
739
 
727
740
  class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
@@ -800,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
800
813
 
801
814
  return w, mx_scales
802
815
 
803
- def process_weights_after_loading(self, layer: Module) -> None:
816
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
804
817
  w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
805
818
  w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
806
819
 
@@ -810,16 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
810
823
  layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
811
824
  layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
812
825
 
826
+ def create_moe_runner(
827
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
828
+ ):
829
+ self.moe_runner_config = moe_runner_config
830
+
813
831
  def apply(
814
832
  self,
815
833
  layer: torch.nn.Module,
816
- x: torch.Tensor,
817
- topk_output: TopKOutput,
818
- moe_runner_config: MoeRunnerConfig,
819
- ) -> torch.Tensor:
820
- topk_weights, topk_ids, _ = topk_output
834
+ dispatch_output: StandardDispatchOutput,
835
+ ) -> CombineInput:
836
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
837
+
838
+ x = dispatch_output.hidden_states
839
+ topk_output = dispatch_output.topk_output
821
840
 
822
- return fused_moe(
841
+ topk_weights, topk_ids, _ = topk_output
842
+ if _is_hip:
843
+ topk_weights = topk_weights.to(
844
+ torch.float32
845
+ ) # aiter's moe_sorting requires topk_weights to be FP32
846
+ output = fused_moe(
823
847
  x,
824
848
  layer.w13_weight,
825
849
  layer.w2_weight,
@@ -830,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
830
854
  w2_scale=layer.w2_weight_scale,
831
855
  activation=(
832
856
  ActivationType.Silu
833
- if moe_runner_config.activation == "silu"
857
+ if self.moe_runner_config.activation == "silu"
834
858
  else ActivationType.Gelu
835
859
  ),
836
860
  doweight_stage1=False,
837
861
  )
862
+ return StandardCombineInput(hidden_states=output)
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
17
19
 
18
20
 
@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
24
26
  E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
25
27
 
26
28
  @classmethod
27
- def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
29
+ def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
28
30
  """Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
29
31
  Args:
30
32
  input (torch.Tensor): The input tensor to be quantized.
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
10
10
  from aiter.fused_moe import fused_moe
11
11
  from aiter.utility.fp4_utils import e8m0_shuffle
12
12
 
13
+ from sglang.srt.layers.moe import MoeRunnerConfig
14
+ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
13
15
  from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
14
16
 
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.layers.moe.token_dispatcher import (
19
+ CombineInput,
20
+ StandardDispatchOutput,
21
+ )
22
+ from sglang.srt.layers.quantization.quark.quark import QuarkConfig
23
+
15
24
  logger = logging.getLogger(__name__)
16
25
 
17
26
  __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
19
28
  OCP_MX_BLOCK_SIZE = 32
20
29
 
21
30
  if TYPE_CHECKING:
22
- from sglang.srt.layers.moe.topk import TopKOutput
23
-
24
-
25
- class QuarkMoEMethod:
26
- def __new__(cls, *args, **kwargs):
27
- from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
28
-
29
- if not hasattr(cls, "_initialized"):
30
- original_init = cls.__init__
31
- new_cls = type(
32
- cls.__name__,
33
- (FusedMoEMethodBase,),
34
- {
35
- "__init__": original_init,
36
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
37
- },
38
- )
39
- obj = super(new_cls, new_cls).__new__(new_cls)
40
- obj.__init__(*args, **kwargs)
41
- return obj
42
- return super().__new__(cls)
31
+ from sglang.srt.layers.quantization import QuarkConfig
32
+
33
+
34
+ class QuarkMoEMethod(FusedMoEMethodBase):
35
+
36
+ def __init__(self, quant_config: QuarkConfig):
37
+ self.quant_config = quant_config
43
38
 
44
39
  @staticmethod
45
40
  def get_moe_method(
46
- quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
41
+ quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
47
42
  module: torch.nn.Module,
48
43
  layer_name: str,
49
44
  ) -> "QuarkMoEMethod":
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
170
165
  # layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
171
166
  layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
172
167
 
168
+ def create_moe_runner(
169
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
170
+ ):
171
+ self.moe_runner_config = moe_runner_config
172
+
173
173
  def apply(
174
174
  self,
175
175
  layer: torch.nn.Module,
176
- x: torch.Tensor,
177
- topk_output: TopKOutput,
178
- moe_runner_config: MoeRunnerConfig,
179
- ) -> torch.Tensor:
176
+ dispatch_output: StandardDispatchOutput,
177
+ ) -> CombineInput:
178
+
179
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
180
+
181
+ x = dispatch_output.hidden_states
182
+ topk_output = dispatch_output.topk_output
183
+ moe_runner_config = self.moe_runner_config
180
184
  topk_weights, topk_ids, _ = topk_output
181
185
 
182
- return fused_moe(
186
+ output = fused_moe(
183
187
  x,
184
188
  layer.w13_weight,
185
189
  layer.w2_weight,
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
195
199
  ),
196
200
  doweight_stage1=False,
197
201
  )
202
+ return StandardCombineInput(hidden_states=output)
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
8
  from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
9
  from aiter.ops.shuffle import shuffle_weight
10
10
  from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
+ from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
11
12
  from aiter.ops.triton.quant import dynamic_mxfp4_quant
12
13
  from aiter.utility import dtypes
13
14
  from aiter.utility.fp4_utils import e8m0_shuffle
@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
38
39
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
39
40
  return
40
41
 
41
- # for aiter implement
42
- # wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
43
- # w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
44
-
45
- # layer.weight = torch.nn.Parameter(wshuffle,
46
- # requires_grad=False)
47
- # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
48
- # requires_grad=False)
49
-
50
42
  def create_weights(
51
43
  self,
52
44
  layer: torch.nn.Module,
@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
93
85
  x: torch.Tensor,
94
86
  bias: Optional[torch.Tensor] = None,
95
87
  ) -> torch.Tensor:
96
-
97
- out_dtype = x.dtype
98
- # M = x.shape[0]
99
- # N = layer.weight.shape[0]
100
-
101
- # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
102
- # x, x_scales_shuffle = quant_func(x, shuffle=True)
103
-
104
- # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
105
-
106
- # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
107
-
108
- # return out[:M]
109
-
110
- # triton implement
111
- x_q, x_s = dynamic_mxfp4_quant(x)
112
- y = torch.empty(
113
- x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
88
+ # This path does not have support for bias currently
89
+ assert bias is None, "bias is not supported"
90
+
91
+ three_d = False
92
+ x_s = None
93
+ y = None
94
+ if isinstance(x, tuple):
95
+ assert len(x) in [
96
+ 2,
97
+ 3,
98
+ ], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
99
+ if len(x) == 2:
100
+ x, x_s = x
101
+ elif len(x) == 3:
102
+ x, x_s, y = x
103
+
104
+ use_fused_quant_gemm = (
105
+ x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
114
106
  )
115
107
 
116
- out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
117
-
118
- return out
108
+ if x.dim() == 3:
109
+ three_d = True
110
+ x = x.view(-1, x.shape[-1])
111
+ output_shape = [*x.shape[:-1], layer.weight.shape[0]]
112
+
113
+ # use_fused_quant_gemm = true, x_q is a bf16/fp16 num
114
+ # x_s is not None = true, x_q is uint8 num
115
+ if use_fused_quant_gemm or x_s is not None:
116
+ x_q = x
117
+ else:
118
+ x_q, x_s = dynamic_mxfp4_quant(x)
119
+
120
+ if y is None:
121
+ y = torch.empty(
122
+ x_q.shape[0],
123
+ layer.weight.shape[0],
124
+ device=x_q.device,
125
+ dtype=self.out_dtype,
126
+ )
127
+
128
+ if use_fused_quant_gemm:
129
+ gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
130
+ y = y.to(x.dtype)
131
+ else:
132
+ gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
133
+
134
+ if three_d:
135
+ return y.view(*output_shape)
136
+
137
+ return y
@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
5
5
  from types import MappingProxyType
6
6
  from typing import Any, Optional
7
7
 
8
+ import torch
9
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
10
+ from torch import nn
11
+
8
12
 
9
13
  def deep_compare(dict1: Any, dict2: Any) -> bool:
10
14
  if type(dict1) is not type(dict2):
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
105
109
  elif target == value:
106
110
  return True
107
111
  return False
112
+
113
+
114
+ # utility for tensor dims > 2 cases
115
+ def b_dynamic_mxfp4_quant(x):
116
+ h, b, d = x.shape
117
+ x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
118
+ return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
119
+
120
+
121
+ def mxfp4_to_f32(x, is_threed):
122
+ # 2 because we pack fp4 in uint8.
123
+ x = x.repeat_interleave(2, dim=-1)
124
+ if is_threed:
125
+ x[..., ::2] = x[..., ::2] & 0xF
126
+ x[..., 1::2] = x[..., 1::2] >> 4
127
+ else:
128
+ x[:, ::2] = x[:, ::2] & 0xF
129
+ x[:, 1::2] = x[:, 1::2] >> 4
130
+
131
+ mxfp4_list = [
132
+ 0.0,
133
+ 0.5,
134
+ 1.0,
135
+ 1.5,
136
+ 2.0,
137
+ 3.0,
138
+ 4.0,
139
+ 6.0,
140
+ -0.0,
141
+ -0.5,
142
+ -1.0,
143
+ -1.5,
144
+ -2.0,
145
+ -3.0,
146
+ -4.0,
147
+ -6.0,
148
+ ]
149
+ mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
150
+ return mxfp4_in_f32[x.long()]
151
+
152
+
153
+ def e8m0_to_f32(x):
154
+ # Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
155
+ # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
156
+ # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
157
+
158
+ # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
159
+ x_f32 = 2 ** ((x.to(torch.float32)) - 127)
160
+
161
+ # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
162
+ # Since this custom format has no mantissa, treat 2^128 as NaN.
163
+ x_f32[x_f32 == 128] = float("nan")
164
+ return x_f32
165
+
166
+
167
+ def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
168
+ if "mxfp4" in quant_format:
169
+ # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
170
+ # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
171
+ # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
172
+ if w.dtype == torch.bfloat16:
173
+ w_kc, w_vc = w.unflatten(
174
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
175
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
176
+ w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
177
+ w_kc = w_kc.transpose(-2, -1)
178
+ w_s_kc = w_s_kc.transpose(-2, -1)
179
+ w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
180
+ w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
181
+ w_s_vc = w_s_vc.contiguous().transpose(1, 2)
182
+ elif w.dtype == torch.uint8: # static quant for mxfp4
183
+ # when dtype is uint8, it means the w has been quantized to mxfp4 format
184
+ # but we must separate it to w_kc and w_vc.
185
+ # The quantized tensor size is only half of original tensor size
186
+ # and the scaling factor is 1/32, the transpose behavior will be not correct
187
+ # need to upcast it to fp32 to separate w to w_kc and w_vc
188
+ # to ensure the following transpose behavior is correct
189
+ # and then do mxfp4 quant again
190
+ w = mxfp4_to_f32(w, True).to(torch.bfloat16)
191
+ w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
192
+ w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
193
+ w = w * w_scales
194
+ w_kc, w_vc = w.unflatten(
195
+ 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
196
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
197
+ w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
198
+ w_kc = w_kc.transpose(-2, -1)
199
+ w_s_kc = w_s_kc.transpose(-2, -1)
200
+ w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
201
+ w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
202
+ w_s_vc = w_s_vc.contiguous().transpose(1, 2)
203
+
204
+ return w_kc, w_s_kc, w_vc, w_s_vc
@@ -0,0 +1,13 @@
1
+ from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
2
+ batched_gemm_afp4wfp4_pre_quant,
3
+ )
4
+ from aiter.ops.triton.fused_mxfp4_quant import (
5
+ fused_flatten_mxfp4_quant,
6
+ fused_rms_mxfp4_quant,
7
+ )
8
+
9
+ __all__ = [
10
+ "fused_rms_mxfp4_quant",
11
+ "fused_flatten_mxfp4_quant",
12
+ "batched_gemm_afp4wfp4_pre_quant",
13
+ ]