sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
2
|
+
|
3
|
+
from types import MappingProxyType
|
4
|
+
from typing import List, Mapping, Tuple, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_cuda
|
9
|
+
|
10
|
+
_is_cuda = is_cuda()
|
11
|
+
|
12
|
+
if _is_cuda:
|
13
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
14
|
+
else:
|
15
|
+
from vllm import _custom_ops as vllm_ops
|
16
|
+
|
17
|
+
|
18
|
+
def is_fp8_fnuz() -> bool:
|
19
|
+
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
20
|
+
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
21
|
+
|
22
|
+
|
23
|
+
def is_layer_skipped(
|
24
|
+
prefix: str,
|
25
|
+
ignored_layers: List[str],
|
26
|
+
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
27
|
+
) -> bool:
|
28
|
+
# prefix: model.layers.0.self_attn.q_proj
|
29
|
+
# proj_name: q_proj
|
30
|
+
proj_name = prefix.split(".")[-1]
|
31
|
+
|
32
|
+
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
33
|
+
# in the safetensors checkpoint. So, we convert the name
|
34
|
+
# from the fused version to unfused + check to make sure that
|
35
|
+
# each shard of the fused layer has the same scheme.
|
36
|
+
if proj_name in fused_mapping:
|
37
|
+
shard_prefixes = [
|
38
|
+
prefix.replace(proj_name, shard_proj_name)
|
39
|
+
for shard_proj_name in fused_mapping[proj_name]
|
40
|
+
]
|
41
|
+
|
42
|
+
is_skipped = None
|
43
|
+
for shard_prefix in shard_prefixes:
|
44
|
+
is_shard_skipped = shard_prefix in ignored_layers
|
45
|
+
|
46
|
+
if is_skipped is None:
|
47
|
+
is_skipped = is_shard_skipped
|
48
|
+
elif is_shard_skipped != is_skipped:
|
49
|
+
raise ValueError(
|
50
|
+
f"Detected some but not all shards of {prefix} "
|
51
|
+
"are quantized. All shards of fused layers "
|
52
|
+
"to have the same precision."
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
is_skipped = prefix in ignored_layers
|
56
|
+
|
57
|
+
assert is_skipped is not None
|
58
|
+
return is_skipped
|
59
|
+
|
60
|
+
|
61
|
+
def per_tensor_dequantize(
|
62
|
+
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
63
|
+
) -> torch.Tensor:
|
64
|
+
fake_qweight = tensor.to(torch.float16)
|
65
|
+
dq_weight = fake_qweight * inv_scale
|
66
|
+
return dq_weight
|
67
|
+
|
68
|
+
|
69
|
+
def all_close_1d(x: torch.Tensor) -> bool:
|
70
|
+
assert len(x.shape) == 1
|
71
|
+
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
72
|
+
|
73
|
+
|
74
|
+
def convert_to_channelwise(
|
75
|
+
weight_scale: torch.Tensor, logical_widths: List[int]
|
76
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
77
|
+
# Create channelwise buffer
|
78
|
+
weight_scale_channel = torch.empty(
|
79
|
+
(sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
|
80
|
+
)
|
81
|
+
|
82
|
+
# Handle scalar tensor case: broadcast same scale to all channels
|
83
|
+
if weight_scale.dim() == 0:
|
84
|
+
weight_scale_channel.fill_(weight_scale.item())
|
85
|
+
return weight_scale_channel
|
86
|
+
|
87
|
+
# Expand each scale to match the size of each logical matrix.
|
88
|
+
start = 0
|
89
|
+
for idx, logical_width in enumerate(logical_widths):
|
90
|
+
end = start + logical_width
|
91
|
+
weight_scale_channel[start:end, :] = weight_scale[idx]
|
92
|
+
start = end
|
93
|
+
|
94
|
+
return weight_scale_channel
|
95
|
+
|
96
|
+
|
97
|
+
def requantize_with_max_scale(
|
98
|
+
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
|
99
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
100
|
+
# Max scale to be used for requanitzation.
|
101
|
+
max_w_scale = weight_scale.max()
|
102
|
+
|
103
|
+
# QKV / MLP is fused in the on disk checkpoint if any of the
|
104
|
+
# weight scales are still set to the default since we initialize
|
105
|
+
# N weight scales for N shards but we only load 1 weight scale
|
106
|
+
# from disk in this case. Skip requantization in this case (since)
|
107
|
+
# we already are quantized with the single scale.
|
108
|
+
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
109
|
+
unfused_module_in_checkpoint = (
|
110
|
+
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
111
|
+
)
|
112
|
+
|
113
|
+
# If unfused checkpoint, need requanize with the single scale.
|
114
|
+
if unfused_module_in_checkpoint:
|
115
|
+
start = 0
|
116
|
+
for idx, logical_width in enumerate(logical_widths):
|
117
|
+
end = start + logical_width
|
118
|
+
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
119
|
+
if _is_cuda:
|
120
|
+
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
121
|
+
else:
|
122
|
+
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
123
|
+
weight_dq, max_w_scale
|
124
|
+
)
|
125
|
+
start = end
|
126
|
+
|
127
|
+
return max_w_scale, weight
|
128
|
+
|
129
|
+
|
130
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
131
|
+
# Newly generated tensors need to replace existing tensors that are
|
132
|
+
# already registered as parameters by vLLM (and won't be freed)
|
133
|
+
def replace_parameter(
|
134
|
+
mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
|
135
|
+
) -> None:
|
136
|
+
|
137
|
+
old = getattr(mod, name)
|
138
|
+
if (
|
139
|
+
type(old) is type(new)
|
140
|
+
and old.dtype == new.dtype
|
141
|
+
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
|
142
|
+
):
|
143
|
+
# If we can just update in-place to avoid re-registering
|
144
|
+
# can be faster if the underlying storage is the same
|
145
|
+
update_tensor_inplace(old, new)
|
146
|
+
else:
|
147
|
+
# Fallback re-register parameter, convert to Parameter if necessary
|
148
|
+
# this not only ensures we don't register a tensor as a parameter, but
|
149
|
+
# also ensures that all parameter subclasses get re-registered as
|
150
|
+
# parameters for `torch.compile` compatibility
|
151
|
+
if not isinstance(new, torch.nn.Parameter):
|
152
|
+
new = torch.nn.Parameter(new, requires_grad=False)
|
153
|
+
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|
9
9
|
QuantizationConfig,
|
10
10
|
QuantizeMethodBase,
|
11
11
|
)
|
12
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
12
13
|
from sglang.srt.layers.quantization.fp8_utils import (
|
13
14
|
apply_fp8_linear,
|
14
15
|
cutlass_fp8_supported,
|
16
|
+
input_to_float8,
|
15
17
|
normalize_e4m3fn_to_e4m3fnuz,
|
16
18
|
)
|
17
19
|
from sglang.srt.utils import is_hip
|
@@ -22,12 +24,24 @@ _is_hip = is_hip()
|
|
22
24
|
class W8A8Fp8Config(QuantizationConfig):
|
23
25
|
"""Config class for W8A8 FP8 Quantization.
|
24
26
|
|
25
|
-
|
26
|
-
-
|
27
|
+
Weight Quantization:
|
28
|
+
- Method: Static quantization
|
29
|
+
- Granularity: Per-channel
|
30
|
+
- Type: Symmetric
|
31
|
+
|
32
|
+
Activation Quantization:
|
33
|
+
- Method: Dynamic quantization
|
34
|
+
- Granularity: Per-token
|
35
|
+
- Type: Symmetric
|
36
|
+
|
37
|
+
Note:
|
38
|
+
- For models without offline quantization, weights will be quantized during model loading
|
39
|
+
- If CUTLASS is supported: Per-channel weight quantization is used
|
40
|
+
- If CUTLASS is not supported: Falls back to per-tensor weight quantization
|
27
41
|
"""
|
28
42
|
|
29
|
-
def __init__(self):
|
30
|
-
|
43
|
+
def __init__(self, is_checkpoint_fp8_serialized: bool = False):
|
44
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
31
45
|
|
32
46
|
@classmethod
|
33
47
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
47
61
|
|
48
62
|
@classmethod
|
49
63
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
50
|
-
|
64
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
65
|
+
is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
|
66
|
+
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
|
51
67
|
|
52
68
|
def get_quant_method(
|
53
69
|
self,
|
@@ -72,13 +88,40 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
72
88
|
|
73
89
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
74
90
|
weight = layer.weight
|
75
|
-
|
76
|
-
if
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
91
|
+
|
92
|
+
if self.quantization_config.is_checkpoint_fp8_serialized:
|
93
|
+
weight_scale = layer.weight_scale.detach()
|
94
|
+
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
95
|
+
if _is_hip:
|
96
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
97
|
+
weight=weight, weight_scale=weight_scale
|
98
|
+
)
|
99
|
+
|
100
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
101
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
102
|
+
else:
|
103
|
+
# If checkpoint not offline quantized, quantize the weights with per-channel quantization.
|
104
|
+
if self.cutlass_fp8_supported:
|
105
|
+
# if cutlass supported, we use cutlass_scaled_mm
|
106
|
+
# which requires per-channel quantization on weight
|
107
|
+
qweight, weight_scale = per_token_group_quant_fp8(
|
108
|
+
layer.weight, layer.weight.shape[-1]
|
109
|
+
)
|
110
|
+
weight_scale = weight_scale.t().contiguous()
|
111
|
+
if _is_hip:
|
112
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
113
|
+
weight=weight, weight_scale=weight_scale
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
# if cutlass not supported, we fall back to use torch._scaled_mm
|
117
|
+
# which requires per tensor quantization on weight
|
118
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
119
|
+
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
120
|
+
|
121
|
+
# Update the layer with the new values.
|
122
|
+
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
123
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
124
|
+
layer.input_scale = None
|
82
125
|
|
83
126
|
def create_weights(
|
84
127
|
self,
|
@@ -90,6 +133,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
90
133
|
params_dtype: torch.dtype,
|
91
134
|
**extra_weight_attrs
|
92
135
|
):
|
136
|
+
weight_dtype = (
|
137
|
+
torch.float8_e4m3fn
|
138
|
+
if self.quantization_config.is_checkpoint_fp8_serialized
|
139
|
+
else params_dtype
|
140
|
+
)
|
93
141
|
|
94
142
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
95
143
|
self.logical_widths = output_partition_sizes
|
@@ -98,7 +146,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
98
146
|
data=torch.empty(
|
99
147
|
sum(output_partition_sizes),
|
100
148
|
input_size_per_partition,
|
101
|
-
dtype=
|
149
|
+
dtype=weight_dtype,
|
102
150
|
),
|
103
151
|
input_dim=1,
|
104
152
|
output_dim=0,
|
@@ -106,12 +154,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
106
154
|
)
|
107
155
|
layer.register_parameter("weight", weight)
|
108
156
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
157
|
+
if self.quantization_config.is_checkpoint_fp8_serialized:
|
158
|
+
weight_scale = ChannelQuantScaleParameter(
|
159
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
160
|
+
output_dim=0,
|
161
|
+
weight_loader=weight_loader,
|
162
|
+
)
|
163
|
+
layer.register_parameter("weight_scale", weight_scale)
|
164
|
+
else:
|
165
|
+
layer.weight_scale = None
|
115
166
|
|
116
167
|
def apply(
|
117
168
|
self,
|
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
|
-
from vllm import _custom_ops as ops
|
10
9
|
|
11
10
|
from sglang.srt.custom_op import CustomOp
|
12
11
|
from sglang.srt.utils import is_cuda_available
|
@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available
|
|
14
13
|
_is_cuda_available = is_cuda_available()
|
15
14
|
if _is_cuda_available:
|
16
15
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
16
|
+
else:
|
17
|
+
from vllm import _custom_ops as ops
|
17
18
|
|
18
19
|
|
19
20
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -147,7 +148,7 @@ class RotaryEmbedding(CustomOp):
|
|
147
148
|
key: torch.Tensor,
|
148
149
|
offsets: Optional[torch.Tensor] = None,
|
149
150
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
150
|
-
if _is_cuda_available:
|
151
|
+
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
|
151
152
|
apply_rope_with_cos_sin_cache_inplace(
|
152
153
|
positions=positions,
|
153
154
|
query=query,
|
@@ -168,76 +169,6 @@ class RotaryEmbedding(CustomOp):
|
|
168
169
|
)
|
169
170
|
return query, key
|
170
171
|
|
171
|
-
def forward_xpu(
|
172
|
-
self,
|
173
|
-
positions: torch.Tensor,
|
174
|
-
query: torch.Tensor,
|
175
|
-
key: torch.Tensor,
|
176
|
-
offsets: Optional[torch.Tensor] = None,
|
177
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
178
|
-
from vllm._ipex_ops import ipex_ops as ops
|
179
|
-
|
180
|
-
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
181
|
-
ops.rotary_embedding(
|
182
|
-
positions,
|
183
|
-
query,
|
184
|
-
key,
|
185
|
-
self.head_size,
|
186
|
-
self.cos_sin_cache,
|
187
|
-
self.is_neox_style,
|
188
|
-
)
|
189
|
-
return query, key
|
190
|
-
|
191
|
-
def forward_hpu(
|
192
|
-
self,
|
193
|
-
positions: torch.Tensor,
|
194
|
-
query: torch.Tensor,
|
195
|
-
key: torch.Tensor,
|
196
|
-
offsets: Optional[torch.Tensor] = None,
|
197
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
198
|
-
from habana_frameworks.torch.hpex.kernels import (
|
199
|
-
RotaryPosEmbeddingMode,
|
200
|
-
apply_rotary_pos_emb,
|
201
|
-
)
|
202
|
-
|
203
|
-
positions = positions.flatten()
|
204
|
-
if offsets is not None:
|
205
|
-
positions = positions + offsets
|
206
|
-
num_tokens = positions.shape[0]
|
207
|
-
cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
|
208
|
-
cos, sin = cos_sin.chunk(2, dim=-1)
|
209
|
-
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
210
|
-
# to query hidden dimension, so the original tensors need to be
|
211
|
-
# expanded
|
212
|
-
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
213
|
-
# and expansion of cos/sin tensors via concatenation
|
214
|
-
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
|
215
|
-
# and expansion of cos/sin tensors via repeat_interleave
|
216
|
-
rope_mode: RotaryPosEmbeddingMode
|
217
|
-
if self.is_neox_style:
|
218
|
-
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
219
|
-
cos = torch.cat((cos, cos), dim=-1)
|
220
|
-
sin = torch.cat((sin, sin), dim=-1)
|
221
|
-
else:
|
222
|
-
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
223
|
-
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
|
224
|
-
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
|
225
|
-
|
226
|
-
query_shape = query.shape
|
227
|
-
query = query.view(num_tokens, -1, self.head_size)
|
228
|
-
query_rot = query[..., : self.rotary_dim]
|
229
|
-
query_pass = query[..., self.rotary_dim :]
|
230
|
-
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
231
|
-
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
232
|
-
|
233
|
-
key_shape = key.shape
|
234
|
-
key = key.view(num_tokens, -1, self.head_size)
|
235
|
-
key_rot = key[..., : self.rotary_dim]
|
236
|
-
key_pass = key[..., self.rotary_dim :]
|
237
|
-
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
238
|
-
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
239
|
-
return query, key
|
240
|
-
|
241
172
|
def extra_repr(self) -> str:
|
242
173
|
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
243
174
|
s += f", max_position_embeddings={self.max_position_embeddings}"
|
@@ -510,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|
510
441
|
):
|
511
442
|
super().__init__()
|
512
443
|
|
513
|
-
if rotary_dim != head_size:
|
514
|
-
raise ValueError(
|
515
|
-
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
516
|
-
rotary_dim != head_size ({rotary_dim}!={head_size})."
|
517
|
-
)
|
518
444
|
if is_neox_style is False:
|
519
445
|
raise ValueError(
|
520
446
|
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
521
447
|
)
|
522
448
|
|
449
|
+
self.rotary_dim = rotary_dim
|
523
450
|
self.head_size = head_size
|
524
451
|
self.max_position_embeddings = max_position_embeddings
|
525
452
|
self.original_max_position_embeddings = original_max_position_embeddings
|
@@ -568,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|
568
495
|
* (
|
569
496
|
self.base
|
570
497
|
** (
|
571
|
-
torch.arange(0, self.
|
572
|
-
/ self.
|
498
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
|
499
|
+
/ self.rotary_dim
|
573
500
|
)
|
574
501
|
)
|
575
502
|
)
|
@@ -618,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|
618
545
|
cos = cos.repeat(1, 2).unsqueeze(-2)
|
619
546
|
sin = sin.repeat(1, 2).unsqueeze(-2)
|
620
547
|
|
621
|
-
|
622
|
-
|
548
|
+
query_rot = query[..., : self.rotary_dim]
|
549
|
+
query_pass = query[..., self.rotary_dim :]
|
550
|
+
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
|
551
|
+
query = torch.cat((query_rot, query_pass), dim=-1)
|
552
|
+
|
553
|
+
key_rot = key[..., : self.rotary_dim]
|
554
|
+
key_pass = key[..., self.rotary_dim :]
|
555
|
+
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
|
556
|
+
key = torch.cat((key_rot, key_pass), dim=-1)
|
623
557
|
|
624
558
|
return query.flatten(-2), key.flatten(-2)
|
625
559
|
|
@@ -717,6 +651,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
717
651
|
query: torch.Tensor,
|
718
652
|
key: torch.Tensor,
|
719
653
|
offsets: Optional[torch.Tensor] = None,
|
654
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
655
|
+
if _is_cuda_available:
|
656
|
+
return self.forward_cuda(positions, query, key, offsets)
|
657
|
+
else:
|
658
|
+
return self.forward_native(positions, query, key, offsets)
|
659
|
+
|
660
|
+
def forward_native(
|
661
|
+
self,
|
662
|
+
positions: torch.Tensor,
|
663
|
+
query: torch.Tensor,
|
664
|
+
key: torch.Tensor,
|
665
|
+
offsets: Optional[torch.Tensor] = None,
|
720
666
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
721
667
|
"""PyTorch-native implementation equivalent to forward()."""
|
722
668
|
query_rot = query[..., : self.rotary_dim]
|
@@ -879,8 +825,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
879
825
|
spatial_merge_size: int,
|
880
826
|
context_len: int = 0,
|
881
827
|
seq_len: Optional[int] = None,
|
828
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
829
|
+
tokens_per_second: Optional[int] = None,
|
882
830
|
) -> Tuple[List[List[int]], int]:
|
883
|
-
"""
|
831
|
+
"""
|
832
|
+
Get mrope input positions and delta value.
|
833
|
+
|
834
|
+
:arg
|
835
|
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
836
|
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
837
|
+
|
838
|
+
"""
|
884
839
|
|
885
840
|
if isinstance(image_grid_thw, torch.Tensor):
|
886
841
|
image_grid_thw = image_grid_thw.tolist()
|
@@ -917,6 +872,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
917
872
|
)
|
918
873
|
image_index += 1
|
919
874
|
remain_images -= 1
|
875
|
+
second_per_grid_t = 0
|
920
876
|
ed = ed_image
|
921
877
|
else:
|
922
878
|
t, h, w = (
|
@@ -924,6 +880,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
924
880
|
video_grid_thw[video_index][1],
|
925
881
|
video_grid_thw[video_index][2],
|
926
882
|
)
|
883
|
+
if second_per_grid_ts is not None:
|
884
|
+
second_per_grid_t = second_per_grid_ts[video_index]
|
885
|
+
else:
|
886
|
+
second_per_grid_t = 1.0
|
927
887
|
video_index += 1
|
928
888
|
remain_videos -= 1
|
929
889
|
ed = ed_video
|
@@ -940,11 +900,11 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
940
900
|
)
|
941
901
|
|
942
902
|
t_index = (
|
943
|
-
torch.arange(llm_grid_t)
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
903
|
+
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
904
|
+
* second_per_grid_t
|
905
|
+
* tokens_per_second
|
906
|
+
).flatten()
|
907
|
+
|
948
908
|
h_index = (
|
949
909
|
torch.arange(llm_grid_h)
|
950
910
|
.view(1, -1, 1)
|
@@ -1172,6 +1132,37 @@ def get_rope(
|
|
1172
1132
|
return rotary_emb
|
1173
1133
|
|
1174
1134
|
|
1135
|
+
# Copied from transformers
|
1136
|
+
def rotate_half(x):
|
1137
|
+
"""Rotates half the hidden dims of the input."""
|
1138
|
+
x1 = x[..., : x.shape[-1] // 2]
|
1139
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
1140
|
+
return torch.cat((-x2, x1), dim=-1)
|
1141
|
+
|
1142
|
+
|
1143
|
+
def apply_rotary_pos_emb(
|
1144
|
+
q: torch.Tensor,
|
1145
|
+
k: torch.Tensor,
|
1146
|
+
cos: torch.Tensor,
|
1147
|
+
sin: torch.Tensor,
|
1148
|
+
unsqueeze_dim=1,
|
1149
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1150
|
+
orig_q_dtype = q.dtype
|
1151
|
+
orig_k_dtype = k.dtype
|
1152
|
+
q, k = q.float(), k.float()
|
1153
|
+
|
1154
|
+
# embedding is performed in float
|
1155
|
+
cos = cos.unsqueeze(unsqueeze_dim).float()
|
1156
|
+
sin = sin.unsqueeze(unsqueeze_dim).float()
|
1157
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
1158
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
1159
|
+
|
1160
|
+
q_embed = q_embed.to(orig_q_dtype)
|
1161
|
+
k_embed = k_embed.to(orig_k_dtype)
|
1162
|
+
|
1163
|
+
return q_embed, k_embed
|
1164
|
+
|
1165
|
+
|
1175
1166
|
def get_rope_cpu(
|
1176
1167
|
head_size: int,
|
1177
1168
|
rotary_dim: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
|
|
168
168
|
group=self.tp_sync_group,
|
169
169
|
)
|
170
170
|
|
171
|
-
return batch_next_token_ids
|
171
|
+
return batch_next_token_ids
|
172
172
|
|
173
173
|
def _apply_custom_logit_processor(
|
174
174
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
|
-
def
|
8
|
+
def get_fuse_output_add_from_name(name: str) -> bool:
|
9
9
|
mapping = {
|
10
10
|
"triton": True,
|
11
11
|
"flashinfer": False,
|
@@ -28,14 +28,14 @@ class BaseLoRABackend:
|
|
28
28
|
Args:
|
29
29
|
name: name of backend
|
30
30
|
batch_info: information of current batch for use
|
31
|
-
|
32
|
-
and the operation of
|
31
|
+
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
+
and the operation of adding will be fused into kernel
|
33
33
|
"""
|
34
34
|
|
35
35
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
36
36
|
self.name = name
|
37
37
|
self.batch_info = batch_info
|
38
|
-
self.
|
38
|
+
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
39
39
|
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
40
40
|
|
41
41
|
def run_lora_a_sgemm(
|
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
37
37
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
38
38
|
) -> torch.Tensor:
|
39
39
|
|
40
|
-
return
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
40
|
+
return (
|
41
|
+
self.segment_gemm.run(
|
42
|
+
x=x,
|
43
|
+
weights=weights,
|
44
|
+
batch_size=self.batch_info.bs,
|
45
|
+
weight_column_major=True,
|
46
|
+
seg_indptr=self.batch_info.seg_indptr,
|
47
|
+
weight_indices=self.batch_info.weight_indices,
|
48
|
+
)
|
49
|
+
* self.batch_info.scalings[0]
|
47
50
|
)
|
48
51
|
|
49
52
|
def run_qkv_lora(
|
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
90
93
|
weights=kv_lora_b[1],
|
91
94
|
)
|
92
95
|
|
93
|
-
return lora_output
|
96
|
+
return lora_output * self.batch_info.scalings[0]
|
94
97
|
|
95
98
|
def run_gate_up_lora(
|
96
99
|
self,
|
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
125
128
|
weights=gate_up_lora_b[1],
|
126
129
|
)
|
127
130
|
|
128
|
-
return lora_output
|
131
|
+
return lora_output * self.batch_info.scalings[0]
|
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
25
25
|
x: torch.Tensor,
|
26
26
|
weights: torch.Tensor,
|
27
27
|
base_output: torch.Tensor = None,
|
28
|
-
scaling: float = 1.0,
|
29
28
|
*args,
|
30
29
|
**kwargs
|
31
30
|
) -> torch.Tensor:
|
32
|
-
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output
|
31
|
+
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
33
32
|
|
34
33
|
def run_qkv_lora(
|
35
34
|
self,
|
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
39
38
|
output_offset: torch.Tensor,
|
40
39
|
max_qkv_out_dim: int,
|
41
40
|
base_output: torch.Tensor = None,
|
42
|
-
scaling: float = 1.0,
|
43
41
|
*args,
|
44
42
|
**kwargs
|
45
43
|
) -> torch.Tensor:
|
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
49
47
|
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
50
48
|
assert isinstance(qkv_lora_b, torch.Tensor)
|
51
49
|
|
52
|
-
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
|
50
|
+
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
|
53
51
|
lora_output = qkv_lora_b_fwd(
|
54
52
|
lora_a_output,
|
55
53
|
qkv_lora_b,
|
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
57
55
|
output_offset,
|
58
56
|
max_qkv_out_dim,
|
59
57
|
base_output,
|
60
|
-
scaling,
|
61
58
|
)
|
62
59
|
return lora_output
|
63
60
|
|
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
67
64
|
gate_up_lora_a: torch.Tensor,
|
68
65
|
gate_up_lora_b: torch.Tensor,
|
69
66
|
base_output: torch.Tensor = None,
|
70
|
-
scaling: float = 1.0,
|
71
67
|
*args,
|
72
68
|
**kwargs
|
73
69
|
) -> torch.Tensor:
|
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
79
75
|
output_dim = gate_up_lora_b.shape[-2] // 2
|
80
76
|
|
81
77
|
# lora_a_output: (s, 2 * r)
|
82
|
-
lora_a_output = sgemm_lora_a_fwd(
|
78
|
+
lora_a_output = sgemm_lora_a_fwd(
|
79
|
+
x, gate_up_lora_a, self.batch_info, stack_num=2
|
80
|
+
)
|
83
81
|
lora_output = gate_up_lora_b_fwd(
|
84
82
|
lora_a_output,
|
85
83
|
gate_up_lora_b,
|
86
84
|
self.batch_info,
|
87
85
|
output_dim,
|
88
86
|
base_output,
|
89
|
-
scaling,
|
90
87
|
)
|
91
88
|
return lora_output
|