sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
from fractions import Fraction
|
|
6
|
+
from typing import Any, Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
from sglang.srt.layers.quantization.utils import get_scalar_types
|
|
13
|
+
|
|
14
|
+
ScalarType, scalar_types = get_scalar_types()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
18
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AutoRoundConfig(QuantizationConfig):
|
|
22
|
+
"""Config class for AutoRound.
|
|
23
|
+
Reference: https://arxiv.org/pdf/2309.05516
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
SUPPORTED_BITS = {2, 3, 4, 8}
|
|
27
|
+
SUPPORTED_DTYPES = {"int"}
|
|
28
|
+
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
|
|
29
|
+
SUPPORTED_BACKENDS = {"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin"}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
weight_bits: int,
|
|
34
|
+
group_size: int,
|
|
35
|
+
sym: bool = True,
|
|
36
|
+
packing_format: str = "auto_round:auto_gptq",
|
|
37
|
+
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
|
|
38
|
+
extra_config: Optional[dict[str, Any]] = None,
|
|
39
|
+
data_type: str = "int",
|
|
40
|
+
backend: str = "auto",
|
|
41
|
+
) -> None:
|
|
42
|
+
super().__init__()
|
|
43
|
+
if weight_bits not in self.SUPPORTED_BITS:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"Unsupported weight_bits: {weight_bits}, "
|
|
46
|
+
f"currently only support {self.SUPPORTED_BITS}"
|
|
47
|
+
)
|
|
48
|
+
if data_type not in self.SUPPORTED_DTYPES:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"Unsupported data_type: {data_type},"
|
|
51
|
+
f" currently only support {self.SUPPORTED_DTYPES}"
|
|
52
|
+
)
|
|
53
|
+
if packing_format not in self.SUPPORTED_FORMATS:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Unsupported packing_format: {packing_format}, "
|
|
56
|
+
f"currently only support {self.SUPPORTED_FORMATS}"
|
|
57
|
+
)
|
|
58
|
+
if backend not in self.SUPPORTED_BACKENDS:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Unsupported backend: {backend}, "
|
|
61
|
+
f"currently only support {self.SUPPORTED_BACKENDS}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.weight_bits = weight_bits
|
|
65
|
+
self.group_size = group_size
|
|
66
|
+
self.sym = sym
|
|
67
|
+
self.packing_format = packing_format
|
|
68
|
+
self.block_name_to_quantize = (
|
|
69
|
+
block_name_to_quantize.split(",")
|
|
70
|
+
if isinstance(block_name_to_quantize, str)
|
|
71
|
+
else block_name_to_quantize
|
|
72
|
+
)
|
|
73
|
+
self.extra_config = extra_config
|
|
74
|
+
self.data_type = data_type
|
|
75
|
+
self.backend = backend
|
|
76
|
+
self.pack_factor = Fraction(32, weight_bits)
|
|
77
|
+
|
|
78
|
+
def __repr__(self) -> str:
|
|
79
|
+
return (
|
|
80
|
+
f"AutoRoundConfig(weight_bits={self.weight_bits}, "
|
|
81
|
+
f"group_size={self.group_size}, sym={self.sym})"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def get_name(cls):
|
|
86
|
+
return "auto-round"
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
90
|
+
return [torch.half, torch.bfloat16]
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def get_min_capability(cls) -> int:
|
|
94
|
+
return 60
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def get_config_filenames(cls) -> list[str]:
|
|
98
|
+
return ["quantization_config.json"]
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
|
|
102
|
+
return cls(
|
|
103
|
+
weight_bits=cls.get_from_keys(config, ["bits"]),
|
|
104
|
+
group_size=cls.get_from_keys(config, ["group_size"]),
|
|
105
|
+
sym=cls.get_from_keys(config, ["sym"]),
|
|
106
|
+
packing_format=cls.get_from_keys_or(
|
|
107
|
+
config,
|
|
108
|
+
["packing_format"],
|
|
109
|
+
"auto_round:auto_gptq",
|
|
110
|
+
),
|
|
111
|
+
block_name_to_quantize=cls.get_from_keys_or(
|
|
112
|
+
config, ["block_name_to_quantize", "to_quant_block_names"], None
|
|
113
|
+
),
|
|
114
|
+
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
|
|
115
|
+
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
|
|
116
|
+
backend=cls.get_from_keys_or(
|
|
117
|
+
config, ["backend", "vllm_backend", "sglang_backend"], "auto"
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def get_scaled_act_names(self) -> list[str]:
|
|
122
|
+
"""Returns the activation function names that should be post-scaled.
|
|
123
|
+
|
|
124
|
+
For now, this is only used by AWQ.
|
|
125
|
+
"""
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
def get_layer_config(self, layer, layer_name: str):
|
|
129
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
130
|
+
|
|
131
|
+
def get_config(name: str, quantized: bool = True):
|
|
132
|
+
if not self.extra_config:
|
|
133
|
+
return (
|
|
134
|
+
self.weight_bits if quantized else 16,
|
|
135
|
+
self.group_size if quantized else -1,
|
|
136
|
+
self.sym if quantized else True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Exact match first
|
|
140
|
+
if name in self.extra_config:
|
|
141
|
+
cfg = self.extra_config[name]
|
|
142
|
+
return (
|
|
143
|
+
cfg.get("bits", self.weight_bits if quantized else 16),
|
|
144
|
+
cfg.get("group_size", self.group_size if quantized else -1),
|
|
145
|
+
cfg.get("sym", self.sym if quantized else True),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
|
|
149
|
+
for pattern, cfg in self.extra_config.items():
|
|
150
|
+
if not isinstance(pattern, str) or not any(
|
|
151
|
+
c in REGEX_SPECIAL_CHARS for c in pattern
|
|
152
|
+
):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
if re.fullmatch(pattern, name):
|
|
157
|
+
return (
|
|
158
|
+
cfg.get("bits", self.weight_bits if quantized else 16),
|
|
159
|
+
cfg.get("group_size", self.group_size if quantized else -1),
|
|
160
|
+
cfg.get("sym", self.sym if quantized else True),
|
|
161
|
+
)
|
|
162
|
+
except re.error:
|
|
163
|
+
# Invalid regex, ignore.
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
return (
|
|
167
|
+
self.weight_bits if quantized else 16,
|
|
168
|
+
self.group_size if quantized else -1,
|
|
169
|
+
self.sym if quantized else True,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# 1. Exact match from config
|
|
173
|
+
if self.extra_config and layer_name in self.extra_config:
|
|
174
|
+
return get_config(layer_name)
|
|
175
|
+
|
|
176
|
+
# 2. Determine whether layer should be quantized
|
|
177
|
+
quantized = not isinstance(layer, ParallelLMHead)
|
|
178
|
+
if self.block_name_to_quantize:
|
|
179
|
+
quantized = any(
|
|
180
|
+
layer_name.startswith(name) for name in self.block_name_to_quantize
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# 3. Handle fused MoE
|
|
184
|
+
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
|
|
185
|
+
moe_configs = [
|
|
186
|
+
get_config(name, quantized)
|
|
187
|
+
for name in self.extra_config
|
|
188
|
+
if name.startswith(layer_name)
|
|
189
|
+
]
|
|
190
|
+
if moe_configs:
|
|
191
|
+
if len(set(moe_configs)) == 1:
|
|
192
|
+
return moe_configs[0]
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"Fused MoE layer '{layer_name}' requires "
|
|
195
|
+
f"consistent quant config for all sub-layers"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# 4. Handle fused QKV or other patterns
|
|
199
|
+
if self.extra_config:
|
|
200
|
+
for fusion_key, sub_keys in self.packed_modules_mapping.items():
|
|
201
|
+
if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
|
|
202
|
+
sub_names = [
|
|
203
|
+
layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
|
|
204
|
+
]
|
|
205
|
+
sub_configs = [get_config(name, quantized) for name in sub_names]
|
|
206
|
+
if len(set(sub_configs)) == 1:
|
|
207
|
+
return sub_configs[0]
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Fused module '{layer_name}' requires "
|
|
210
|
+
f"consistent quant config for {sub_names}"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# 5. Fallback or try a regular expression match
|
|
214
|
+
return get_config(layer_name, quantized)
|
|
215
|
+
|
|
216
|
+
def check_quantized(self, weight_bits: int) -> bool:
|
|
217
|
+
return weight_bits < 16
|
|
218
|
+
|
|
219
|
+
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
|
220
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
221
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
|
222
|
+
check_marlin_supported,
|
|
223
|
+
check_moe_marlin_supports_layer,
|
|
224
|
+
)
|
|
225
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
226
|
+
|
|
227
|
+
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
|
228
|
+
if not self.check_quantized(weight_bits):
|
|
229
|
+
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
|
230
|
+
return UnquantizedLinearMethod()
|
|
231
|
+
else:
|
|
232
|
+
return None
|
|
233
|
+
logger.debug(
|
|
234
|
+
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
|
235
|
+
prefix,
|
|
236
|
+
layer.__class__.__name__,
|
|
237
|
+
weight_bits,
|
|
238
|
+
group_size,
|
|
239
|
+
sym,
|
|
240
|
+
)
|
|
241
|
+
if backend == "auto" or "marlin" in backend:
|
|
242
|
+
AWQ_TYPE_MAP = {
|
|
243
|
+
4: scalar_types.uint4,
|
|
244
|
+
8: scalar_types.uint8,
|
|
245
|
+
}
|
|
246
|
+
use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
|
|
247
|
+
AWQ_TYPE_MAP[weight_bits], group_size, not sym
|
|
248
|
+
)
|
|
249
|
+
if isinstance(layer, FusedMoE):
|
|
250
|
+
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
|
251
|
+
layer, group_size
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
else:
|
|
255
|
+
use_marlin = False
|
|
256
|
+
if use_marlin:
|
|
257
|
+
from sglang.srt.layers.quantization.awq import (
|
|
258
|
+
AWQMarlinConfig,
|
|
259
|
+
AWQMarlinLinearMethod,
|
|
260
|
+
AWQMoEMethod,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
quant_args_marlin = AWQMarlinConfig(
|
|
264
|
+
weight_bits=weight_bits,
|
|
265
|
+
group_size=group_size,
|
|
266
|
+
zero_point=not sym,
|
|
267
|
+
lm_head_quantized=False,
|
|
268
|
+
full_config={},
|
|
269
|
+
modules_to_not_convert=[],
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
from sglang.srt.layers.quantization.awq import AWQConfig, AWQLinearMethod
|
|
273
|
+
|
|
274
|
+
quant_args = AWQConfig(
|
|
275
|
+
weight_bits=weight_bits,
|
|
276
|
+
group_size=group_size,
|
|
277
|
+
zero_point=not sym,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if isinstance(layer, FusedMoE):
|
|
281
|
+
if use_marlin:
|
|
282
|
+
return AWQMoEMethod(quant_args_marlin)
|
|
283
|
+
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
|
284
|
+
|
|
285
|
+
config = {
|
|
286
|
+
"quant_method": "awq",
|
|
287
|
+
"bits": weight_bits,
|
|
288
|
+
"group_size": group_size,
|
|
289
|
+
"zero_point": not sym,
|
|
290
|
+
"lm_head": False,
|
|
291
|
+
}
|
|
292
|
+
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
|
293
|
+
|
|
294
|
+
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
|
295
|
+
if use_marlin:
|
|
296
|
+
return AWQMarlinLinearMethod(quant_args_marlin)
|
|
297
|
+
else:
|
|
298
|
+
return AWQLinearMethod(quant_args)
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
|
|
302
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
303
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
|
304
|
+
check_marlin_supported,
|
|
305
|
+
check_moe_marlin_supports_layer,
|
|
306
|
+
)
|
|
307
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
308
|
+
|
|
309
|
+
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
|
|
310
|
+
if not self.check_quantized(weight_bits):
|
|
311
|
+
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
|
312
|
+
return UnquantizedLinearMethod()
|
|
313
|
+
else:
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
logger.debug(
|
|
317
|
+
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
|
|
318
|
+
prefix,
|
|
319
|
+
layer.__class__.__name__,
|
|
320
|
+
weight_bits,
|
|
321
|
+
group_size,
|
|
322
|
+
sym,
|
|
323
|
+
)
|
|
324
|
+
if backend == "auto" or "marlin" in backend:
|
|
325
|
+
GPTQ_TYPE_MAP = {
|
|
326
|
+
(4, True): scalar_types.uint4b8,
|
|
327
|
+
(8, True): scalar_types.uint8b128,
|
|
328
|
+
}
|
|
329
|
+
use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
|
|
330
|
+
GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
|
|
331
|
+
)
|
|
332
|
+
if isinstance(layer, FusedMoE):
|
|
333
|
+
use_marlin = use_marlin and check_moe_marlin_supports_layer(
|
|
334
|
+
layer, group_size
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
use_marlin = False
|
|
338
|
+
if use_marlin:
|
|
339
|
+
from sglang.srt.layers.quantization.gptq import (
|
|
340
|
+
GPTQMarlinConfig,
|
|
341
|
+
GPTQMarlinLinearMethod,
|
|
342
|
+
GPTQMarlinMoEMethod,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
quant_args_marlin = GPTQMarlinConfig(
|
|
346
|
+
weight_bits=weight_bits,
|
|
347
|
+
group_size=group_size,
|
|
348
|
+
is_sym=sym,
|
|
349
|
+
lm_head_quantized=False,
|
|
350
|
+
desc_act=False,
|
|
351
|
+
dynamic={},
|
|
352
|
+
full_config={},
|
|
353
|
+
)
|
|
354
|
+
else:
|
|
355
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQLinearMethod
|
|
356
|
+
|
|
357
|
+
quant_args = GPTQConfig(
|
|
358
|
+
weight_bits=weight_bits,
|
|
359
|
+
group_size=group_size,
|
|
360
|
+
lm_head_quantized=False,
|
|
361
|
+
desc_act=False,
|
|
362
|
+
dynamic={},
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
if isinstance(layer, FusedMoE):
|
|
366
|
+
if use_marlin:
|
|
367
|
+
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
|
368
|
+
|
|
369
|
+
config = {
|
|
370
|
+
"quant_method": "gptq",
|
|
371
|
+
"bits": weight_bits,
|
|
372
|
+
"group_size": group_size,
|
|
373
|
+
"sym": sym,
|
|
374
|
+
"lm_head": False,
|
|
375
|
+
}
|
|
376
|
+
return MoeWNA16Config.from_config(config).get_quant_method(
|
|
377
|
+
layer, prefix
|
|
378
|
+
)
|
|
379
|
+
return GPTQMarlinMoEMethod(quant_args_marlin)
|
|
380
|
+
|
|
381
|
+
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
|
382
|
+
if use_marlin:
|
|
383
|
+
return GPTQMarlinLinearMethod(quant_args_marlin)
|
|
384
|
+
else:
|
|
385
|
+
return GPTQLinearMethod(quant_args)
|
|
386
|
+
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
|
390
|
+
# TODO enable CPU quant method later
|
|
391
|
+
if "gptq" in self.packing_format or "gptq" in self.backend:
|
|
392
|
+
return self.apply_gptq_quant_layer(layer, prefix)
|
|
393
|
+
if "awq" in self.packing_format or "awq" in self.backend:
|
|
394
|
+
return self.apply_awq_quant_layer(layer, prefix)
|
|
@@ -840,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
|
840
840
|
self.moe_runner_config.activation == "silu"
|
|
841
841
|
), "Only SiLU activation is supported."
|
|
842
842
|
|
|
843
|
-
# The input must currently be float16
|
|
844
843
|
x = dispatch_output.hidden_states
|
|
845
844
|
topk_output = dispatch_output.topk_output
|
|
846
|
-
|
|
847
845
|
orig_dtype = x.dtype
|
|
848
|
-
x = x.half()
|
|
849
846
|
|
|
850
847
|
topk_weights, topk_ids, router_logits = topk_output
|
|
851
848
|
|
|
@@ -179,6 +179,13 @@ class QuantizationConfig(ABC):
|
|
|
179
179
|
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
|
|
180
180
|
return "modelopt_fp4"
|
|
181
181
|
|
|
182
|
+
# The hf_quant_config may be a parsed quant config, so we need to check the
|
|
183
|
+
# quant_method.
|
|
184
|
+
if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
|
|
185
|
+
return "modelopt_fp8"
|
|
186
|
+
elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
|
|
187
|
+
return "modelopt_fp4"
|
|
188
|
+
|
|
182
189
|
return None
|
|
183
190
|
|
|
184
191
|
@staticmethod
|
|
@@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
|
|
33
33
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
|
34
34
|
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
|
|
35
35
|
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
|
36
|
+
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
|
36
37
|
from sglang.srt.layers.parameter import (
|
|
37
38
|
BlockQuantScaleParameter,
|
|
38
39
|
ModelWeightParameter,
|
|
@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
525
526
|
self.quant_config = quant_config
|
|
526
527
|
self.block_quant = self.quant_config.weight_block_size is not None
|
|
527
528
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
|
528
|
-
self.use_cutlass_fused_experts_fp8 = (
|
|
529
|
-
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
|
530
|
-
and self.cutlass_fp8_supported
|
|
531
|
-
and self.block_quant
|
|
532
|
-
and (is_sm100_supported() or is_sm90_supported())
|
|
533
|
-
)
|
|
534
529
|
|
|
535
530
|
def create_weights(
|
|
536
531
|
self,
|
|
@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
638
633
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
|
639
634
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
|
640
635
|
assert self.quant_config.activation_scheme == "dynamic"
|
|
641
|
-
if self.
|
|
642
|
-
self.
|
|
643
|
-
(num_experts,),
|
|
644
|
-
hidden_size,
|
|
645
|
-
device=w13_weight.device,
|
|
646
|
-
dtype=torch.int64,
|
|
647
|
-
)
|
|
648
|
-
self.c_strides1 = torch.full(
|
|
649
|
-
(num_experts,),
|
|
650
|
-
2 * intermediate_size_per_partition,
|
|
651
|
-
device=w13_weight.device,
|
|
652
|
-
dtype=torch.int64,
|
|
653
|
-
)
|
|
654
|
-
self.ab_strides2 = torch.full(
|
|
655
|
-
(num_experts,),
|
|
656
|
-
intermediate_size_per_partition,
|
|
657
|
-
device=w2_weight.device,
|
|
658
|
-
dtype=torch.int64,
|
|
659
|
-
)
|
|
660
|
-
self.c_strides2 = torch.full(
|
|
661
|
-
(num_experts,),
|
|
662
|
-
hidden_size,
|
|
663
|
-
device=w2_weight.device,
|
|
664
|
-
dtype=torch.int64,
|
|
665
|
-
)
|
|
666
|
-
self.workspace = torch.empty(
|
|
667
|
-
90000, device=w13_weight.device, dtype=torch.uint8
|
|
668
|
-
)
|
|
669
|
-
self.a_ptr = torch.empty(
|
|
670
|
-
num_experts, device=w13_weight.device, dtype=torch.int64
|
|
671
|
-
)
|
|
672
|
-
self.b_ptr = torch.empty(
|
|
673
|
-
num_experts, device=w13_weight.device, dtype=torch.int64
|
|
674
|
-
)
|
|
675
|
-
self.out_ptr = torch.empty(
|
|
676
|
-
num_experts, device=w13_weight.device, dtype=torch.int64
|
|
677
|
-
)
|
|
678
|
-
self.a_scales_ptr = torch.empty(
|
|
679
|
-
num_experts, device=w13_weight.device, dtype=torch.int64
|
|
680
|
-
)
|
|
681
|
-
self.b_scales_ptr = torch.empty(
|
|
682
|
-
num_experts, device=w13_weight.device, dtype=torch.int64
|
|
683
|
-
)
|
|
684
|
-
self.expert_offsets = torch.empty(
|
|
685
|
-
num_experts + 1, device=w13_weight.device, dtype=torch.int32
|
|
686
|
-
)
|
|
687
|
-
self.problem_sizes1 = torch.empty(
|
|
688
|
-
num_experts, 3, device=w13_weight.device, dtype=torch.int32
|
|
689
|
-
)
|
|
690
|
-
self.problem_sizes2 = torch.empty(
|
|
691
|
-
num_experts, 3, device=w13_weight.device, dtype=torch.int32
|
|
692
|
-
)
|
|
636
|
+
if self._should_use_cutlass_fused_experts():
|
|
637
|
+
self._ensure_cutlass_buffers_initialized(layer)
|
|
693
638
|
|
|
694
639
|
else:
|
|
695
640
|
# Allocate 2 scales for w1 and w3 respectively.
|
|
@@ -1039,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
1039
984
|
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
|
1040
985
|
|
|
1041
986
|
x = dispatch_output.hidden_states
|
|
1042
|
-
topk_output = dispatch_output.topk_output
|
|
1043
987
|
moe_runner_config = self.moe_runner_config
|
|
1044
988
|
|
|
1045
989
|
if use_intel_amx_backend(layer):
|
|
1046
990
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
|
1047
991
|
|
|
1048
|
-
topk_weights, topk_ids, _ = topk_output
|
|
992
|
+
topk_weights, topk_ids, _ = dispatch_output.topk_output
|
|
1049
993
|
x, topk_weights = apply_topk_weights_cpu(
|
|
1050
994
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
|
1051
995
|
)
|
|
@@ -1072,17 +1016,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
1072
1016
|
ret = self.maybe_apply_hip_fused_experts(
|
|
1073
1017
|
layer,
|
|
1074
1018
|
x,
|
|
1075
|
-
topk_output,
|
|
1019
|
+
dispatch_output.topk_output,
|
|
1076
1020
|
moe_runner_config.activation,
|
|
1077
1021
|
moe_runner_config.no_combine,
|
|
1078
1022
|
)
|
|
1079
1023
|
if ret is not None:
|
|
1080
1024
|
return StandardCombineInput(hidden_states=ret)
|
|
1081
1025
|
|
|
1082
|
-
if self.
|
|
1026
|
+
if self._should_use_cutlass_fused_experts():
|
|
1083
1027
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
|
1084
1028
|
|
|
1085
|
-
topk_weights, topk_ids, _ = topk_output
|
|
1029
|
+
topk_weights, topk_ids, _ = dispatch_output.topk_output
|
|
1086
1030
|
output = cutlass_fused_experts_fp8(
|
|
1087
1031
|
x,
|
|
1088
1032
|
layer.w13_weight.transpose(1, 2),
|
|
@@ -1171,6 +1115,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
1171
1115
|
|
|
1172
1116
|
return self.runner.run(dispatch_output, quant_info)
|
|
1173
1117
|
|
|
1118
|
+
def _should_use_cutlass_fused_experts(self) -> bool:
|
|
1119
|
+
"""Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
|
|
1120
|
+
with env var override via `SGLANG_CUTLASS_MOE`.
|
|
1121
|
+
"""
|
|
1122
|
+
backend = get_moe_runner_backend()
|
|
1123
|
+
env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
|
|
1124
|
+
# TODO: remove env var in the future, it should be handled by moe runner backend
|
|
1125
|
+
if env_force:
|
|
1126
|
+
return True
|
|
1127
|
+
return (
|
|
1128
|
+
backend.is_flashinfer_cutlass()
|
|
1129
|
+
and self.cutlass_fp8_supported
|
|
1130
|
+
and self.block_quant
|
|
1131
|
+
and (is_sm100_supported() or is_sm90_supported())
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
|
|
1135
|
+
if getattr(self, "_cutlass_buffers_ready", False):
|
|
1136
|
+
return
|
|
1137
|
+
|
|
1138
|
+
device = layer.w13_weight.device
|
|
1139
|
+
num_experts = layer.w13_weight.shape[0]
|
|
1140
|
+
hidden_size = layer.w2_weight.shape[1]
|
|
1141
|
+
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
|
1142
|
+
|
|
1143
|
+
self.ab_strides1 = torch.full(
|
|
1144
|
+
(num_experts,), hidden_size, device=device, dtype=torch.int64
|
|
1145
|
+
)
|
|
1146
|
+
self.c_strides1 = torch.full(
|
|
1147
|
+
(num_experts,),
|
|
1148
|
+
2 * intermediate_size_per_partition,
|
|
1149
|
+
device=device,
|
|
1150
|
+
dtype=torch.int64,
|
|
1151
|
+
)
|
|
1152
|
+
self.ab_strides2 = torch.full(
|
|
1153
|
+
(num_experts,),
|
|
1154
|
+
intermediate_size_per_partition,
|
|
1155
|
+
device=device,
|
|
1156
|
+
dtype=torch.int64,
|
|
1157
|
+
)
|
|
1158
|
+
self.c_strides2 = torch.full(
|
|
1159
|
+
(num_experts,), hidden_size, device=device, dtype=torch.int64
|
|
1160
|
+
)
|
|
1161
|
+
self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
|
|
1162
|
+
self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
|
|
1163
|
+
self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
|
|
1164
|
+
self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
|
|
1165
|
+
self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
|
|
1166
|
+
self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
|
|
1167
|
+
self.expert_offsets = torch.empty(
|
|
1168
|
+
num_experts + 1, device=device, dtype=torch.int32
|
|
1169
|
+
)
|
|
1170
|
+
self.problem_sizes1 = torch.empty(
|
|
1171
|
+
num_experts, 3, device=device, dtype=torch.int32
|
|
1172
|
+
)
|
|
1173
|
+
self.problem_sizes2 = torch.empty(
|
|
1174
|
+
num_experts, 3, device=device, dtype=torch.int32
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
self._cutlass_buffers_ready = True
|
|
1178
|
+
|
|
1174
1179
|
def apply_with_router_logits(
|
|
1175
1180
|
self,
|
|
1176
1181
|
layer: torch.nn.Module,
|
|
@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale(
|
|
|
459
459
|
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
|
|
460
460
|
device=device,
|
|
461
461
|
dtype=torch.float32,
|
|
462
|
-
).
|
|
462
|
+
).transpose(-1, -2)[: x_shape[-2], :]
|
|
463
463
|
else:
|
|
464
464
|
return torch.empty(
|
|
465
465
|
(x_shape[-1] // group_size,) + x_shape[:-1],
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
from sglang.srt.layers import deep_gemm_wrapper
|
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
|
7
7
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
|
8
|
-
from sglang.srt.utils import ceil_div,
|
|
8
|
+
from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader
|
|
9
9
|
|
|
10
10
|
try:
|
|
11
11
|
from vllm import _custom_ops as ops
|
|
@@ -129,7 +129,7 @@ def cutlass_block_fp8_supported() -> bool:
|
|
|
129
129
|
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
|
130
130
|
ENABLE_FLASHINFER_GEMM = (
|
|
131
131
|
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
|
|
132
|
-
and
|
|
132
|
+
and is_blackwell_supported()
|
|
133
133
|
and is_flashinfer_available()
|
|
134
134
|
)
|
|
135
135
|
if ENABLE_FLASHINFER_GEMM:
|