sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
|
|
22
22
|
import torch
|
23
23
|
from torch.nn.parameter import Parameter
|
24
24
|
|
25
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
26
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
25
27
|
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
26
28
|
from sglang.srt.layers.quantization.base_config import (
|
27
29
|
FusedMoEMethodBase,
|
@@ -29,14 +31,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|
29
31
|
QuantizeMethodBase,
|
30
32
|
)
|
31
33
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
32
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
33
34
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
35
|
from sglang.srt.utils import (
|
35
36
|
direct_register_custom_op,
|
36
|
-
get_bool_env_var,
|
37
37
|
is_cuda,
|
38
38
|
is_flashinfer_available,
|
39
39
|
is_hip,
|
40
|
+
is_sm100_supported,
|
40
41
|
is_triton_kernels_available,
|
41
42
|
log_info_on_rank0,
|
42
43
|
mxfp_supported,
|
@@ -60,17 +61,24 @@ if is_flashinfer_available():
|
|
60
61
|
logger = logging.getLogger(__name__)
|
61
62
|
|
62
63
|
if TYPE_CHECKING:
|
63
|
-
from sglang.srt.layers.moe.
|
64
|
-
|
64
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
65
|
+
CombineInput,
|
66
|
+
StandardDispatchOutput,
|
67
|
+
)
|
65
68
|
|
66
69
|
_is_hip = is_hip()
|
67
70
|
|
68
71
|
if _is_hip:
|
69
72
|
# import aiter
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
73
|
+
try:
|
74
|
+
from aiter import ActivationType, QuantType, dtypes
|
75
|
+
from aiter.fused_moe import fused_moe
|
76
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
77
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
78
|
+
except ImportError as err:
|
79
|
+
ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
|
80
|
+
e8m0_shuffle
|
81
|
+
) = err
|
74
82
|
|
75
83
|
|
76
84
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
@@ -146,27 +154,21 @@ def _quant_dequant_mxfp4_fake(
|
|
146
154
|
return torch.empty_like(x)
|
147
155
|
|
148
156
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
mutates_args=[],
|
165
|
-
fake_impl=_quant_dequant_mxfp4_fake,
|
166
|
-
)
|
167
|
-
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
168
|
-
except AttributeError as error:
|
169
|
-
raise error
|
157
|
+
direct_register_custom_op(
|
158
|
+
op_name="dequant_mxfp4",
|
159
|
+
op_func=_dequant_mxfp4,
|
160
|
+
mutates_args=[],
|
161
|
+
fake_impl=_dequant_mxfp4_fake,
|
162
|
+
)
|
163
|
+
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
164
|
+
|
165
|
+
direct_register_custom_op(
|
166
|
+
op_name="quant_dequant_mxfp4",
|
167
|
+
op_func=_quant_dequant_mxfp4,
|
168
|
+
mutates_args=[],
|
169
|
+
fake_impl=_quant_dequant_mxfp4_fake,
|
170
|
+
)
|
171
|
+
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
170
172
|
|
171
173
|
|
172
174
|
class Mxfp4Config(QuantizationConfig):
|
@@ -285,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
285
287
|
layer: torch.nn.Module,
|
286
288
|
num_experts: int,
|
287
289
|
hidden_size: int,
|
288
|
-
|
290
|
+
intermediate_size_per_partition: int,
|
289
291
|
params_dtype: torch.dtype,
|
290
292
|
with_bias: bool = False,
|
291
293
|
**extra_weight_attrs,
|
@@ -298,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
298
300
|
|
299
301
|
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
300
302
|
# for to hold non-uniform sharded tensor as well as swizzling
|
301
|
-
intermediate_size_per_partition_after_pad =
|
303
|
+
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
302
304
|
if _is_sm100_supported:
|
303
305
|
if self.use_flashinfer:
|
304
306
|
intermediate_size_per_partition_after_pad = round_up(
|
305
|
-
|
307
|
+
intermediate_size_per_partition, 256
|
306
308
|
)
|
307
309
|
hidden_size = round_up(hidden_size, 256)
|
308
310
|
else:
|
309
311
|
intermediate_size_per_partition_after_pad = round_up(
|
310
|
-
|
312
|
+
intermediate_size_per_partition, 64
|
311
313
|
)
|
312
314
|
elif has_triton_kernels:
|
313
315
|
# TODO: this is a hack to make
|
314
316
|
# intermediate_size_per_partition_after_pad the same as the
|
315
317
|
# per_rank_intermediate_size during weight loading
|
316
318
|
intermediate_size_per_partition_after_pad = round_up(
|
317
|
-
|
319
|
+
intermediate_size_per_partition, mxfp4_block
|
318
320
|
)
|
319
321
|
|
320
|
-
self.
|
322
|
+
self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
|
321
323
|
|
322
324
|
self.hidden_size = hidden_size
|
323
325
|
# Fused gate_up_proj (column parallel)
|
@@ -412,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
412
414
|
assert (
|
413
415
|
layer.w13_weight.dim() == 3
|
414
416
|
and layer.w13_weight.shape[0] == self.num_experts
|
415
|
-
and layer.w13_weight.shape[1]
|
417
|
+
and layer.w13_weight.shape[1]
|
418
|
+
== self.intermediate_size_per_partition * 2
|
416
419
|
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
417
420
|
)
|
418
421
|
assert (
|
419
422
|
layer.w13_weight_scale.dim() == 3
|
420
423
|
and layer.w13_weight_scale.shape[0] == self.num_experts
|
421
|
-
and layer.w13_weight_scale.shape[1]
|
424
|
+
and layer.w13_weight_scale.shape[1]
|
425
|
+
== self.intermediate_size_per_partition * 2
|
422
426
|
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
423
427
|
)
|
424
428
|
assert (
|
425
429
|
layer.w2_weight.dim() == 3
|
426
430
|
and layer.w2_weight.shape[0] == self.num_experts
|
427
431
|
and layer.w2_weight.shape[1] == self.hidden_size
|
428
|
-
and layer.w2_weight.shape[2]
|
432
|
+
and layer.w2_weight.shape[2]
|
433
|
+
== self.intermediate_size_per_partition // 2
|
429
434
|
)
|
430
435
|
assert (
|
431
436
|
layer.w2_weight_scale.dim() == 3
|
432
437
|
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
433
438
|
and layer.w2_weight_scale.shape[2]
|
434
|
-
== self.
|
439
|
+
== self.intermediate_size_per_partition // sf_block_size
|
435
440
|
)
|
436
441
|
assert (
|
437
442
|
layer.w13_weight_bias.dim() == 2
|
438
443
|
and layer.w13_weight_bias.shape[0] == self.num_experts
|
439
|
-
and layer.w13_weight_bias.shape[1]
|
444
|
+
and layer.w13_weight_bias.shape[1]
|
445
|
+
== self.intermediate_size_per_partition * 2
|
440
446
|
)
|
441
447
|
assert (
|
442
448
|
layer.w2_weight_bias.dim() == 2
|
@@ -513,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
513
519
|
torch.stack(gemm1_scales_mxfp4_shuffled)
|
514
520
|
.reshape(
|
515
521
|
self.num_experts,
|
516
|
-
2 * self.
|
522
|
+
2 * self.intermediate_size_per_partition,
|
517
523
|
self.hidden_size // sf_block_size,
|
518
524
|
)
|
519
525
|
.view(torch.float8_e4m3fn)
|
@@ -525,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
525
531
|
.reshape(
|
526
532
|
self.num_experts,
|
527
533
|
self.hidden_size,
|
528
|
-
self.
|
534
|
+
self.intermediate_size_per_partition // sf_block_size,
|
529
535
|
)
|
530
536
|
.view(torch.float8_e4m3fn)
|
531
537
|
)
|
@@ -615,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
615
621
|
|
616
622
|
return tile_tokens_dim
|
617
623
|
|
624
|
+
def create_moe_runner(
|
625
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
626
|
+
):
|
627
|
+
self.moe_runner_config = moe_runner_config
|
628
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
629
|
+
|
618
630
|
def apply(
|
619
631
|
self,
|
620
632
|
layer: torch.nn.Module,
|
621
|
-
|
622
|
-
|
623
|
-
moe_runner_config: MoeRunnerConfig,
|
624
|
-
) -> torch.Tensor:
|
633
|
+
dispatch_output: StandardDispatchOutput,
|
634
|
+
) -> CombineInput:
|
625
635
|
|
636
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
626
637
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
627
638
|
|
639
|
+
x = dispatch_output.hidden_states
|
640
|
+
topk_output = dispatch_output.topk_output
|
641
|
+
|
642
|
+
moe_runner_config = self.moe_runner_config
|
643
|
+
|
628
644
|
if self.use_flashinfer:
|
629
645
|
# When bf16 mode is enabled, we don't need to quantize the input,
|
630
646
|
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
@@ -676,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
676
692
|
top_k,
|
677
693
|
None, # n_group # TODO: support n_group
|
678
694
|
None, # topk_group # TODO: support topk_group
|
679
|
-
self.
|
695
|
+
self.intermediate_size_per_partition, # padded to multiple of 256
|
680
696
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
681
697
|
layer.num_local_experts, # local num experts
|
682
698
|
None,
|
@@ -684,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
684
700
|
1, # routing_method_type, renormalize
|
685
701
|
True, # do finalize
|
686
702
|
)[0]
|
687
|
-
return trtllm_gen_output
|
703
|
+
return StandardCombineInput(hidden_states=trtllm_gen_output)
|
688
704
|
|
689
705
|
if self.use_triton_kernels:
|
690
706
|
assert (
|
691
707
|
layer.moe_ep_size == 1
|
692
708
|
), "Expert parallel is not supported when using triton kernels"
|
693
709
|
if self.with_bias:
|
694
|
-
|
710
|
+
output = self.triton_kernel_moe_with_bias_forward(
|
695
711
|
hidden_states=x,
|
696
712
|
w1=self.w13_weight_triton_tensor,
|
697
713
|
w1_pcg=self.w13_precision_config,
|
@@ -703,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
703
719
|
moe_runner_config=moe_runner_config,
|
704
720
|
)
|
705
721
|
else:
|
706
|
-
|
722
|
+
output = self.triton_kernel_moe_forward(
|
707
723
|
hidden_states=x,
|
708
724
|
w1=layer.w13_weight,
|
709
725
|
w2=layer.w2_weight,
|
710
726
|
topk_output=topk_output,
|
711
727
|
moe_runner_config=moe_runner_config,
|
712
728
|
)
|
729
|
+
return StandardCombineInput(hidden_states=output)
|
713
730
|
else:
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
w2=layer.w2_weight,
|
720
|
-
topk_output=topk_output,
|
721
|
-
moe_runner_config=moe_runner_config,
|
722
|
-
b1=layer.w13_weight_bias,
|
723
|
-
b2=layer.w2_weight_bias,
|
731
|
+
quant_info = TritonMoeQuantInfo(
|
732
|
+
w13_weight=layer.w13_weight,
|
733
|
+
w2_weight=layer.w2_weight,
|
734
|
+
w13_weight_bias=layer.w13_weight_bias,
|
735
|
+
w2_weight_bias=layer.w2_weight_bias,
|
724
736
|
)
|
737
|
+
return self.runner.run(dispatch_output, quant_info)
|
725
738
|
|
726
739
|
|
727
740
|
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
@@ -800,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
800
813
|
|
801
814
|
return w, mx_scales
|
802
815
|
|
803
|
-
def process_weights_after_loading(self, layer: Module) -> None:
|
816
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
804
817
|
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
805
818
|
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
806
819
|
|
@@ -810,16 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
810
823
|
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
811
824
|
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
812
825
|
|
826
|
+
def create_moe_runner(
|
827
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
828
|
+
):
|
829
|
+
self.moe_runner_config = moe_runner_config
|
830
|
+
|
813
831
|
def apply(
|
814
832
|
self,
|
815
833
|
layer: torch.nn.Module,
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
834
|
+
dispatch_output: StandardDispatchOutput,
|
835
|
+
) -> CombineInput:
|
836
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
837
|
+
|
838
|
+
x = dispatch_output.hidden_states
|
839
|
+
topk_output = dispatch_output.topk_output
|
821
840
|
|
822
|
-
|
841
|
+
topk_weights, topk_ids, _ = topk_output
|
842
|
+
if _is_hip:
|
843
|
+
topk_weights = topk_weights.to(
|
844
|
+
torch.float32
|
845
|
+
) # aiter's moe_sorting requires topk_weights to be FP32
|
846
|
+
output = fused_moe(
|
823
847
|
x,
|
824
848
|
layer.w13_weight,
|
825
849
|
layer.w2_weight,
|
@@ -830,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
830
854
|
w2_scale=layer.w2_weight_scale,
|
831
855
|
activation=(
|
832
856
|
ActivationType.Silu
|
833
|
-
if moe_runner_config.activation == "silu"
|
857
|
+
if self.moe_runner_config.activation == "silu"
|
834
858
|
else ActivationType.Gelu
|
835
859
|
),
|
836
860
|
doweight_stage1=False,
|
837
861
|
)
|
862
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
from typing import Optional
|
17
|
+
|
16
18
|
import torch
|
17
19
|
|
18
20
|
|
@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
|
|
24
26
|
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
25
27
|
|
26
28
|
@classmethod
|
27
|
-
def quantize(cls, input: torch.Tensor, block_size: int
|
29
|
+
def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
|
28
30
|
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
29
31
|
Args:
|
30
32
|
input (torch.Tensor): The input tensor to be quantized.
|
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
|
|
10
10
|
from aiter.fused_moe import fused_moe
|
11
11
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
12
12
|
|
13
|
+
from sglang.srt.layers.moe import MoeRunnerConfig
|
14
|
+
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
13
15
|
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
14
16
|
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
19
|
+
CombineInput,
|
20
|
+
StandardDispatchOutput,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
|
23
|
+
|
15
24
|
logger = logging.getLogger(__name__)
|
16
25
|
|
17
26
|
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
|
19
28
|
OCP_MX_BLOCK_SIZE = 32
|
20
29
|
|
21
30
|
if TYPE_CHECKING:
|
22
|
-
from sglang.srt.layers.
|
23
|
-
|
24
|
-
|
25
|
-
class QuarkMoEMethod:
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
if not hasattr(cls, "_initialized"):
|
30
|
-
original_init = cls.__init__
|
31
|
-
new_cls = type(
|
32
|
-
cls.__name__,
|
33
|
-
(FusedMoEMethodBase,),
|
34
|
-
{
|
35
|
-
"__init__": original_init,
|
36
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
37
|
-
},
|
38
|
-
)
|
39
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
40
|
-
obj.__init__(*args, **kwargs)
|
41
|
-
return obj
|
42
|
-
return super().__new__(cls)
|
31
|
+
from sglang.srt.layers.quantization import QuarkConfig
|
32
|
+
|
33
|
+
|
34
|
+
class QuarkMoEMethod(FusedMoEMethodBase):
|
35
|
+
|
36
|
+
def __init__(self, quant_config: QuarkConfig):
|
37
|
+
self.quant_config = quant_config
|
43
38
|
|
44
39
|
@staticmethod
|
45
40
|
def get_moe_method(
|
46
|
-
quant_config:
|
41
|
+
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
|
47
42
|
module: torch.nn.Module,
|
48
43
|
layer_name: str,
|
49
44
|
) -> "QuarkMoEMethod":
|
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|
170
165
|
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
|
171
166
|
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
172
167
|
|
168
|
+
def create_moe_runner(
|
169
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
170
|
+
):
|
171
|
+
self.moe_runner_config = moe_runner_config
|
172
|
+
|
173
173
|
def apply(
|
174
174
|
self,
|
175
175
|
layer: torch.nn.Module,
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
176
|
+
dispatch_output: StandardDispatchOutput,
|
177
|
+
) -> CombineInput:
|
178
|
+
|
179
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
180
|
+
|
181
|
+
x = dispatch_output.hidden_states
|
182
|
+
topk_output = dispatch_output.topk_output
|
183
|
+
moe_runner_config = self.moe_runner_config
|
180
184
|
topk_weights, topk_ids, _ = topk_output
|
181
185
|
|
182
|
-
|
186
|
+
output = fused_moe(
|
183
187
|
x,
|
184
188
|
layer.w13_weight,
|
185
189
|
layer.w2_weight,
|
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|
195
199
|
),
|
196
200
|
doweight_stage1=False,
|
197
201
|
)
|
202
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8
8
|
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
9
9
|
from aiter.ops.shuffle import shuffle_weight
|
10
10
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
11
|
+
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
11
12
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
12
13
|
from aiter.utility import dtypes
|
13
14
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
@@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|
38
39
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
39
40
|
return
|
40
41
|
|
41
|
-
# for aiter implement
|
42
|
-
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
43
|
-
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
44
|
-
|
45
|
-
# layer.weight = torch.nn.Parameter(wshuffle,
|
46
|
-
# requires_grad=False)
|
47
|
-
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
48
|
-
# requires_grad=False)
|
49
|
-
|
50
42
|
def create_weights(
|
51
43
|
self,
|
52
44
|
layer: torch.nn.Module,
|
@@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|
93
85
|
x: torch.Tensor,
|
94
86
|
bias: Optional[torch.Tensor] = None,
|
95
87
|
) -> torch.Tensor:
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
88
|
+
# This path does not have support for bias currently
|
89
|
+
assert bias is None, "bias is not supported"
|
90
|
+
|
91
|
+
three_d = False
|
92
|
+
x_s = None
|
93
|
+
y = None
|
94
|
+
if isinstance(x, tuple):
|
95
|
+
assert len(x) in [
|
96
|
+
2,
|
97
|
+
3,
|
98
|
+
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
|
99
|
+
if len(x) == 2:
|
100
|
+
x, x_s = x
|
101
|
+
elif len(x) == 3:
|
102
|
+
x, x_s, y = x
|
103
|
+
|
104
|
+
use_fused_quant_gemm = (
|
105
|
+
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
|
114
106
|
)
|
115
107
|
|
116
|
-
|
117
|
-
|
118
|
-
|
108
|
+
if x.dim() == 3:
|
109
|
+
three_d = True
|
110
|
+
x = x.view(-1, x.shape[-1])
|
111
|
+
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
|
112
|
+
|
113
|
+
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
|
114
|
+
# x_s is not None = true, x_q is uint8 num
|
115
|
+
if use_fused_quant_gemm or x_s is not None:
|
116
|
+
x_q = x
|
117
|
+
else:
|
118
|
+
x_q, x_s = dynamic_mxfp4_quant(x)
|
119
|
+
|
120
|
+
if y is None:
|
121
|
+
y = torch.empty(
|
122
|
+
x_q.shape[0],
|
123
|
+
layer.weight.shape[0],
|
124
|
+
device=x_q.device,
|
125
|
+
dtype=self.out_dtype,
|
126
|
+
)
|
127
|
+
|
128
|
+
if use_fused_quant_gemm:
|
129
|
+
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
|
130
|
+
y = y.to(x.dtype)
|
131
|
+
else:
|
132
|
+
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
|
133
|
+
|
134
|
+
if three_d:
|
135
|
+
return y.view(*output_shape)
|
136
|
+
|
137
|
+
return y
|
@@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping
|
|
5
5
|
from types import MappingProxyType
|
6
6
|
from typing import Any, Optional
|
7
7
|
|
8
|
+
import torch
|
9
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
10
|
+
from torch import nn
|
11
|
+
|
8
12
|
|
9
13
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
10
14
|
if type(dict1) is not type(dict2):
|
@@ -105,3 +109,96 @@ def _is_equal_or_regex_match(
|
|
105
109
|
elif target == value:
|
106
110
|
return True
|
107
111
|
return False
|
112
|
+
|
113
|
+
|
114
|
+
# utility for tensor dims > 2 cases
|
115
|
+
def b_dynamic_mxfp4_quant(x):
|
116
|
+
h, b, d = x.shape
|
117
|
+
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
|
118
|
+
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
|
119
|
+
|
120
|
+
|
121
|
+
def mxfp4_to_f32(x, is_threed):
|
122
|
+
# 2 because we pack fp4 in uint8.
|
123
|
+
x = x.repeat_interleave(2, dim=-1)
|
124
|
+
if is_threed:
|
125
|
+
x[..., ::2] = x[..., ::2] & 0xF
|
126
|
+
x[..., 1::2] = x[..., 1::2] >> 4
|
127
|
+
else:
|
128
|
+
x[:, ::2] = x[:, ::2] & 0xF
|
129
|
+
x[:, 1::2] = x[:, 1::2] >> 4
|
130
|
+
|
131
|
+
mxfp4_list = [
|
132
|
+
0.0,
|
133
|
+
0.5,
|
134
|
+
1.0,
|
135
|
+
1.5,
|
136
|
+
2.0,
|
137
|
+
3.0,
|
138
|
+
4.0,
|
139
|
+
6.0,
|
140
|
+
-0.0,
|
141
|
+
-0.5,
|
142
|
+
-1.0,
|
143
|
+
-1.5,
|
144
|
+
-2.0,
|
145
|
+
-3.0,
|
146
|
+
-4.0,
|
147
|
+
-6.0,
|
148
|
+
]
|
149
|
+
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
|
150
|
+
return mxfp4_in_f32[x.long()]
|
151
|
+
|
152
|
+
|
153
|
+
def e8m0_to_f32(x):
|
154
|
+
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
|
155
|
+
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
|
156
|
+
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
|
157
|
+
|
158
|
+
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
|
159
|
+
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
|
160
|
+
|
161
|
+
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
|
162
|
+
# Since this custom format has no mantissa, treat 2^128 as NaN.
|
163
|
+
x_f32[x_f32 == 128] = float("nan")
|
164
|
+
return x_f32
|
165
|
+
|
166
|
+
|
167
|
+
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
|
168
|
+
if "mxfp4" in quant_format:
|
169
|
+
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
|
170
|
+
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
|
171
|
+
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
|
172
|
+
if w.dtype == torch.bfloat16:
|
173
|
+
w_kc, w_vc = w.unflatten(
|
174
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
175
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
176
|
+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
177
|
+
w_kc = w_kc.transpose(-2, -1)
|
178
|
+
w_s_kc = w_s_kc.transpose(-2, -1)
|
179
|
+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
180
|
+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
181
|
+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
182
|
+
elif w.dtype == torch.uint8: # static quant for mxfp4
|
183
|
+
# when dtype is uint8, it means the w has been quantized to mxfp4 format
|
184
|
+
# but we must separate it to w_kc and w_vc.
|
185
|
+
# The quantized tensor size is only half of original tensor size
|
186
|
+
# and the scaling factor is 1/32, the transpose behavior will be not correct
|
187
|
+
# need to upcast it to fp32 to separate w to w_kc and w_vc
|
188
|
+
# to ensure the following transpose behavior is correct
|
189
|
+
# and then do mxfp4 quant again
|
190
|
+
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
|
191
|
+
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
|
192
|
+
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
|
193
|
+
w = w * w_scales
|
194
|
+
w_kc, w_vc = w.unflatten(
|
195
|
+
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
|
196
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
197
|
+
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
198
|
+
w_kc = w_kc.transpose(-2, -1)
|
199
|
+
w_s_kc = w_s_kc.transpose(-2, -1)
|
200
|
+
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
201
|
+
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
202
|
+
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
203
|
+
|
204
|
+
return w_kc, w_s_kc, w_vc, w_s_vc
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
|
2
|
+
batched_gemm_afp4wfp4_pre_quant,
|
3
|
+
)
|
4
|
+
from aiter.ops.triton.fused_mxfp4_quant import (
|
5
|
+
fused_flatten_mxfp4_quant,
|
6
|
+
fused_rms_mxfp4_quant,
|
7
|
+
)
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"fused_rms_mxfp4_quant",
|
11
|
+
"fused_flatten_mxfp4_quant",
|
12
|
+
"batched_gemm_afp4wfp4_pre_quant",
|
13
|
+
]
|