sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,13 @@ from sglang.srt.layers.moe import (
|
|
23
23
|
get_moe_runner_backend,
|
24
24
|
should_use_flashinfer_trtllm_moe,
|
25
25
|
)
|
26
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
27
|
+
CombineInput,
|
28
|
+
StandardDispatcher,
|
29
|
+
)
|
26
30
|
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
27
31
|
from sglang.srt.layers.quantization.base_config import (
|
32
|
+
FusedMoEMethodBase,
|
28
33
|
QuantizationConfig,
|
29
34
|
QuantizeMethodBase,
|
30
35
|
)
|
@@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module):
|
|
152
157
|
self.expert_map_cpu = None
|
153
158
|
self.expert_map_gpu = None
|
154
159
|
|
155
|
-
self.moe_runner_config = MoeRunnerConfig(
|
156
|
-
activation=activation,
|
157
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
158
|
-
inplace=inplace,
|
159
|
-
no_combine=no_combine,
|
160
|
-
routed_scaling_factor=routed_scaling_factor,
|
161
|
-
gemm1_alpha=gemm1_alpha,
|
162
|
-
gemm1_clamp_limit=gemm1_clamp_limit,
|
163
|
-
)
|
164
|
-
|
165
160
|
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
166
161
|
|
167
162
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
@@ -175,6 +170,8 @@ class FusedMoE(torch.nn.Module):
|
|
175
170
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
176
171
|
assert num_experts % self.moe_ep_size == 0
|
177
172
|
self.num_local_experts = num_experts // self.moe_ep_size
|
173
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
174
|
+
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
178
175
|
if self.moe_ep_size > 1:
|
179
176
|
# TODO(ch-wan): support shared experts fusion
|
180
177
|
# Create a tensor of size num_experts filled with -1
|
@@ -194,13 +191,6 @@ class FusedMoE(torch.nn.Module):
|
|
194
191
|
self.use_presharded_weights = use_presharded_weights
|
195
192
|
|
196
193
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
197
|
-
if quant_config is None:
|
198
|
-
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
199
|
-
self.use_triton_kernels
|
200
|
-
)
|
201
|
-
else:
|
202
|
-
self.quant_method = quant_config.get_quant_method(self, prefix)
|
203
|
-
assert self.quant_method is not None
|
204
194
|
|
205
195
|
self.quant_config = quant_config
|
206
196
|
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
@@ -211,12 +201,40 @@ class FusedMoE(torch.nn.Module):
|
|
211
201
|
and self.use_flashinfer_mxfp4_moe
|
212
202
|
):
|
213
203
|
hidden_size = round_up(hidden_size, 256)
|
204
|
+
self.hidden_size = hidden_size
|
205
|
+
|
206
|
+
self.moe_runner_config = MoeRunnerConfig(
|
207
|
+
num_experts=num_experts,
|
208
|
+
num_local_experts=self.num_local_experts,
|
209
|
+
hidden_size=hidden_size,
|
210
|
+
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
211
|
+
layer_id=layer_id,
|
212
|
+
top_k=top_k,
|
213
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
214
|
+
params_dtype=params_dtype,
|
215
|
+
activation=activation,
|
216
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
217
|
+
inplace=inplace,
|
218
|
+
no_combine=no_combine,
|
219
|
+
routed_scaling_factor=routed_scaling_factor,
|
220
|
+
gemm1_alpha=gemm1_alpha,
|
221
|
+
gemm1_clamp_limit=gemm1_clamp_limit,
|
222
|
+
)
|
223
|
+
|
224
|
+
if quant_config is None:
|
225
|
+
self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
|
226
|
+
self.use_triton_kernels
|
227
|
+
)
|
228
|
+
else:
|
229
|
+
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
|
230
|
+
self, prefix
|
231
|
+
)
|
232
|
+
assert self.quant_method is not None
|
233
|
+
|
214
234
|
self.quant_method.create_weights(
|
215
235
|
layer=self,
|
216
236
|
num_experts=self.num_local_experts,
|
217
237
|
hidden_size=hidden_size,
|
218
|
-
# FIXME: figure out which intermediate_size to use
|
219
|
-
intermediate_size=self.intermediate_size_per_partition,
|
220
238
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
221
239
|
params_dtype=params_dtype,
|
222
240
|
weight_loader=(
|
@@ -227,6 +245,9 @@ class FusedMoE(torch.nn.Module):
|
|
227
245
|
with_bias=with_bias,
|
228
246
|
)
|
229
247
|
|
248
|
+
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
249
|
+
self.dispatcher = StandardDispatcher()
|
250
|
+
|
230
251
|
def _load_per_tensor_weight_scale(
|
231
252
|
self,
|
232
253
|
shard_id: str,
|
@@ -592,9 +613,12 @@ class FusedMoE(torch.nn.Module):
|
|
592
613
|
loaded_weight = loaded_weight.to(param.data.device)
|
593
614
|
|
594
615
|
if (
|
595
|
-
|
596
|
-
|
597
|
-
|
616
|
+
(
|
617
|
+
"compressed" in self.quant_method.__class__.__name__.lower()
|
618
|
+
or "w4afp8" in self.quant_config.get_name()
|
619
|
+
)
|
620
|
+
and (param.data[expert_id] != 1).any()
|
621
|
+
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
|
598
622
|
):
|
599
623
|
raise ValueError(
|
600
624
|
"input_scales of w1 and w3 of a layer "
|
@@ -808,16 +832,17 @@ class FusedMoE(torch.nn.Module):
|
|
808
832
|
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
809
833
|
raise NotImplementedError()
|
810
834
|
|
811
|
-
|
812
|
-
|
835
|
+
dispatch_output = self.dispatcher.dispatch(
|
836
|
+
hidden_states=hidden_states, topk_output=topk_output
|
837
|
+
)
|
813
838
|
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
839
|
+
# TODO: consider using symmetric memory
|
840
|
+
combine_input = self.quant_method.apply(
|
841
|
+
layer=self,
|
842
|
+
dispatch_output=dispatch_output,
|
843
|
+
)
|
844
|
+
|
845
|
+
final_hidden_states = self.dispatcher.combine(combine_input)
|
821
846
|
|
822
847
|
final_hidden_states = final_hidden_states[
|
823
848
|
..., :origin_hidden_states_dim
|
@@ -952,7 +977,6 @@ class FlashInferFusedMoE(FusedMoE):
|
|
952
977
|
layer=self,
|
953
978
|
x=hidden_states,
|
954
979
|
topk_output=topk_output,
|
955
|
-
moe_runner_config=self.moe_runner_config,
|
956
980
|
)
|
957
981
|
|
958
982
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
@@ -1052,16 +1076,3 @@ class FlashInferFP4MoE(FusedMoE):
|
|
1052
1076
|
)[0]
|
1053
1077
|
|
1054
1078
|
return result
|
1055
|
-
|
1056
|
-
|
1057
|
-
def get_fused_moe_impl_class():
|
1058
|
-
"""Factory function to get the appropriate FusedMoE implementation class."""
|
1059
|
-
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
|
1060
|
-
# Use FP4 variant when FP4 quantization is enabled
|
1061
|
-
return FlashInferFP4MoE
|
1062
|
-
elif should_use_flashinfer_trtllm_moe():
|
1063
|
-
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
|
1064
|
-
return FlashInferFusedMoE
|
1065
|
-
else:
|
1066
|
-
# Default case
|
1067
|
-
return FusedMoE
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import triton
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
|
10
|
+
_is_cuda = is_cuda()
|
11
|
+
_is_hip = is_hip()
|
12
|
+
|
13
|
+
if _is_cuda or _is_hip:
|
14
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
15
|
+
|
16
|
+
|
17
|
+
def moe_align_block_size(
|
18
|
+
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
19
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
20
|
+
"""
|
21
|
+
Aligns the token distribution across experts to be compatible with block
|
22
|
+
size for matrix multiplication.
|
23
|
+
|
24
|
+
Parameters:
|
25
|
+
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
26
|
+
top-k expert indices for each token.
|
27
|
+
- block_size: The block size used in block matrix multiplication.
|
28
|
+
- num_experts: The total number of experts.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
- sorted_token_ids: A tensor containing the sorted token indices according
|
32
|
+
to their allocated expert.
|
33
|
+
- expert_ids: A tensor indicating the assigned expert index for each block.
|
34
|
+
- num_tokens_post_padded: The total number of tokens after padding,
|
35
|
+
ensuring divisibility by block_size.
|
36
|
+
|
37
|
+
This function pads the number of tokens that each expert needs to process
|
38
|
+
so that it is divisible by block_size.
|
39
|
+
Padding ensures that during block matrix multiplication, the dimensions
|
40
|
+
align correctly.
|
41
|
+
|
42
|
+
Example:
|
43
|
+
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
44
|
+
block_size = 4, and num_experts = 4:
|
45
|
+
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
46
|
+
with each expert needing to process 3 tokens.
|
47
|
+
- As block_size is 4, we pad 1 token for each expert.
|
48
|
+
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
49
|
+
- Then append padding tokens [12, 12, 12, 12] for each block.
|
50
|
+
- After sorting by expert index, we obtain token_ids
|
51
|
+
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
52
|
+
Tokens 12 are non-existent (padding) and are ignored in
|
53
|
+
the subsequent matrix multiplication.
|
54
|
+
- The padding ensures that the total number of tokens is now divisible
|
55
|
+
by block_size for proper block matrix operations.
|
56
|
+
"""
|
57
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
58
|
+
sorted_ids = torch.empty(
|
59
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
60
|
+
)
|
61
|
+
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
62
|
+
expert_ids = torch.empty(
|
63
|
+
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
64
|
+
)
|
65
|
+
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
66
|
+
|
67
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
68
|
+
cumsum_buffer = torch.empty(
|
69
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
70
|
+
)
|
71
|
+
|
72
|
+
# Threshold based on benchmark results
|
73
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
74
|
+
if not fuse_sorted_ids_padding:
|
75
|
+
sorted_ids.fill_(topk_ids.numel())
|
76
|
+
|
77
|
+
sgl_moe_align_block_size(
|
78
|
+
topk_ids,
|
79
|
+
num_experts + 1,
|
80
|
+
block_size,
|
81
|
+
sorted_ids,
|
82
|
+
expert_ids,
|
83
|
+
num_tokens_post_pad,
|
84
|
+
cumsum_buffer,
|
85
|
+
fuse_sorted_ids_padding,
|
86
|
+
)
|
87
|
+
return sorted_ids, expert_ids, num_tokens_post_pad
|
@@ -1,9 +1,41 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
1
4
|
from dataclasses import dataclass
|
2
|
-
from typing import Optional
|
5
|
+
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.moe.moe_runner.triton import (
|
13
|
+
TritonRunnerCore,
|
14
|
+
TritonRunnerInput,
|
15
|
+
TritonRunnerOutput,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
18
|
+
CombineInput,
|
19
|
+
CombineInputFormat,
|
20
|
+
DispatchOutput,
|
21
|
+
DispatchOutputFormat,
|
22
|
+
)
|
3
23
|
|
4
24
|
|
5
25
|
@dataclass
|
6
26
|
class MoeRunnerConfig:
|
27
|
+
|
28
|
+
# MoE parameters
|
29
|
+
num_experts: Optional[int] = None
|
30
|
+
num_local_experts: Optional[int] = None
|
31
|
+
hidden_size: Optional[int] = None
|
32
|
+
intermediate_size_per_partition: Optional[int] = None
|
33
|
+
layer_id: Optional[int] = None
|
34
|
+
top_k: Optional[int] = None
|
35
|
+
num_fused_shared_experts: Optional[int] = None
|
36
|
+
params_dtype: Optional[torch.dtype] = None
|
37
|
+
|
38
|
+
# Runner configuration
|
7
39
|
activation: str = "silu"
|
8
40
|
apply_router_weight_on_input: bool = False
|
9
41
|
inplace: bool = True
|
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
|
|
11
43
|
routed_scaling_factor: Optional[float] = None
|
12
44
|
gemm1_alpha: Optional[float] = None
|
13
45
|
gemm1_clamp_limit: Optional[float] = None
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class RunnerInput(ABC):
|
50
|
+
|
51
|
+
@property
|
52
|
+
@abstractmethod
|
53
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
54
|
+
|
55
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
|
56
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
57
|
+
|
58
|
+
|
59
|
+
class RunnerOutput(ABC):
|
60
|
+
|
61
|
+
@property
|
62
|
+
@abstractmethod
|
63
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
64
|
+
|
65
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
|
66
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
67
|
+
|
68
|
+
|
69
|
+
@dataclass
|
70
|
+
class MoeQuantInfo(ABC):
|
71
|
+
"""Moe quantization data."""
|
72
|
+
|
73
|
+
pass
|
74
|
+
|
75
|
+
|
76
|
+
class MoeRunnerCore(ABC):
|
77
|
+
|
78
|
+
def __init__(self, config: MoeRunnerConfig):
|
79
|
+
self.config = config
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def run(
|
83
|
+
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
|
84
|
+
) -> RunnerOutput:
|
85
|
+
pass
|
86
|
+
|
87
|
+
@property
|
88
|
+
@abstractmethod
|
89
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
90
|
+
|
91
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
|
92
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
93
|
+
|
94
|
+
|
95
|
+
class FusedOpPool:
|
96
|
+
|
97
|
+
_fused_funcs: dict[str, Callable] = {}
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def register_fused_func(
|
101
|
+
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
|
102
|
+
):
|
103
|
+
key = (a2a_backend_name, runner_backend_name)
|
104
|
+
if key in cls._fused_funcs:
|
105
|
+
raise ValueError(
|
106
|
+
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
|
107
|
+
)
|
108
|
+
assert MoeA2ABackend(
|
109
|
+
a2a_backend_name
|
110
|
+
), f"Invalid dispatch name: {a2a_backend_name}"
|
111
|
+
assert MoeRunnerBackend(
|
112
|
+
runner_backend_name
|
113
|
+
), f"Invalid runner name: {runner_backend_name}"
|
114
|
+
cls._fused_funcs[key] = fused_func
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
|
118
|
+
key = (dispatch_name, runner_name)
|
119
|
+
fused_func = cls._fused_funcs.get(key)
|
120
|
+
return fused_func
|
121
|
+
|
122
|
+
|
123
|
+
class PermuteMethodPool:
|
124
|
+
|
125
|
+
_pre_permute_methods: dict[
|
126
|
+
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
|
127
|
+
] = {}
|
128
|
+
_post_permute_methods: dict[
|
129
|
+
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
|
130
|
+
] = {}
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def register_pre_permute(
|
134
|
+
cls,
|
135
|
+
dispatch_output_name: str,
|
136
|
+
runner_backend_name: str,
|
137
|
+
permute_func: Callable,
|
138
|
+
):
|
139
|
+
"""
|
140
|
+
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
141
|
+
|
142
|
+
:param dispatch_output_name: The DispatchOutputFormat name.
|
143
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
144
|
+
:param permute_func: The permute function to register.
|
145
|
+
"""
|
146
|
+
# TODO: check if registration is valid
|
147
|
+
key = (dispatch_output_name, runner_backend_name)
|
148
|
+
if key in cls._pre_permute_methods:
|
149
|
+
raise ValueError(
|
150
|
+
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
151
|
+
)
|
152
|
+
cls._pre_permute_methods[key] = permute_func
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def register_post_permute(
|
156
|
+
cls,
|
157
|
+
runner_backend_name: str,
|
158
|
+
combine_input_name: str,
|
159
|
+
permute_func: Callable,
|
160
|
+
):
|
161
|
+
"""
|
162
|
+
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
163
|
+
|
164
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
165
|
+
:param combine_input_name: The CombineInputFormat name.
|
166
|
+
:param permute_func: The permute function to register.
|
167
|
+
"""
|
168
|
+
# TODO: check if registration is valid
|
169
|
+
key = (runner_backend_name, combine_input_name)
|
170
|
+
if key in cls._post_permute_methods:
|
171
|
+
raise ValueError(
|
172
|
+
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
173
|
+
)
|
174
|
+
cls._post_permute_methods[key] = permute_func
|
175
|
+
|
176
|
+
@classmethod
|
177
|
+
def get_pre_permute(
|
178
|
+
cls,
|
179
|
+
dispatch_output_format: DispatchOutputFormat,
|
180
|
+
runner_input_format: MoeRunnerBackend,
|
181
|
+
) -> Callable:
|
182
|
+
"""
|
183
|
+
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
184
|
+
|
185
|
+
:param dispatch_output_format: The DispatchOutputFormat type.
|
186
|
+
:param runner_input_format: The MoeRunnerBackend type.
|
187
|
+
:return: The registered permute function or None if not found.
|
188
|
+
"""
|
189
|
+
key = (dispatch_output_format, runner_input_format)
|
190
|
+
pre_permute_func = cls._pre_permute_methods.get(key)
|
191
|
+
assert (
|
192
|
+
pre_permute_func is not None
|
193
|
+
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
|
194
|
+
return pre_permute_func
|
195
|
+
|
196
|
+
@classmethod
|
197
|
+
def get_post_permute(
|
198
|
+
cls,
|
199
|
+
runner_output_format: MoeRunnerBackend,
|
200
|
+
combine_input_format: CombineInputFormat,
|
201
|
+
) -> Callable:
|
202
|
+
"""
|
203
|
+
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
204
|
+
|
205
|
+
:param runner_output_format: The MoeRunnerBackend type.
|
206
|
+
:param combine_input_format: The CombineInputFormat type.
|
207
|
+
:return: The registered permute function or None if not found.
|
208
|
+
"""
|
209
|
+
key = (runner_output_format, combine_input_format)
|
210
|
+
post_permute_func = cls._post_permute_methods.get(key)
|
211
|
+
assert (
|
212
|
+
post_permute_func is not None
|
213
|
+
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
|
214
|
+
return post_permute_func
|
215
|
+
|
216
|
+
|
217
|
+
def register_fused_func(
|
218
|
+
a2a_backend_name: str,
|
219
|
+
runner_backend_name: str,
|
220
|
+
) -> Callable:
|
221
|
+
"""
|
222
|
+
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
|
223
|
+
|
224
|
+
:param a2a_backend_name: The A2A backend name.
|
225
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
226
|
+
:return: The decorator function.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def decorator(fused_func: Callable):
|
230
|
+
FusedOpPool.register_fused_func(
|
231
|
+
a2a_backend_name, runner_backend_name, fused_func
|
232
|
+
)
|
233
|
+
return fused_func
|
234
|
+
|
235
|
+
return decorator
|
236
|
+
|
237
|
+
|
238
|
+
def register_pre_permute(
|
239
|
+
dispatch_output_name: str,
|
240
|
+
runner_backend_name: str,
|
241
|
+
) -> Callable:
|
242
|
+
"""
|
243
|
+
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
244
|
+
|
245
|
+
:param dispatch_output_name: The DispatchOutputFormat name.
|
246
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
247
|
+
:return: The decorator function.
|
248
|
+
"""
|
249
|
+
|
250
|
+
def decorator(
|
251
|
+
permute_func: Callable[
|
252
|
+
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
|
253
|
+
]
|
254
|
+
) -> Callable:
|
255
|
+
|
256
|
+
PermuteMethodPool.register_pre_permute(
|
257
|
+
dispatch_output_name, runner_backend_name, permute_func
|
258
|
+
)
|
259
|
+
return permute_func
|
260
|
+
|
261
|
+
return decorator
|
262
|
+
|
263
|
+
|
264
|
+
def register_post_permute(
|
265
|
+
runner_backend_name: str,
|
266
|
+
combine_input_name: str,
|
267
|
+
) -> Callable:
|
268
|
+
"""
|
269
|
+
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
270
|
+
|
271
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
272
|
+
:param combine_input_name: The CombineInputFormat name.
|
273
|
+
:return: The decorator function.
|
274
|
+
"""
|
275
|
+
|
276
|
+
def decorator(
|
277
|
+
permute_func: Callable[
|
278
|
+
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
|
279
|
+
]
|
280
|
+
) -> Callable:
|
281
|
+
PermuteMethodPool.register_post_permute(
|
282
|
+
runner_backend_name, combine_input_name, permute_func
|
283
|
+
)
|
284
|
+
return permute_func
|
285
|
+
|
286
|
+
return decorator
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.moe_runner.base import (
|
8
|
+
FusedOpPool,
|
9
|
+
MoeRunnerConfig,
|
10
|
+
PermuteMethodPool,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
13
|
+
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
17
|
+
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
|
18
|
+
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class MoeRunner:
|
24
|
+
|
25
|
+
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
|
26
|
+
self.runner_backend = runner_backend
|
27
|
+
self.config = config
|
28
|
+
|
29
|
+
self.fused_func = None
|
30
|
+
|
31
|
+
if runner_backend.is_triton():
|
32
|
+
self.runner_core = TritonRunnerCore(config)
|
33
|
+
else:
|
34
|
+
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
35
|
+
|
36
|
+
a2a_backend_name = get_moe_a2a_backend().value
|
37
|
+
runner_backend_name = runner_backend.value
|
38
|
+
|
39
|
+
self.fused_func = FusedOpPool.get_fused_func(
|
40
|
+
a2a_backend_name, runner_backend_name
|
41
|
+
)
|
42
|
+
|
43
|
+
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
|
44
|
+
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
|
45
|
+
)
|
46
|
+
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
|
47
|
+
logger.info(
|
48
|
+
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
|
49
|
+
)
|
50
|
+
self.fused_func = None
|
51
|
+
|
52
|
+
def run(
|
53
|
+
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
|
54
|
+
) -> CombineInput:
|
55
|
+
|
56
|
+
if self.fused_func is not None:
|
57
|
+
return self.fused_func(dispatch_output, quant_info, self.config)
|
58
|
+
|
59
|
+
dispatch_format = dispatch_output.format.value
|
60
|
+
runner_format = self.runner_core.runner_backend.value
|
61
|
+
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
|
62
|
+
dispatch_format, runner_format
|
63
|
+
)
|
64
|
+
|
65
|
+
running_state = {}
|
66
|
+
runner_input = self.pre_permute_func(
|
67
|
+
dispatch_output, quant_info, self.config, running_state
|
68
|
+
)
|
69
|
+
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
|
70
|
+
|
71
|
+
runner_format = self.runner_core.runner_backend.value
|
72
|
+
combine_format = dispatch_output.format.value
|
73
|
+
self.post_permute_func = PermuteMethodPool.get_post_permute(
|
74
|
+
runner_format, combine_format
|
75
|
+
)
|
76
|
+
combine_input = self.post_permute_func(
|
77
|
+
runner_output, quant_info, self.config, running_state
|
78
|
+
)
|
79
|
+
|
80
|
+
return combine_input
|