sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ except ImportError:
|
|
14
14
|
|
15
15
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
16
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
|
+
fp8_dtype,
|
18
|
+
fp8_max,
|
19
|
+
is_fp8_fnuz,
|
17
20
|
per_token_group_quant_fp8,
|
18
21
|
scaled_fp8_quant,
|
19
22
|
sglang_per_token_quant_fp8,
|
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
|
|
30
33
|
|
31
34
|
_is_hip = is_hip()
|
32
35
|
_is_cuda = is_cuda()
|
36
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
33
37
|
|
34
|
-
|
38
|
+
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
39
|
+
|
40
|
+
if _is_hip and use_aiter_moe:
|
35
41
|
from aiter import gemm_a8w8_blockscale
|
36
42
|
|
37
43
|
if _is_cuda:
|
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
|
|
43
49
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
44
50
|
TORCH_DEVICE_IDENTITY = None
|
45
51
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
# The condition is
|
55
|
-
#
|
56
|
-
|
57
|
-
|
58
|
-
)
|
52
|
+
|
53
|
+
def use_rowwise_torch_scaled_mm():
|
54
|
+
_TORCH_VERSION = torch.__version__.split("+")[0]
|
55
|
+
try:
|
56
|
+
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
57
|
+
except ValueError:
|
58
|
+
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
59
|
+
if _is_hip:
|
60
|
+
# The condition to determine if it is on a platform that supports
|
61
|
+
# torch._scaled_mm rowwise feature.
|
62
|
+
# The condition is determined once as the operations
|
63
|
+
# are time consuming.
|
64
|
+
return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
65
|
+
return False
|
66
|
+
|
67
|
+
|
68
|
+
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
|
59
69
|
|
60
70
|
|
61
71
|
def cutlass_fp8_supported():
|
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
|
|
132
142
|
output = fp8_blockwise_scaled_mm(
|
133
143
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
134
144
|
)
|
135
|
-
elif _is_hip and
|
145
|
+
elif _is_hip and use_aiter_moe:
|
136
146
|
q_input, x_scale = per_token_group_quant_fp8(
|
137
147
|
input_2d, block_size[1], column_major_scales=False
|
138
148
|
)
|
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
|
|
164
174
|
|
165
175
|
|
166
176
|
def input_to_float8(
|
167
|
-
x: torch.Tensor, dtype: torch.dtype =
|
177
|
+
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
168
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
169
179
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
170
|
-
finfo = torch.finfo(dtype)
|
171
180
|
min_val, max_val = x.aminmax()
|
172
181
|
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
173
|
-
|
174
|
-
if
|
175
|
-
dtype =
|
176
|
-
|
177
|
-
|
178
|
-
|
182
|
+
|
183
|
+
if _is_fp8_fnuz:
|
184
|
+
dtype = fp8_dtype
|
185
|
+
fp_max = fp8_max
|
186
|
+
else:
|
187
|
+
finfo = torch.finfo(dtype)
|
188
|
+
fp_max = finfo.max
|
189
|
+
|
190
|
+
scale = fp_max / amax
|
191
|
+
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
|
179
192
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
180
193
|
|
181
194
|
|
@@ -222,6 +235,41 @@ def block_quant_to_tensor_quant(
|
|
222
235
|
return x_q_tensor, scale
|
223
236
|
|
224
237
|
|
238
|
+
def block_quant_dequant(
|
239
|
+
x_q_block: torch.Tensor,
|
240
|
+
x_s: torch.Tensor,
|
241
|
+
block_size: List[int],
|
242
|
+
dtype: torch.dtype,
|
243
|
+
) -> torch.Tensor:
|
244
|
+
"""This function converts block-wise quantization to unquantized.
|
245
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
246
|
+
and the block size.
|
247
|
+
The output is an unquantized tensor with dtype.
|
248
|
+
"""
|
249
|
+
block_n, block_k = block_size[0], block_size[1]
|
250
|
+
n, k = x_q_block.shape
|
251
|
+
n_tiles = (n + block_n - 1) // block_n
|
252
|
+
k_tiles = (k + block_k - 1) // block_k
|
253
|
+
assert n_tiles == x_s.shape[0]
|
254
|
+
assert k_tiles == x_s.shape[1]
|
255
|
+
|
256
|
+
x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
|
257
|
+
|
258
|
+
for j in range(n_tiles):
|
259
|
+
for i in range(k_tiles):
|
260
|
+
x_q_block_tile = x_q_block[
|
261
|
+
j * block_n : min((j + 1) * block_n, n),
|
262
|
+
i * block_k : min((i + 1) * block_k, k),
|
263
|
+
]
|
264
|
+
x_dq_block_tile = x_dq_block[
|
265
|
+
j * block_n : min((j + 1) * block_n, n),
|
266
|
+
i * block_k : min((i + 1) * block_k, k),
|
267
|
+
]
|
268
|
+
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
|
269
|
+
|
270
|
+
return x_dq_block
|
271
|
+
|
272
|
+
|
225
273
|
def channel_quant_to_tensor_quant(
|
226
274
|
x_q_channel: torch.Tensor,
|
227
275
|
x_s: torch.Tensor,
|
@@ -76,7 +76,7 @@ def _per_token_group_quant_int8(
|
|
76
76
|
y_s_ptr,
|
77
77
|
# Stride of input
|
78
78
|
y_stride,
|
79
|
-
#
|
79
|
+
# Columns of input
|
80
80
|
N,
|
81
81
|
# Avoid to divide zero
|
82
82
|
eps,
|
@@ -370,7 +370,7 @@ def w8a8_block_int8_matmul(
|
|
370
370
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
371
371
|
else:
|
372
372
|
# Default config
|
373
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
373
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
374
374
|
config = {
|
375
375
|
"BLOCK_SIZE_M": 64,
|
376
376
|
"BLOCK_SIZE_N": block_size[0],
|
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|
8
8
|
QuantizationConfig,
|
9
9
|
QuantizeMethodBase,
|
10
10
|
)
|
11
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
11
12
|
from sglang.srt.layers.radix_attention import RadixAttention
|
12
|
-
from sglang.srt.utils import is_hip
|
13
|
-
|
14
|
-
_is_hip = is_hip()
|
15
13
|
|
16
14
|
logger = logging.getLogger(__name__)
|
17
15
|
|
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
44
42
|
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
45
43
|
)
|
46
44
|
|
47
|
-
@classmethod
|
48
|
-
def is_fp8_fnuz(cls) -> bool:
|
49
|
-
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
50
|
-
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
51
|
-
|
52
45
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
53
46
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
54
47
|
|
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
57
50
|
# We prefer to use separate k_scale and v_scale if present
|
58
51
|
k_scale = layer.k_scale.to("cpu").tolist()
|
59
52
|
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
-
if
|
53
|
+
if is_fp8_fnuz():
|
61
54
|
k_scale *= 2
|
62
55
|
v_scale *= 2
|
63
56
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
73
66
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
67
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
68
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
-
if
|
69
|
+
if is_fp8_fnuz():
|
77
70
|
k_scale *= 2
|
78
71
|
v_scale *= 2
|
79
72
|
|
@@ -14,11 +14,6 @@ if not _is_cuda:
|
|
14
14
|
from vllm._custom_ops import scaled_fp8_quant
|
15
15
|
|
16
16
|
|
17
|
-
def is_fp8_fnuz() -> bool:
|
18
|
-
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
19
|
-
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
20
|
-
|
21
|
-
|
22
17
|
def is_layer_skipped(
|
23
18
|
prefix: str,
|
24
19
|
ignored_layers: List[str],
|
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
|
|
9
9
|
QuantizationConfig,
|
10
10
|
QuantizeMethodBase,
|
11
11
|
)
|
12
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
12
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
13
|
+
fp8_dtype,
|
14
|
+
is_fp8_fnuz,
|
15
|
+
per_token_group_quant_fp8,
|
16
|
+
)
|
13
17
|
from sglang.srt.layers.quantization.fp8_utils import (
|
14
18
|
apply_fp8_linear,
|
15
19
|
cutlass_fp8_supported,
|
16
20
|
input_to_float8,
|
17
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
18
22
|
)
|
19
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import set_weight_attrs
|
20
24
|
|
21
|
-
|
25
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
22
26
|
|
23
27
|
|
24
28
|
class W8A8Fp8Config(QuantizationConfig):
|
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
97
101
|
if self.quantization_config.is_checkpoint_fp8_serialized:
|
98
102
|
weight_scale = layer.weight_scale.detach()
|
99
103
|
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
100
|
-
if
|
104
|
+
if _is_fp8_fnuz:
|
101
105
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
102
106
|
weight=weight, weight_scale=weight_scale
|
103
107
|
)
|
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
113
117
|
layer.weight, layer.weight.shape[-1]
|
114
118
|
)
|
115
119
|
weight_scale = weight_scale.t().contiguous()
|
116
|
-
if _is_hip:
|
117
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
118
|
-
weight=weight, weight_scale=weight_scale
|
119
|
-
)
|
120
120
|
else:
|
121
121
|
# if cutlass not supported, we fall back to use torch._scaled_mm
|
122
122
|
# which requires per tensor quantization on weight
|
123
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
124
123
|
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
125
124
|
|
126
125
|
# Update the layer with the new values.
|
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
|
|
227
226
|
):
|
228
227
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
229
228
|
|
230
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
231
229
|
# WEIGHTS
|
232
230
|
w13_weight = torch.nn.Parameter(
|
233
231
|
torch.empty(
|
sglang/srt/layers/sampler.py
CHANGED
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
|
|
239
239
|
|
240
240
|
|
241
241
|
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
242
|
-
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
243
|
-
len(top_logprobs_nums),
|
244
|
-
logprobs.shape[0],
|
245
|
-
)
|
246
242
|
max_k = max(top_logprobs_nums)
|
247
243
|
ret = logprobs.topk(max_k, dim=1)
|
248
244
|
values = ret.values.tolist()
|
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
14
|
tensor_model_parallel_all_reduce,
|
15
15
|
)
|
16
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
16
17
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
214
215
|
self,
|
215
216
|
num_embeddings: int,
|
216
217
|
embedding_dim: int,
|
218
|
+
*,
|
217
219
|
params_dtype: Optional[torch.dtype] = None,
|
218
220
|
org_num_embeddings: Optional[int] = None,
|
219
221
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
220
222
|
quant_config: Optional[QuantizationConfig] = None,
|
221
223
|
prefix: str = "",
|
222
224
|
enable_tp: bool = True,
|
225
|
+
use_attn_tp_group: bool = False,
|
223
226
|
use_presharded_weights: bool = False,
|
224
227
|
):
|
225
228
|
super().__init__()
|
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
227
230
|
|
228
231
|
self.enable_tp = enable_tp
|
229
232
|
if self.enable_tp:
|
230
|
-
|
231
|
-
|
233
|
+
if use_attn_tp_group:
|
234
|
+
tp_rank = get_attention_tp_rank()
|
235
|
+
self.tp_size = get_attention_tp_size()
|
236
|
+
else:
|
237
|
+
tp_rank = get_tensor_model_parallel_rank()
|
238
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
232
239
|
else:
|
240
|
+
assert use_attn_tp_group is False
|
233
241
|
tp_rank = 0
|
234
242
|
self.tp_size = 1
|
235
243
|
|
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
519
527
|
self,
|
520
528
|
num_embeddings: int,
|
521
529
|
embedding_dim: int,
|
530
|
+
*,
|
522
531
|
bias: bool = False,
|
523
532
|
params_dtype: Optional[torch.dtype] = None,
|
524
533
|
org_num_embeddings: Optional[int] = None,
|
525
534
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
526
535
|
quant_config: Optional[QuantizationConfig] = None,
|
527
536
|
prefix: str = "",
|
537
|
+
use_attn_tp_group: bool = False,
|
528
538
|
use_presharded_weights: bool = False,
|
529
539
|
):
|
530
540
|
super().__init__(
|
531
541
|
num_embeddings,
|
532
542
|
embedding_dim,
|
533
|
-
params_dtype,
|
534
|
-
org_num_embeddings,
|
535
|
-
padding_size,
|
536
|
-
quant_config,
|
537
|
-
prefix,
|
543
|
+
params_dtype=params_dtype,
|
544
|
+
org_num_embeddings=org_num_embeddings,
|
545
|
+
padding_size=padding_size,
|
546
|
+
quant_config=quant_config,
|
547
|
+
prefix=prefix,
|
548
|
+
use_attn_tp_group=use_attn_tp_group,
|
538
549
|
use_presharded_weights=use_presharded_weights,
|
539
550
|
)
|
540
551
|
self.quant_config = quant_config
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -100,7 +100,7 @@ class LoRAManager:
|
|
100
100
|
self.configs[name] = LoRAConfig(path)
|
101
101
|
self.hf_target_names.update(self.configs[name].target_modules)
|
102
102
|
|
103
|
-
# Target lora weight names for lora_a and lora_b modules
|
103
|
+
# Target lora weight names for lora_a and lora_b modules respectively.
|
104
104
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
105
105
|
self.lora_weight_names: Set[Tuple[str]] = set(
|
106
106
|
[get_stacked_name(module) for module in self.hf_target_names]
|
@@ -156,18 +156,15 @@ class LoRAManager:
|
|
156
156
|
# set up batch info shared by all lora modules
|
157
157
|
bs = forward_batch.batch_size
|
158
158
|
|
159
|
-
if
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
159
|
+
if (
|
160
|
+
hasattr(self, "max_bs_in_cuda_graph")
|
161
|
+
and bs <= self.max_bs_in_cuda_graph
|
162
|
+
and forward_batch.forward_mode.is_cuda_graph()
|
163
|
+
):
|
164
|
+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
165
|
+
# could use CUDA graph.
|
164
166
|
self.cuda_graph_batch_info.bs = bs
|
165
|
-
|
166
|
-
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
|
167
|
-
forward_batch.extend_seq_lens
|
168
|
-
)
|
169
|
-
else:
|
170
|
-
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
167
|
+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
171
168
|
torch.cumsum(
|
172
169
|
self.cuda_graph_batch_info.seg_lens[:bs],
|
173
170
|
dim=0,
|
@@ -201,10 +198,10 @@ class LoRAManager:
|
|
201
198
|
max_len = int(torch.max(seg_lens))
|
202
199
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
203
200
|
|
204
|
-
lora_ranks = torch.
|
201
|
+
lora_ranks = torch.zeros(
|
205
202
|
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
206
203
|
)
|
207
|
-
scalings = torch.
|
204
|
+
scalings = torch.zeros(
|
208
205
|
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
209
206
|
)
|
210
207
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
|
|
50
50
|
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
51
51
|
|
52
52
|
# Buffer idx -> lora uid in memory pool
|
53
|
-
# All uids are
|
54
|
-
# Here we don't
|
53
|
+
# All uids are initialized as empty strings for empty buffer slots
|
54
|
+
# Here we don't initialize to None since None is a valid uid
|
55
55
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
56
56
|
|
57
57
|
def get_lora_A_shape(
|
58
58
|
self, module_name: str, base_model: torch.nn.Module
|
59
59
|
) -> Tuple[int]:
|
60
60
|
"""
|
61
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
61
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
62
62
|
"""
|
63
63
|
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
64
|
c = get_stacked_multiply(module_name)
|
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
|
|
75
75
|
self, module_name: str, base_model: torch.nn.Module
|
76
76
|
) -> Tuple[int]:
|
77
77
|
"""
|
78
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
78
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
79
79
|
"""
|
80
80
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
81
|
c = get_stacked_multiply(module_name)
|
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
|
|
77
77
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
78
78
|
)
|
79
79
|
|
80
|
-
#
|
80
|
+
# Iterate to compute the block in output matrix
|
81
81
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
82
82
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
83
83
|
x_tile = tl.load(
|
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
|
|
79
79
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
80
80
|
)
|
81
81
|
|
82
|
-
#
|
82
|
+
# Iterate to compute the block in output matrix
|
83
83
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
84
84
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
85
85
|
x_tile = tl.load(
|
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
|
|
67
67
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
68
68
|
)
|
69
69
|
|
70
|
-
#
|
70
|
+
# Iterate to compute the block in output matrix
|
71
71
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
72
72
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
73
73
|
x_tile = tl.load(
|
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
|
|
69
69
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
70
70
|
)
|
71
71
|
|
72
|
-
#
|
72
|
+
# Iterate to compute the block in output matrix
|
73
73
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
74
74
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
75
75
|
x_tile = tl.load(
|
sglang/srt/lora/utils.py
CHANGED
@@ -79,7 +79,7 @@ def get_hidden_dim(
|
|
79
79
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
80
80
|
) -> Tuple[int]:
|
81
81
|
"""
|
82
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
82
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
83
83
|
"""
|
84
84
|
|
85
85
|
if hasattr(base_model, "get_hidden_dim"):
|