sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
__all__ = ["CompressedTensorsScheme"]
|
10
|
+
|
11
|
+
|
12
|
+
class CompressedTensorsScheme(ABC):
|
13
|
+
"""
|
14
|
+
Abstract class used to describe the weight creation and forward pass
|
15
|
+
of different quantization schemes supported by CompressedTensors.
|
16
|
+
"""
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
@abstractmethod
|
20
|
+
def get_min_capability(cls) -> int:
|
21
|
+
"""
|
22
|
+
Get minimum device capability.
|
23
|
+
"""
|
24
|
+
raise NotImplementedError
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def create_weights(self, *args, **kwargs):
|
28
|
+
"""
|
29
|
+
Weight creation for the particular scheme. Inputs to this function
|
30
|
+
|
31
|
+
"""
|
32
|
+
raise NotImplementedError
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def apply_weights(
|
36
|
+
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Run the forward pass for the particular scheme. This is where
|
40
|
+
scheme-specific dequant/quant steps/kernels should be applied.
|
41
|
+
|
42
|
+
:param layer: torch.nn.Module with the registered weights and
|
43
|
+
other parameters relevant to the particular scheme.
|
44
|
+
:param x: input to the layer
|
45
|
+
:param bias: bias parameter
|
46
|
+
|
47
|
+
"""
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
@abstractmethod
|
51
|
+
def process_weights_after_loading(self, layer: torch.nn.Module):
|
52
|
+
"""
|
53
|
+
Called after weight loading is complete for any cleanup that
|
54
|
+
needs to occur.
|
55
|
+
"""
|
56
|
+
raise NotImplementedError
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
from typing import Callable, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
8
|
+
from torch.nn import Parameter
|
9
|
+
|
10
|
+
from sglang.srt.layers.parameter import (
|
11
|
+
ChannelQuantScaleParameter,
|
12
|
+
ModelWeightParameter,
|
13
|
+
PerTensorScaleParameter,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
16
|
+
CompressedTensorsScheme,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
19
|
+
Fp8LinearOp,
|
20
|
+
maybe_create_device_identity,
|
21
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
22
|
+
)
|
23
|
+
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
24
|
+
|
25
|
+
__all__ = ["CompressedTensorsW8A8Fp8"]
|
26
|
+
|
27
|
+
|
28
|
+
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
29
|
+
|
30
|
+
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
31
|
+
self.strategy = strategy
|
32
|
+
self.is_static_input_scheme = is_static_input_scheme
|
33
|
+
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def get_min_capability(cls) -> int:
|
37
|
+
# lovelace and up
|
38
|
+
return 89
|
39
|
+
|
40
|
+
def process_weights_after_loading(self, layer) -> None:
|
41
|
+
# If per tensor, when we have a fused module (e.g. QKV) with per
|
42
|
+
# tensor scales (thus N scales being passed to the kernel),
|
43
|
+
# requantize so we can always run per tensor
|
44
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
45
|
+
max_w_scale, weight = requantize_with_max_scale(
|
46
|
+
weight=layer.weight,
|
47
|
+
weight_scale=layer.weight_scale,
|
48
|
+
logical_widths=layer.logical_widths,
|
49
|
+
)
|
50
|
+
|
51
|
+
if is_fp8_fnuz():
|
52
|
+
input_scale = getattr(layer, "input_scale", None)
|
53
|
+
|
54
|
+
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
55
|
+
weight=weight, weight_scale=max_w_scale, input_scale=input_scale
|
56
|
+
)
|
57
|
+
if input_scale is not None:
|
58
|
+
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
59
|
+
|
60
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
61
|
+
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
62
|
+
|
63
|
+
# If channelwise, scales are already lined up, so just transpose.
|
64
|
+
elif self.strategy == QuantizationStrategy.CHANNEL:
|
65
|
+
weight = layer.weight
|
66
|
+
|
67
|
+
if is_fp8_fnuz():
|
68
|
+
input_scale = getattr(layer, "input_scale", None)
|
69
|
+
|
70
|
+
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
71
|
+
weight=weight,
|
72
|
+
weight_scale=layer.weight_scale,
|
73
|
+
input_scale=input_scale,
|
74
|
+
)
|
75
|
+
if input_scale is not None:
|
76
|
+
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
77
|
+
else:
|
78
|
+
weight_scale = layer.weight_scale.data
|
79
|
+
|
80
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
81
|
+
# required by torch.compile to be torch.nn.Parameter
|
82
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
83
|
+
|
84
|
+
else:
|
85
|
+
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
86
|
+
|
87
|
+
# INPUT SCALE
|
88
|
+
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
89
|
+
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
90
|
+
else:
|
91
|
+
layer.input_scale = None
|
92
|
+
|
93
|
+
def create_weights(
|
94
|
+
self,
|
95
|
+
layer: torch.nn.Module,
|
96
|
+
output_partition_sizes: List[int],
|
97
|
+
input_size_per_partition: int,
|
98
|
+
params_dtype: torch.dtype,
|
99
|
+
weight_loader: Callable,
|
100
|
+
**kwargs,
|
101
|
+
):
|
102
|
+
maybe_create_device_identity()
|
103
|
+
|
104
|
+
output_size_per_partition = sum(output_partition_sizes)
|
105
|
+
layer.logical_widths = output_partition_sizes
|
106
|
+
|
107
|
+
# WEIGHT
|
108
|
+
weight = ModelWeightParameter(
|
109
|
+
data=torch.empty(
|
110
|
+
output_size_per_partition,
|
111
|
+
input_size_per_partition,
|
112
|
+
dtype=torch.float8_e4m3fn,
|
113
|
+
),
|
114
|
+
input_dim=1,
|
115
|
+
output_dim=0,
|
116
|
+
weight_loader=weight_loader,
|
117
|
+
)
|
118
|
+
layer.register_parameter("weight", weight)
|
119
|
+
|
120
|
+
# WEIGHT SCALE
|
121
|
+
# TODO: update create_xxx_parameter functions to return
|
122
|
+
# the newly added parameters
|
123
|
+
if self.strategy == QuantizationStrategy.CHANNEL:
|
124
|
+
weight_scale = ChannelQuantScaleParameter(
|
125
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
126
|
+
output_dim=0,
|
127
|
+
weight_loader=weight_loader,
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
assert self.strategy == QuantizationStrategy.TENSOR
|
131
|
+
weight_scale = PerTensorScaleParameter(
|
132
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
133
|
+
weight_loader=weight_loader,
|
134
|
+
)
|
135
|
+
|
136
|
+
# min requirement for fp8 kernels
|
137
|
+
weight_scale[:] = torch.finfo(torch.float32).min
|
138
|
+
layer.register_parameter("weight_scale", weight_scale)
|
139
|
+
|
140
|
+
# INPUT SCALE
|
141
|
+
if self.is_static_input_scheme:
|
142
|
+
input_scale = PerTensorScaleParameter(
|
143
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
144
|
+
weight_loader=weight_loader,
|
145
|
+
)
|
146
|
+
input_scale[:] = torch.finfo(torch.float32).min
|
147
|
+
layer.register_parameter("input_scale", input_scale)
|
148
|
+
|
149
|
+
def apply_weights(
|
150
|
+
self,
|
151
|
+
layer: torch.nn.Module,
|
152
|
+
x: torch.Tensor,
|
153
|
+
bias: Optional[torch.Tensor] = None,
|
154
|
+
) -> torch.Tensor:
|
155
|
+
|
156
|
+
return self.fp8_linear.apply(
|
157
|
+
input=x,
|
158
|
+
weight=layer.weight,
|
159
|
+
weight_scale=layer.weight_scale,
|
160
|
+
input_scale=layer.input_scale,
|
161
|
+
bias=bias,
|
162
|
+
)
|
@@ -0,0 +1,218 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
import re
|
5
|
+
from types import MappingProxyType
|
6
|
+
from typing import Iterable, List, Mapping, Optional
|
7
|
+
|
8
|
+
from compressed_tensors import CompressionFormat
|
9
|
+
from torch.nn import Module
|
10
|
+
|
11
|
+
|
12
|
+
def is_activation_quantization_format(format: str) -> bool:
|
13
|
+
_ACTIVATION_QUANTIZATION_FORMATS = [
|
14
|
+
CompressionFormat.naive_quantized.value,
|
15
|
+
CompressionFormat.int_quantized.value,
|
16
|
+
CompressionFormat.float_quantized.value,
|
17
|
+
]
|
18
|
+
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
19
|
+
|
20
|
+
|
21
|
+
def should_ignore_layer(
|
22
|
+
layer_name: Optional[str],
|
23
|
+
ignore: Iterable[str] = tuple(),
|
24
|
+
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
25
|
+
) -> bool:
|
26
|
+
if layer_name is None:
|
27
|
+
return False
|
28
|
+
|
29
|
+
# layer_name = model.layers.0.self_attn.qkv_proj
|
30
|
+
# proj_name = qkv_proj
|
31
|
+
proj_name = layer_name.split(".")[-1]
|
32
|
+
|
33
|
+
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
34
|
+
# in the safetensors checkpoint. So, we convert the name
|
35
|
+
# from the fused version to unfused + check to make sure that
|
36
|
+
# each shard of the fused layer has the same scheme.
|
37
|
+
if proj_name in fused_mapping and layer_name not in ignore:
|
38
|
+
shard_proj_names = fused_mapping[proj_name]
|
39
|
+
|
40
|
+
# Convert fused_name --> [shard_names]
|
41
|
+
shard_names = [
|
42
|
+
layer_name.replace(proj_name, shard_proj_name)
|
43
|
+
for shard_proj_name in shard_proj_names
|
44
|
+
]
|
45
|
+
|
46
|
+
# Layer should be ignored if shards are ignored.
|
47
|
+
should_ignore_layer = None
|
48
|
+
for shard_name in shard_names:
|
49
|
+
should_ignore_shard = check_equal_or_regex_match(
|
50
|
+
layer_name=shard_name, targets=ignore
|
51
|
+
)
|
52
|
+
|
53
|
+
# If shard_idx=0, set layer ignore to match shard.
|
54
|
+
if should_ignore_layer is None:
|
55
|
+
should_ignore_layer = should_ignore_shard
|
56
|
+
|
57
|
+
# If shard_idx=1+ confirm scheme matches prior shards.
|
58
|
+
elif should_ignore_shard != should_ignore_layer:
|
59
|
+
raise ValueError(
|
60
|
+
f"Found a different quantization schemes for "
|
61
|
+
f"{shard_proj_names} in {layer_name}. vLLM "
|
62
|
+
"requires all to use the same scheme."
|
63
|
+
)
|
64
|
+
|
65
|
+
# Unfused layers like down_proj and o_proj will match
|
66
|
+
# the safetensors checkpoint already.
|
67
|
+
else:
|
68
|
+
should_ignore_layer = check_equal_or_regex_match(
|
69
|
+
layer_name=layer_name, targets=ignore
|
70
|
+
)
|
71
|
+
|
72
|
+
assert should_ignore_layer is not None
|
73
|
+
return should_ignore_layer
|
74
|
+
|
75
|
+
|
76
|
+
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
77
|
+
"""
|
78
|
+
Checks whether a layer_name is exactly equal or a regex match for
|
79
|
+
if target starts with 're:' to any target in list.
|
80
|
+
"""
|
81
|
+
for target in targets:
|
82
|
+
if _is_equal_or_regex_match(layer_name, target):
|
83
|
+
return True
|
84
|
+
return False
|
85
|
+
|
86
|
+
|
87
|
+
def find_matched_target(
|
88
|
+
layer_name: Optional[str],
|
89
|
+
module: Module,
|
90
|
+
targets: Iterable[str],
|
91
|
+
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
92
|
+
) -> str:
|
93
|
+
"""
|
94
|
+
Helper function to look up which "target" in the compressed-tensors
|
95
|
+
config that a layer corresponds to.
|
96
|
+
|
97
|
+
Recall that a compressed-tensors configs has a concept of
|
98
|
+
config_groups, where each layer can be quantized with with a different
|
99
|
+
scheme.
|
100
|
+
|
101
|
+
targets in each config_group will be a list of either layer names
|
102
|
+
(or regexes corresponding to layer names) or names of torch Modules.
|
103
|
+
|
104
|
+
First, we try to match the layer_name with a target
|
105
|
+
Second, we try to match the module's name with a target
|
106
|
+
Third, we try to map the layer_name to a list of fused module names.
|
107
|
+
*All* component module names must match in order for a match to be
|
108
|
+
successful. A successful match returns the first component target
|
109
|
+
|
110
|
+
:param layer_name: layer name
|
111
|
+
:param module: torch.nn.Module
|
112
|
+
:param targets: list of targets to match the layer against
|
113
|
+
:param fused_mapping: map from fused layer names to its components
|
114
|
+
:param fused_strategy: either "all" or "any". If using "all", fused
|
115
|
+
layers match if "all" of its components match
|
116
|
+
"""
|
117
|
+
|
118
|
+
if layer_name is None:
|
119
|
+
layer_name = ""
|
120
|
+
|
121
|
+
matched_target = (
|
122
|
+
_find_first_match(layer_name, targets)
|
123
|
+
or _find_first_match(module.__class__.__name__, targets, True)
|
124
|
+
or _match_fused_layer(layer_name, targets, fused_mapping)
|
125
|
+
)
|
126
|
+
|
127
|
+
if matched_target is None:
|
128
|
+
raise ValueError(
|
129
|
+
f"Unable to find matching target for {layer_name} in the "
|
130
|
+
"compressed-tensors config."
|
131
|
+
)
|
132
|
+
|
133
|
+
return matched_target
|
134
|
+
|
135
|
+
|
136
|
+
def _find_first_match(
|
137
|
+
value: str, targets: Iterable[str], check_contains: bool = False
|
138
|
+
) -> Optional[str]:
|
139
|
+
"""
|
140
|
+
Returns first element of target that matches value either
|
141
|
+
exactly or as a regex after 're:'. If check_contains is set to True,
|
142
|
+
additionally checks if the target string is contained within the value.
|
143
|
+
|
144
|
+
:param value: string to compare the list of targets against
|
145
|
+
:param targets: list of targets to match the layer against
|
146
|
+
:param check_contains: whether or not to do a substring match
|
147
|
+
"""
|
148
|
+
|
149
|
+
for target in targets:
|
150
|
+
if _is_equal_or_regex_match(value, target, check_contains=check_contains):
|
151
|
+
return target
|
152
|
+
return None
|
153
|
+
|
154
|
+
|
155
|
+
def _is_equal_or_regex_match(
|
156
|
+
value: str, target: str, check_contains: bool = False
|
157
|
+
) -> bool:
|
158
|
+
"""
|
159
|
+
Checks whether a value is exactly equal or a regex match for target
|
160
|
+
if target starts with 're:'. If check_contains is set to True,
|
161
|
+
additionally checks if the target string is contained within the value.
|
162
|
+
"""
|
163
|
+
|
164
|
+
if target.startswith("re:"):
|
165
|
+
pattern = target[3:]
|
166
|
+
if re.match(pattern, value):
|
167
|
+
return True
|
168
|
+
elif check_contains:
|
169
|
+
if target.lower() in value.lower():
|
170
|
+
return True
|
171
|
+
elif target == value:
|
172
|
+
return True
|
173
|
+
return False
|
174
|
+
|
175
|
+
|
176
|
+
def _match_fused_layer(
|
177
|
+
layer_name: str,
|
178
|
+
target_layers: Iterable[str],
|
179
|
+
fused_mapping: Mapping[str, List[str]],
|
180
|
+
) -> Optional[str]:
|
181
|
+
"""
|
182
|
+
Match a fused layer name to its corresponding individual layer in
|
183
|
+
target_layers. Returns first value in fused_mapping which matches targets
|
184
|
+
|
185
|
+
Implements an "all" matching strategy where a fused layer matches iff
|
186
|
+
"all" of its components match
|
187
|
+
|
188
|
+
:param layer_name: layer name
|
189
|
+
:param target_layers: list of targets to match the layer against
|
190
|
+
:param fused_mapping: map from fused layer names to its components
|
191
|
+
|
192
|
+
Examples:
|
193
|
+
layer_name = "model.layers.0.self_attn.qkv_proj"
|
194
|
+
target_layers = ["model.layers.0.self_attn.q_proj",
|
195
|
+
"model.layers.0.self_attn.k_proj",
|
196
|
+
"model.layers.0.self_attn.v_proj"]
|
197
|
+
"""
|
198
|
+
# find layer_name in mapping
|
199
|
+
fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
|
200
|
+
if fused is None:
|
201
|
+
return None
|
202
|
+
|
203
|
+
# expand path of unfused components
|
204
|
+
unfused_paths = [
|
205
|
+
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
206
|
+
]
|
207
|
+
|
208
|
+
# for each unfused component, find a match in targets
|
209
|
+
unfused_matches: List[Optional[str]] = []
|
210
|
+
for unfused in unfused_paths:
|
211
|
+
for target in target_layers:
|
212
|
+
if _is_equal_or_regex_match(unfused, target):
|
213
|
+
unfused_matches.append(target)
|
214
|
+
break
|
215
|
+
else:
|
216
|
+
unfused_matches.append(None)
|
217
|
+
|
218
|
+
return unfused_matches[0] if all(unfused_matches) else None
|
@@ -7,20 +7,33 @@ import torch
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module
|
9
9
|
from torch.nn.parameter import Parameter
|
10
|
-
|
11
|
-
from
|
12
|
-
from
|
13
|
-
apply_fp8_marlin_linear,
|
14
|
-
prepare_fp8_layer_for_marlin,
|
15
|
-
)
|
16
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
17
|
-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
10
|
+
|
11
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
12
|
+
from sglang.srt.layers.quantization.utils import (
|
18
13
|
all_close_1d,
|
19
14
|
convert_to_channelwise,
|
15
|
+
is_layer_skipped,
|
20
16
|
per_tensor_dequantize,
|
21
17
|
requantize_with_max_scale,
|
22
18
|
)
|
23
19
|
|
20
|
+
try:
|
21
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
22
|
+
apply_fp8_marlin_linear,
|
23
|
+
prepare_fp8_layer_for_marlin,
|
24
|
+
)
|
25
|
+
|
26
|
+
MARLIN_FP8_AVAILABLE = True
|
27
|
+
except ImportError:
|
28
|
+
MARLIN_FP8_AVAILABLE = False
|
29
|
+
|
30
|
+
def apply_fp8_marlin_linear(*args, **kwargs):
|
31
|
+
raise ImportError("vllm is not installed")
|
32
|
+
|
33
|
+
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
34
|
+
raise ImportError("vllm is not installed")
|
35
|
+
|
36
|
+
|
24
37
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
25
38
|
from sglang.srt.layers.linear import (
|
26
39
|
LinearBase,
|
@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
46
59
|
)
|
47
60
|
from sglang.srt.utils import (
|
48
61
|
get_bool_env_var,
|
62
|
+
is_cuda,
|
49
63
|
is_hip,
|
50
64
|
permute_weight,
|
51
65
|
print_warning_once,
|
@@ -60,6 +74,13 @@ if _is_hip:
|
|
60
74
|
from aiter.fused_moe_bf16_asm import asm_moe
|
61
75
|
from aiter.ops.shuffle import shuffle_weight
|
62
76
|
|
77
|
+
_is_cuda = is_cuda()
|
78
|
+
|
79
|
+
if _is_cuda:
|
80
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
81
|
+
else:
|
82
|
+
from vllm import _custom_ops as vllm_ops
|
83
|
+
|
63
84
|
logger = logging.getLogger(__name__)
|
64
85
|
|
65
86
|
|
@@ -131,8 +152,6 @@ class Fp8Config(QuantizationConfig):
|
|
131
152
|
def get_quant_method(
|
132
153
|
self, layer: torch.nn.Module, prefix: str
|
133
154
|
) -> Optional["QuantizeMethodBase"]:
|
134
|
-
from vllm.attention.layer import Attention # Avoid circular import
|
135
|
-
|
136
155
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
137
156
|
|
138
157
|
if isinstance(layer, LinearBase):
|
@@ -141,8 +160,6 @@ class Fp8Config(QuantizationConfig):
|
|
141
160
|
return Fp8LinearMethod(self)
|
142
161
|
elif isinstance(layer, FusedMoE):
|
143
162
|
return Fp8MoEMethod(self)
|
144
|
-
elif isinstance(layer, Attention):
|
145
|
-
return Fp8KVCacheMethod(self)
|
146
163
|
return None
|
147
164
|
|
148
165
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -173,7 +190,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
173
190
|
|
174
191
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
175
192
|
# kernel for fast weight-only FP8 quantization
|
176
|
-
self.use_marlin =
|
193
|
+
self.use_marlin = (
|
194
|
+
get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
|
195
|
+
)
|
177
196
|
# Disable marlin for ROCm
|
178
197
|
if _is_hip:
|
179
198
|
self.use_marlin = False
|
@@ -371,9 +390,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
371
390
|
)
|
372
391
|
|
373
392
|
if self.use_marlin:
|
374
|
-
|
375
|
-
|
376
|
-
|
393
|
+
try:
|
394
|
+
prepare_fp8_layer_for_marlin(layer)
|
395
|
+
# Activations not quantized for marlin.
|
396
|
+
del layer.input_scale
|
397
|
+
except ImportError:
|
398
|
+
self.use_marlin = False
|
377
399
|
|
378
400
|
def apply(
|
379
401
|
self,
|
@@ -383,15 +405,18 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
383
405
|
) -> torch.Tensor:
|
384
406
|
|
385
407
|
if self.use_marlin:
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
408
|
+
try:
|
409
|
+
return apply_fp8_marlin_linear(
|
410
|
+
input=x,
|
411
|
+
weight=layer.weight,
|
412
|
+
weight_scale=layer.weight_scale,
|
413
|
+
workspace=layer.workspace,
|
414
|
+
size_n=layer.output_size_per_partition,
|
415
|
+
size_k=layer.input_size_per_partition,
|
416
|
+
bias=bias,
|
417
|
+
)
|
418
|
+
except ImportError:
|
419
|
+
self.use_marlin = False
|
395
420
|
|
396
421
|
if self.block_quant:
|
397
422
|
return apply_w8a8_block_fp8_linear(
|
@@ -680,12 +705,20 @@ class Fp8MoEMethod:
|
|
680
705
|
requires_grad=False,
|
681
706
|
)
|
682
707
|
for expert in range(layer.num_experts):
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
708
|
+
if _is_cuda:
|
709
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
710
|
+
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
711
|
+
)
|
712
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
713
|
+
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
714
|
+
)
|
715
|
+
else:
|
716
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
717
|
+
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
718
|
+
)
|
719
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
720
|
+
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
721
|
+
)
|
689
722
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
690
723
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
691
724
|
|
@@ -762,9 +795,18 @@ class Fp8MoEMethod:
|
|
762
795
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
763
796
|
layer.w13_weight_scale[expert_id][shard_id],
|
764
797
|
)
|
765
|
-
|
766
|
-
|
767
|
-
|
798
|
+
if _is_cuda:
|
799
|
+
(
|
800
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
801
|
+
_,
|
802
|
+
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
803
|
+
else:
|
804
|
+
(
|
805
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
806
|
+
_,
|
807
|
+
) = vllm_ops.scaled_fp8_quant(
|
808
|
+
dq_weight, max_w13_scales[expert_id]
|
809
|
+
)
|
768
810
|
start += shard_size
|
769
811
|
|
770
812
|
layer.w13_weight_scale = torch.nn.Parameter(
|
@@ -26,11 +26,14 @@ from sglang.srt.utils import (
|
|
26
26
|
direct_register_custom_op,
|
27
27
|
get_device_core_count,
|
28
28
|
get_device_name,
|
29
|
+
get_device_sm,
|
29
30
|
is_cuda,
|
30
31
|
is_hip,
|
31
32
|
supports_custom_op,
|
32
33
|
)
|
33
34
|
|
35
|
+
_enable_jit_deepgemm = False
|
36
|
+
|
34
37
|
_is_hip = is_hip()
|
35
38
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
36
39
|
|
@@ -39,9 +42,12 @@ if _is_cuda:
|
|
39
42
|
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
|
40
43
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
41
44
|
|
42
|
-
|
45
|
+
sm_version = get_device_sm()
|
46
|
+
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
|
47
|
+
_enable_jit_deepgemm = True
|
43
48
|
|
44
|
-
|
49
|
+
|
50
|
+
logger = logging.getLogger(__name__)
|
45
51
|
|
46
52
|
if supports_custom_op():
|
47
53
|
|
@@ -168,6 +174,7 @@ def per_token_group_quant_fp8(
|
|
168
174
|
eps: float = 1e-10,
|
169
175
|
dtype: torch.dtype = fp8_type_,
|
170
176
|
column_major_scales: bool = False,
|
177
|
+
scale_tma_aligned: bool = False,
|
171
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
172
179
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
173
180
|
|
@@ -200,11 +207,20 @@ def per_token_group_quant_fp8(
|
|
200
207
|
M = x.numel() // group_size
|
201
208
|
N = group_size
|
202
209
|
if column_major_scales:
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
210
|
+
if scale_tma_aligned:
|
211
|
+
# aligned to 4 * sizeof(float)
|
212
|
+
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
213
|
+
x_s = torch.empty(
|
214
|
+
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
215
|
+
device=x.device,
|
216
|
+
dtype=torch.float32,
|
217
|
+
).permute(-1, -2)[: x.shape[-2], :]
|
218
|
+
else:
|
219
|
+
x_s = torch.empty(
|
220
|
+
(x.shape[-1] // group_size,) + x.shape[:-1],
|
221
|
+
device=x.device,
|
222
|
+
dtype=torch.float32,
|
223
|
+
).permute(-1, -2)
|
208
224
|
else:
|
209
225
|
x_s = torch.empty(
|
210
226
|
x.shape[:-1] + (x.shape[-1] // group_size,),
|
@@ -761,7 +777,7 @@ def w8a8_block_fp8_matmul(
|
|
761
777
|
)
|
762
778
|
|
763
779
|
# deepgemm only support bf16
|
764
|
-
if
|
780
|
+
if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
765
781
|
if supports_custom_op():
|
766
782
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
767
783
|
else:
|