sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
|
|
12
12
|
DeepEPConfig,
|
|
13
13
|
DeepEPDispatcher,
|
|
14
14
|
DeepEPLLCombineInput,
|
|
15
|
-
|
|
15
|
+
DeepEPLLDispatchOutput,
|
|
16
16
|
DeepEPNormalCombineInput,
|
|
17
|
-
|
|
17
|
+
DeepEPNormalDispatchOutput,
|
|
18
18
|
)
|
|
19
19
|
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
|
|
20
20
|
MooncakeCombineInput,
|
|
@@ -44,8 +44,8 @@ __all__ = [
|
|
|
44
44
|
"StandardCombineInput",
|
|
45
45
|
"DeepEPConfig",
|
|
46
46
|
"DeepEPDispatcher",
|
|
47
|
-
"
|
|
48
|
-
"
|
|
47
|
+
"DeepEPNormalDispatchOutput",
|
|
48
|
+
"DeepEPLLDispatchOutput",
|
|
49
49
|
"DeepEPLLCombineInput",
|
|
50
50
|
"DeepEPNormalCombineInput",
|
|
51
51
|
]
|
|
@@ -9,9 +9,9 @@ import torch
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from sglang.srt.layers.moe.token_dispatcher import (
|
|
11
11
|
DeepEPLLCombineInput,
|
|
12
|
-
|
|
12
|
+
DeepEPLLDispatchOutput,
|
|
13
13
|
DeepEPNormalCombineInput,
|
|
14
|
-
|
|
14
|
+
DeepEPNormalDispatchOutput,
|
|
15
15
|
StandardCombineInput,
|
|
16
16
|
StandardDispatchOutput,
|
|
17
17
|
)
|
|
@@ -28,22 +28,28 @@ class DispatchOutputChecker:
|
|
|
28
28
|
) -> TypeGuard[StandardDispatchOutput]:
|
|
29
29
|
return dispatch_output.format.is_standard()
|
|
30
30
|
|
|
31
|
+
@staticmethod
|
|
32
|
+
def format_is_triton_kernels(
|
|
33
|
+
dispatch_output: DispatchOutput,
|
|
34
|
+
) -> TypeGuard[StandardDispatchOutput]:
|
|
35
|
+
return dispatch_output.format.is_standard()
|
|
36
|
+
|
|
31
37
|
@staticmethod
|
|
32
38
|
def format_is_deepep_normal(
|
|
33
39
|
dispatch_output: DispatchOutput,
|
|
34
|
-
) -> TypeGuard[
|
|
40
|
+
) -> TypeGuard[DeepEPNormalDispatchOutput]:
|
|
35
41
|
return dispatch_output.format.is_deepep_normal()
|
|
36
42
|
|
|
37
43
|
@staticmethod
|
|
38
44
|
def format_is_deepep_ll(
|
|
39
45
|
dispatch_output: DispatchOutput,
|
|
40
|
-
) -> TypeGuard[
|
|
46
|
+
) -> TypeGuard[DeepEPLLDispatchOutput]:
|
|
41
47
|
return dispatch_output.format.is_deepep_ll()
|
|
42
48
|
|
|
43
49
|
@staticmethod
|
|
44
50
|
def format_is_deepep(
|
|
45
51
|
dispatch_output: DispatchOutput,
|
|
46
|
-
) -> TypeGuard[Union[
|
|
52
|
+
) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]:
|
|
47
53
|
return dispatch_output.format.is_deepep()
|
|
48
54
|
|
|
49
55
|
|
|
@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
|
|
58
58
|
logger = logging.getLogger(__name__)
|
|
59
59
|
|
|
60
60
|
|
|
61
|
-
class
|
|
61
|
+
class DeepEPNormalDispatchOutput(NamedTuple):
|
|
62
62
|
"""DeepEP normal dispatch output."""
|
|
63
63
|
|
|
64
64
|
hidden_states: torch.Tensor
|
|
@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple):
|
|
|
72
72
|
return DispatchOutputFormat.DEEPEP_NORMAL
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
class
|
|
75
|
+
class DeepEPLLDispatchOutput(NamedTuple):
|
|
76
76
|
"""DeepEP low latency dispatch output."""
|
|
77
77
|
|
|
78
78
|
hidden_states: torch.Tensor
|
|
@@ -87,14 +87,16 @@ class DeepEPLLOutput(NamedTuple):
|
|
|
87
87
|
return DispatchOutputFormat.DEEPEP_LL
|
|
88
88
|
|
|
89
89
|
|
|
90
|
-
assert isinstance(
|
|
91
|
-
assert isinstance(
|
|
90
|
+
assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
|
|
91
|
+
assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
class DeepEPNormalCombineInput(NamedTuple):
|
|
95
95
|
"""DeepEP normal combine input."""
|
|
96
96
|
|
|
97
|
-
|
|
97
|
+
hidden_states: torch.Tensor
|
|
98
|
+
topk_ids: torch.Tensor
|
|
99
|
+
topk_weights: torch.Tensor
|
|
98
100
|
|
|
99
101
|
@property
|
|
100
102
|
def format(self) -> CombineInputFormat:
|
|
@@ -104,7 +106,9 @@ class DeepEPNormalCombineInput(NamedTuple):
|
|
|
104
106
|
class DeepEPLLCombineInput(NamedTuple):
|
|
105
107
|
"""DeepEP low latency combine input."""
|
|
106
108
|
|
|
107
|
-
|
|
109
|
+
hidden_states: torch.Tensor
|
|
110
|
+
topk_ids: torch.Tensor
|
|
111
|
+
topk_weights: torch.Tensor
|
|
108
112
|
|
|
109
113
|
@property
|
|
110
114
|
def format(self) -> CombineInputFormat:
|
|
@@ -327,7 +331,7 @@ class _DeepEPDispatcherImplBase:
|
|
|
327
331
|
hidden_states: torch.Tensor,
|
|
328
332
|
topk_ids: torch.Tensor,
|
|
329
333
|
topk_weights: torch.Tensor,
|
|
330
|
-
overlap_args: Optional[
|
|
334
|
+
overlap_args: Optional[CombineOverlapArgs] = None,
|
|
331
335
|
):
|
|
332
336
|
raise NotImplementedError
|
|
333
337
|
|
|
@@ -383,7 +387,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
383
387
|
else:
|
|
384
388
|
hidden_states_scale = None
|
|
385
389
|
|
|
386
|
-
return
|
|
390
|
+
return DeepEPNormalDispatchOutput(
|
|
387
391
|
hidden_states,
|
|
388
392
|
hidden_states_scale,
|
|
389
393
|
topk_ids,
|
|
@@ -457,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
|
457
461
|
hidden_states: torch.Tensor,
|
|
458
462
|
topk_ids: torch.Tensor,
|
|
459
463
|
topk_weights: torch.Tensor,
|
|
460
|
-
overlap_args: Optional[
|
|
464
|
+
overlap_args: Optional[CombineOverlapArgs] = None,
|
|
461
465
|
):
|
|
462
466
|
|
|
463
467
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
|
@@ -562,7 +566,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
562
566
|
else:
|
|
563
567
|
hidden_states_scale = None
|
|
564
568
|
|
|
565
|
-
deepep_output =
|
|
569
|
+
deepep_output = DeepEPLLDispatchOutput(
|
|
566
570
|
hidden_states,
|
|
567
571
|
hidden_states_scale,
|
|
568
572
|
topk_ids,
|
|
@@ -613,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
613
617
|
hidden_states: torch.Tensor,
|
|
614
618
|
topk_ids: torch.Tensor,
|
|
615
619
|
topk_weights: torch.Tensor,
|
|
616
|
-
overlap_args: Optional[
|
|
620
|
+
overlap_args: Optional[CombineOverlapArgs] = None,
|
|
617
621
|
):
|
|
618
622
|
hidden_states, event, hook = self._combine_core(
|
|
619
623
|
hidden_states,
|
|
@@ -639,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
|
639
643
|
hidden_states: torch.Tensor,
|
|
640
644
|
topk_ids: torch.Tensor,
|
|
641
645
|
topk_weights: torch.Tensor,
|
|
642
|
-
overlap_args: Optional[
|
|
646
|
+
overlap_args: Optional[CombineOverlapArgs] = None,
|
|
643
647
|
):
|
|
644
648
|
buffer = self._get_buffer()
|
|
645
649
|
|
|
@@ -756,18 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
|
756
760
|
del self._dispatch_intermediate_state
|
|
757
761
|
return self._get_impl().dispatch_b(*inner_state)
|
|
758
762
|
|
|
759
|
-
def combine(
|
|
760
|
-
self
|
|
763
|
+
def combine(
|
|
764
|
+
self,
|
|
765
|
+
combine_input: CombineInput,
|
|
766
|
+
overlap_args: Optional[CombineOverlapArgs] = None,
|
|
767
|
+
) -> Tuple:
|
|
768
|
+
self.combine_a(combine_input, overlap_args)
|
|
761
769
|
ret = self.combine_b()
|
|
762
770
|
return ret
|
|
763
771
|
|
|
764
772
|
def combine_a(
|
|
765
773
|
self,
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
topk_weights: torch.Tensor,
|
|
769
|
-
overlap_args: Optional["CombineOverlapArgs"] = None,
|
|
774
|
+
combine_input: CombineInput,
|
|
775
|
+
overlap_args: Optional[CombineOverlapArgs] = None,
|
|
770
776
|
):
|
|
777
|
+
hidden_states, topk_ids, topk_weights = combine_input
|
|
771
778
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
|
772
779
|
inner_state = self._get_impl().combine_a(
|
|
773
780
|
hidden_states=hidden_states,
|
|
@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher):
|
|
|
88
88
|
topk_output = topk_output._replace(
|
|
89
89
|
topk_ids=self.local_expert_mapping[topk_output.topk_ids]
|
|
90
90
|
)
|
|
91
|
-
elif TopKOutputChecker.
|
|
91
|
+
elif TopKOutputChecker.format_is_triton_kernels(topk_output):
|
|
92
92
|
raise NotImplementedError()
|
|
93
93
|
|
|
94
94
|
return StandardDispatchOutput(
|
sglang/srt/layers/moe/topk.py
CHANGED
|
@@ -111,10 +111,10 @@ class TopKOutputChecker:
|
|
|
111
111
|
return topk_output.format.is_standard()
|
|
112
112
|
|
|
113
113
|
@staticmethod
|
|
114
|
-
def
|
|
114
|
+
def format_is_triton_kernels(
|
|
115
115
|
topk_output: TopKOutput,
|
|
116
116
|
) -> TypeGuard[TritonKernelTopKOutput]:
|
|
117
|
-
return topk_output.format.
|
|
117
|
+
return topk_output.format.is_triton_kernels()
|
|
118
118
|
|
|
119
119
|
@staticmethod
|
|
120
120
|
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
|
|
@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
|
|
|
129
129
|
def is_standard(self) -> bool:
|
|
130
130
|
return self == TopKOutputFormat.STANDARD
|
|
131
131
|
|
|
132
|
-
def
|
|
132
|
+
def is_triton_kernels(self) -> bool:
|
|
133
133
|
return self == TopKOutputFormat.TRITON_KERNEL
|
|
134
134
|
|
|
135
135
|
def is_bypassed(self) -> bool:
|
|
@@ -254,7 +254,7 @@ class TopK(CustomOp):
|
|
|
254
254
|
) -> TopKOutput:
|
|
255
255
|
if self.topk_config.output_format is not None:
|
|
256
256
|
output_format = self.topk_config.output_format
|
|
257
|
-
elif get_moe_runner_backend().
|
|
257
|
+
elif get_moe_runner_backend().is_triton_kernels():
|
|
258
258
|
output_format = TopKOutputFormat.TRITON_KERNEL
|
|
259
259
|
elif (
|
|
260
260
|
should_use_flashinfer_trtllm_moe()
|
|
@@ -314,16 +314,41 @@ class TopK(CustomOp):
|
|
|
314
314
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
315
315
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
316
316
|
) -> TopKOutput:
|
|
317
|
-
global_num_experts = router_logits.shape[-1]
|
|
318
317
|
|
|
319
|
-
|
|
320
|
-
|
|
318
|
+
use_grouped_topk = self.topk_config.use_grouped_topk
|
|
319
|
+
torch_native = self.topk_config.torch_native
|
|
320
|
+
renormalize = self.topk_config.renormalize
|
|
321
321
|
|
|
322
|
+
if not use_grouped_topk and not torch_native:
|
|
323
|
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
|
|
324
|
+
router_logits,
|
|
325
|
+
k=self.topk_config.top_k,
|
|
326
|
+
)
|
|
327
|
+
topk_weights = topk_weights.to(torch.float32)
|
|
328
|
+
|
|
329
|
+
if renormalize:
|
|
330
|
+
topk_weights_sum = (
|
|
331
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
|
332
|
+
if self.topk_config.num_fused_shared_experts == 0
|
|
333
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
|
334
|
+
)
|
|
335
|
+
topk_weights = topk_weights / topk_weights_sum
|
|
336
|
+
|
|
337
|
+
if expert_location_dispatch_info is not None:
|
|
338
|
+
topk_ids = topk_ids_logical_to_physical(
|
|
339
|
+
topk_ids, expert_location_dispatch_info
|
|
340
|
+
)
|
|
341
|
+
get_global_expert_distribution_recorder().on_select_experts(
|
|
342
|
+
topk_ids=topk_ids
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return StandardTopKOutput(topk_weights, topk_ids, _)
|
|
346
|
+
if use_grouped_topk and not torch_native and router_logits.shape[-1] == 256:
|
|
347
|
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
|
322
348
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
|
323
|
-
router_logits = router_logits.to(torch.float32)
|
|
324
349
|
|
|
325
350
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
|
326
|
-
router_logits,
|
|
351
|
+
router_logits.to(torch.float32),
|
|
327
352
|
k=self.topk_config.top_k,
|
|
328
353
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
|
329
354
|
k_group=self.topk_config.topk_group,
|
|
@@ -335,7 +360,7 @@ class TopK(CustomOp):
|
|
|
335
360
|
eps=float(1e-20),
|
|
336
361
|
)
|
|
337
362
|
|
|
338
|
-
if
|
|
363
|
+
if renormalize:
|
|
339
364
|
topk_weights_sum = (
|
|
340
365
|
topk_weights.sum(dim=-1, keepdim=True)
|
|
341
366
|
if self.topk_config.num_fused_shared_experts == 0
|
sglang/srt/layers/moe/utils.py
CHANGED
|
@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum):
|
|
|
51
51
|
AUTO = "auto"
|
|
52
52
|
DEEP_GEMM = "deep_gemm"
|
|
53
53
|
TRITON = "triton"
|
|
54
|
-
|
|
54
|
+
TRITON_KERNELS = "triton_kernel"
|
|
55
55
|
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
|
56
56
|
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
|
57
57
|
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
|
@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum):
|
|
|
67
67
|
def is_triton(self):
|
|
68
68
|
return self == MoeRunnerBackend.TRITON
|
|
69
69
|
|
|
70
|
-
def
|
|
71
|
-
return self == MoeRunnerBackend.
|
|
70
|
+
def is_triton_kernels(self):
|
|
71
|
+
return self == MoeRunnerBackend.TRITON_KERNELS
|
|
72
72
|
|
|
73
73
|
def is_flashinfer_trtllm(self):
|
|
74
74
|
return self == MoeRunnerBackend.FLASHINFER_TRTLLM
|
|
@@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs):
|
|
|
152
152
|
def get_moe_a2a_backend() -> MoeA2ABackend:
|
|
153
153
|
global MOE_A2A_BACKEND
|
|
154
154
|
if MOE_A2A_BACKEND is None:
|
|
155
|
-
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
|
156
155
|
MOE_A2A_BACKEND = MoeA2ABackend.NONE
|
|
157
156
|
return MOE_A2A_BACKEND
|
|
158
157
|
|
sglang/srt/layers/pooler.py
CHANGED
|
@@ -20,7 +20,9 @@ class PoolingType(IntEnum):
|
|
|
20
20
|
|
|
21
21
|
@dataclass
|
|
22
22
|
class EmbeddingPoolerOutput:
|
|
23
|
-
|
|
23
|
+
# Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
|
|
24
|
+
# due to different per-request matryoshka dim truncation
|
|
25
|
+
embeddings: torch.Tensor | list[torch.Tensor]
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class Pooler(nn.Module):
|
|
@@ -42,6 +44,7 @@ class Pooler(nn.Module):
|
|
|
42
44
|
def forward(
|
|
43
45
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
|
44
46
|
) -> EmbeddingPoolerOutput:
|
|
47
|
+
|
|
45
48
|
if self.pooling_type == PoolingType.LAST:
|
|
46
49
|
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
|
47
50
|
pooled_data = hidden_states[last_token_indices]
|
|
@@ -53,8 +56,24 @@ class Pooler(nn.Module):
|
|
|
53
56
|
else:
|
|
54
57
|
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
|
55
58
|
|
|
59
|
+
if forward_batch.dimensions is not None:
|
|
60
|
+
all_same_dimensions = len(set(forward_batch.dimensions)) == 1
|
|
61
|
+
if all_same_dimensions:
|
|
62
|
+
pooled_data = pooled_data[..., : forward_batch.dimensions[0]]
|
|
63
|
+
else:
|
|
64
|
+
pooled_data = [
|
|
65
|
+
tensor[..., :dim]
|
|
66
|
+
for tensor, dim in zip(pooled_data, forward_batch.dimensions)
|
|
67
|
+
]
|
|
68
|
+
|
|
56
69
|
if self.normalize:
|
|
57
|
-
|
|
70
|
+
if isinstance(pooled_data, list):
|
|
71
|
+
pooled_data = [
|
|
72
|
+
nn.functional.normalize(tensor, p=2, dim=-1)
|
|
73
|
+
for tensor in pooled_data
|
|
74
|
+
]
|
|
75
|
+
else:
|
|
76
|
+
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=-1)
|
|
58
77
|
|
|
59
78
|
return EmbeddingPoolerOutput(embeddings=pooled_data)
|
|
60
79
|
|
|
@@ -7,36 +7,16 @@ from typing import TYPE_CHECKING, Dict, Optional, Type
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
|
16
|
-
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
|
17
|
-
GPTQMarlin24Config,
|
|
18
|
-
)
|
|
19
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|
20
|
-
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|
21
|
-
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
|
22
|
-
|
|
23
|
-
VLLM_AVAILABLE = True
|
|
24
|
-
except ImportError as e:
|
|
25
|
-
VLLM_AVAILABLE = False
|
|
26
|
-
VLLM_IMPORT_ERROR = e
|
|
27
|
-
|
|
28
|
-
# Define empty classes as placeholders when vllm is not available
|
|
29
|
-
class DummyConfig:
|
|
30
|
-
def override_quantization_method(self, *args, **kwargs):
|
|
31
|
-
return None
|
|
32
|
-
|
|
33
|
-
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
|
34
|
-
ExpertsInt8Config
|
|
35
|
-
) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
|
|
36
|
-
DummyConfig
|
|
37
|
-
)
|
|
10
|
+
|
|
11
|
+
# Define empty classes as placeholders when vllm is not available
|
|
12
|
+
class DummyConfig:
|
|
13
|
+
def override_quantization_method(self, *args, **kwargs):
|
|
14
|
+
return None
|
|
38
15
|
|
|
39
16
|
|
|
17
|
+
CompressedTensorsConfig = DummyConfig
|
|
18
|
+
|
|
19
|
+
from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
|
|
40
20
|
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
|
41
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
42
22
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
@@ -45,6 +25,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
|
45
25
|
)
|
|
46
26
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
|
47
27
|
from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
|
|
28
|
+
from sglang.srt.layers.quantization.gguf import GGUFConfig
|
|
48
29
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
|
49
30
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
|
50
31
|
ModelOptFp4Config,
|
|
@@ -64,7 +45,7 @@ _is_mxfp_supported = mxfp_supported()
|
|
|
64
45
|
if TYPE_CHECKING:
|
|
65
46
|
from sglang.srt.layers.moe.topk import TopKOutput
|
|
66
47
|
|
|
67
|
-
# Base quantization methods
|
|
48
|
+
# Base quantization methods
|
|
68
49
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
69
50
|
"fp8": Fp8Config,
|
|
70
51
|
"blockwise_int8": BlockInt8Config,
|
|
@@ -75,6 +56,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
|
75
56
|
"w8a8_fp8": W8A8Fp8Config,
|
|
76
57
|
"awq": AWQConfig,
|
|
77
58
|
"awq_marlin": AWQMarlinConfig,
|
|
59
|
+
"gguf": GGUFConfig,
|
|
78
60
|
"gptq": GPTQConfig,
|
|
79
61
|
"gptq_marlin": GPTQMarlinConfig,
|
|
80
62
|
"moe_wna16": MoeWNA16Config,
|
|
@@ -83,6 +65,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
|
83
65
|
"w4afp8": W4AFp8Config,
|
|
84
66
|
"petit_nvfp4": PetitNvFp4Config,
|
|
85
67
|
"fbgemm_fp8": FBGEMMFp8Config,
|
|
68
|
+
"auto-round": AutoRoundConfig,
|
|
86
69
|
}
|
|
87
70
|
|
|
88
71
|
|
|
@@ -102,20 +85,8 @@ elif _is_mxfp_supported and is_hip():
|
|
|
102
85
|
"mxfp4": Mxfp4Config,
|
|
103
86
|
}
|
|
104
87
|
)
|
|
105
|
-
# VLLM-dependent quantization methods
|
|
106
|
-
VLLM_QUANTIZATION_METHODS = {
|
|
107
|
-
"aqlm": AQLMConfig,
|
|
108
|
-
"deepspeedfp": DeepSpeedFPConfig,
|
|
109
|
-
"tpu_int8": Int8TpuConfig,
|
|
110
|
-
"marlin": MarlinConfig,
|
|
111
|
-
"gguf": GGUFConfig,
|
|
112
|
-
"gptq_marlin_24": GPTQMarlin24Config,
|
|
113
|
-
"bitsandbytes": BitsAndBytesConfig,
|
|
114
|
-
"qqq": QQQConfig,
|
|
115
|
-
"experts_int8": ExpertsInt8Config,
|
|
116
|
-
}
|
|
117
88
|
|
|
118
|
-
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS
|
|
89
|
+
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS}
|
|
119
90
|
|
|
120
91
|
|
|
121
92
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
@@ -124,50 +95,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
|
124
95
|
f"Invalid quantization method: {quantization}. "
|
|
125
96
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
|
126
97
|
)
|
|
127
|
-
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
|
128
|
-
raise ValueError(
|
|
129
|
-
f"{quantization} quantization requires some operators from vllm. "
|
|
130
|
-
f"Please install vllm by `pip install vllm==0.9.0.1`\n"
|
|
131
|
-
f"Import error: {VLLM_IMPORT_ERROR}"
|
|
132
|
-
)
|
|
133
98
|
|
|
134
99
|
return QUANTIZATION_METHODS[quantization]
|
|
135
100
|
|
|
136
101
|
|
|
137
102
|
original_isinstance = builtins.isinstance
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
141
|
-
"""
|
|
142
|
-
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
|
143
|
-
can recognize sglang layers
|
|
144
|
-
"""
|
|
145
|
-
if not VLLM_AVAILABLE:
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
if reverse:
|
|
149
|
-
builtins.isinstance = original_isinstance
|
|
150
|
-
return
|
|
151
|
-
|
|
152
|
-
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
153
|
-
from vllm.model_executor.layers.linear import LinearBase
|
|
154
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
155
|
-
VocabParallelEmbedding,
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
|
159
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
|
160
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
161
|
-
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
def patched_isinstance(obj, classinfo):
|
|
165
|
-
if classinfo is LinearBase:
|
|
166
|
-
return original_isinstance(obj, PatchedLinearBase)
|
|
167
|
-
if classinfo is FusedMoE:
|
|
168
|
-
return original_isinstance(obj, PatchedFusedMoE)
|
|
169
|
-
if classinfo is VocabParallelEmbedding:
|
|
170
|
-
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
|
171
|
-
return original_isinstance(obj, classinfo)
|
|
172
|
-
|
|
173
|
-
builtins.isinstance = patched_isinstance
|