sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
|
|
39
39
|
if is_dp_attention_enabled():
|
40
40
|
self.tp_sync_group = get_attention_tp_group().device_group
|
41
41
|
|
42
|
+
def _preprocess_logits(
|
43
|
+
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
44
|
+
) -> torch.Tensor:
|
45
|
+
"""Apply custom logit processors and handle NaN detection."""
|
46
|
+
# Apply the custom logit processors if registered in the sampling info
|
47
|
+
if sampling_info.has_custom_logit_processor:
|
48
|
+
apply_custom_logit_processor(logits, sampling_info)
|
49
|
+
|
50
|
+
# Detect and handle NaN values in logits
|
51
|
+
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
52
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
53
|
+
logits = torch.where(
|
54
|
+
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
55
|
+
)
|
56
|
+
if crash_on_warnings():
|
57
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
58
|
+
|
59
|
+
return logits
|
60
|
+
|
42
61
|
def forward(
|
43
62
|
self,
|
44
63
|
logits_output: LogitsProcessorOutput,
|
@@ -61,17 +80,8 @@ class Sampler(nn.Module):
|
|
61
80
|
"""
|
62
81
|
logits = logits_output.next_token_logits
|
63
82
|
|
64
|
-
#
|
65
|
-
|
66
|
-
apply_custom_logit_processor(logits, sampling_info)
|
67
|
-
|
68
|
-
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
69
|
-
logger.warning("Detected errors during sampling! NaN in the logits.")
|
70
|
-
logits = torch.where(
|
71
|
-
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
72
|
-
)
|
73
|
-
if crash_on_warnings():
|
74
|
-
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
83
|
+
# Preprocess logits (custom processors and NaN handling)
|
84
|
+
logits = self._preprocess_logits(logits, sampling_info)
|
75
85
|
|
76
86
|
if sampling_info.is_all_greedy:
|
77
87
|
# Use torch.argmax if all requests use greedy sampling
|
@@ -80,9 +90,9 @@ class Sampler(nn.Module):
|
|
80
90
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
81
91
|
|
82
92
|
else:
|
83
|
-
#
|
93
|
+
# If requested, cache probabilities from original logits before temperature scaling.
|
84
94
|
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
85
|
-
|
95
|
+
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
|
86
96
|
|
87
97
|
# Post process logits
|
88
98
|
logits.div_(sampling_info.temperatures)
|
@@ -123,9 +133,10 @@ class Sampler(nn.Module):
|
|
123
133
|
if return_logprob:
|
124
134
|
# clamp to avoid -inf
|
125
135
|
if RETURN_ORIGINAL_LOGPROB:
|
126
|
-
logprobs = torch.log(
|
127
|
-
min=torch.finfo(
|
136
|
+
logprobs = torch.log(probs_without_temp_scaling).clamp(
|
137
|
+
min=torch.finfo(probs_without_temp_scaling.dtype).min
|
128
138
|
)
|
139
|
+
del probs_without_temp_scaling
|
129
140
|
else:
|
130
141
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
131
142
|
|
@@ -164,6 +175,54 @@ class Sampler(nn.Module):
|
|
164
175
|
|
165
176
|
return batch_next_token_ids
|
166
177
|
|
178
|
+
def compute_logprobs_only(
|
179
|
+
self,
|
180
|
+
logits_output: LogitsProcessorOutput,
|
181
|
+
sampling_info: SamplingBatchInfo,
|
182
|
+
return_logprob: bool,
|
183
|
+
top_logprobs_nums: List[int],
|
184
|
+
token_ids_logprobs: List[List[int]],
|
185
|
+
) -> None:
|
186
|
+
"""
|
187
|
+
Compute logprobs for requested token IDs without performing sampling.
|
188
|
+
|
189
|
+
Optimized for prefill-only scoring requests that need token probabilities
|
190
|
+
but don't require next token generation.
|
191
|
+
"""
|
192
|
+
if logits_output.next_token_logits is None:
|
193
|
+
logger.warning("No logits available for logprob computation")
|
194
|
+
return
|
195
|
+
|
196
|
+
# Check if any requests actually need logprobs computation
|
197
|
+
needs_token_ids_logprobs = any(
|
198
|
+
token_ids is not None and len(token_ids) > 0
|
199
|
+
for token_ids in token_ids_logprobs
|
200
|
+
)
|
201
|
+
needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
|
202
|
+
|
203
|
+
if not (needs_token_ids_logprobs or needs_top_logprobs):
|
204
|
+
return
|
205
|
+
|
206
|
+
# Preprocess logits (custom processors and NaN handling)
|
207
|
+
logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
|
208
|
+
|
209
|
+
# Compute logprobs
|
210
|
+
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
211
|
+
|
212
|
+
# Handle top logprobs if requested
|
213
|
+
if needs_top_logprobs:
|
214
|
+
(
|
215
|
+
logits_output.next_token_top_logprobs_val,
|
216
|
+
logits_output.next_token_top_logprobs_idx,
|
217
|
+
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
218
|
+
|
219
|
+
# Handle token_ids logprobs if requested
|
220
|
+
if needs_token_ids_logprobs:
|
221
|
+
(
|
222
|
+
logits_output.next_token_token_ids_logprobs_val,
|
223
|
+
logits_output.next_token_token_ids_logprobs_idx,
|
224
|
+
) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
|
225
|
+
|
167
226
|
|
168
227
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
169
228
|
probs: torch.Tensor,
|
@@ -233,10 +292,95 @@ def get_top_logprobs(
|
|
233
292
|
)
|
234
293
|
|
235
294
|
|
236
|
-
def
|
295
|
+
def get_token_ids_logprobs_batch_optimized(
|
237
296
|
logprobs: torch.Tensor,
|
238
297
|
token_ids_logprobs: List[List[int]],
|
239
|
-
):
|
298
|
+
) -> Tuple[List, List]:
|
299
|
+
"""
|
300
|
+
Vectorized batch processing for token ID logprobs extraction.
|
301
|
+
|
302
|
+
Uses a single GPU kernel call for the entire batch instead of multiple
|
303
|
+
separate calls, significantly improving performance for large batches.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
logprobs: Log probabilities tensor [batch_size, vocab_size]
|
307
|
+
token_ids_logprobs: List of token IDs to extract logprobs for
|
308
|
+
|
309
|
+
Example:
|
310
|
+
# Input: batch_size=3, vocab_size=5
|
311
|
+
logprobs = torch.tensor([
|
312
|
+
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
|
313
|
+
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
|
314
|
+
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
|
315
|
+
])
|
316
|
+
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
|
317
|
+
|
318
|
+
# Output:
|
319
|
+
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
|
320
|
+
# indices = [[1, 3], [2], [0, 2, 4]]
|
321
|
+
"""
|
322
|
+
batch_size = len(token_ids_logprobs)
|
323
|
+
device = logprobs.device
|
324
|
+
|
325
|
+
# Step 1: Calculate lengths for each request, treating None as empty list
|
326
|
+
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
|
327
|
+
token_lengths = torch.tensor(
|
328
|
+
[len(token_ids or []) for token_ids in token_ids_logprobs], device=device
|
329
|
+
)
|
330
|
+
total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
|
331
|
+
|
332
|
+
# Handle edge case where no tokens are requested
|
333
|
+
if total_tokens == 0:
|
334
|
+
return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
|
335
|
+
[] for _ in token_ids_logprobs
|
336
|
+
]
|
337
|
+
|
338
|
+
# Step 2: Build flattened indices using torch operations
|
339
|
+
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
|
340
|
+
row_indices = torch.repeat_interleave(
|
341
|
+
torch.arange(batch_size, device=device), token_lengths
|
342
|
+
)
|
343
|
+
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
|
344
|
+
col_indices = torch.tensor(
|
345
|
+
[
|
346
|
+
token_id
|
347
|
+
for token_ids in token_ids_logprobs
|
348
|
+
for token_id in (token_ids or [])
|
349
|
+
],
|
350
|
+
device=device,
|
351
|
+
dtype=torch.long,
|
352
|
+
)
|
353
|
+
|
354
|
+
# Step 3: Single vectorized gather operation
|
355
|
+
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
|
356
|
+
gathered_logprobs = logprobs[row_indices, col_indices]
|
357
|
+
|
358
|
+
# Step 4: Split results back per request using torch operations
|
359
|
+
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
|
360
|
+
split_logprobs = torch.split_with_sizes(
|
361
|
+
gathered_logprobs, token_lengths.tolist(), dim=0
|
362
|
+
)
|
363
|
+
|
364
|
+
# Step 5: Format output to match expected return structure
|
365
|
+
# Example: Convert split tensors back to list format with proper empty handling
|
366
|
+
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
|
367
|
+
# i=1: [2] -> append split_logprobs[1] and [2]
|
368
|
+
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
|
369
|
+
output_token_ids_logprobs_val = []
|
370
|
+
output_token_ids_logprobs_idx = []
|
371
|
+
|
372
|
+
for i, token_ids in enumerate(token_ids_logprobs):
|
373
|
+
if token_ids is not None and len(token_ids) > 0:
|
374
|
+
output_token_ids_logprobs_val.append(split_logprobs[i])
|
375
|
+
output_token_ids_logprobs_idx.append(token_ids)
|
376
|
+
else:
|
377
|
+
output_token_ids_logprobs_val.append(logprobs.new_empty(0))
|
378
|
+
output_token_ids_logprobs_idx.append([])
|
379
|
+
|
380
|
+
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
381
|
+
|
382
|
+
|
383
|
+
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
240
384
|
output_token_ids_logprobs_val = []
|
241
385
|
output_token_ids_logprobs_idx = []
|
242
386
|
for i, token_ids in enumerate(token_ids_logprobs):
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Tuple, Union
|
1
|
+
from typing import Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
6
7
|
|
7
8
|
|
8
9
|
class BaseLoRABackend:
|
@@ -10,13 +11,14 @@ class BaseLoRABackend:
|
|
10
11
|
Each backend has its own implementation of Lora kernels.
|
11
12
|
|
12
13
|
Args:
|
13
|
-
|
14
|
-
|
14
|
+
max_loras_per_batch: maximum number of different lora weights
|
15
|
+
that can be applied in a single forward batch.
|
16
|
+
device: the device where the backend runs.
|
15
17
|
"""
|
16
18
|
|
17
|
-
def __init__(self,
|
18
|
-
self.
|
19
|
-
self.
|
19
|
+
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
20
|
+
self.max_loras_per_batch = max_loras_per_batch
|
21
|
+
self.device = device
|
20
22
|
|
21
23
|
def run_lora_a_sgemm(
|
22
24
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -93,8 +95,44 @@ class BaseLoRABackend:
|
|
93
95
|
"""
|
94
96
|
pass
|
95
97
|
|
96
|
-
def
|
97
|
-
self
|
98
|
+
def init_cuda_graph_batch_info(
|
99
|
+
self,
|
100
|
+
cuda_graph_batch_info: LoRABatchInfo,
|
101
|
+
max_bs_in_cuda_graph: int,
|
102
|
+
):
|
103
|
+
"""Initialize the batch info for CUDA Graph mode.
|
104
|
+
|
105
|
+
This method provides a hook for each backend to conduct its own initialization
|
106
|
+
logic for CUDA Graph mode.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
|
110
|
+
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
def prepare_lora_batch(
|
115
|
+
self,
|
116
|
+
forward_batch: ForwardBatch,
|
117
|
+
weight_indices: list[int],
|
118
|
+
lora_ranks: list[int],
|
119
|
+
scalings: list[float],
|
120
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
121
|
+
):
|
122
|
+
"""Prepare the lora weights and batch info for current forward batch.
|
123
|
+
|
124
|
+
This method provides a hook for each backend to conduct its own preparation
|
125
|
+
logic for each forward batch.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
forward_batch: the ForwardBatch object for current forward pass
|
129
|
+
weight_indices: list of indices of lora weights to be applied for current batch
|
130
|
+
lora_ranks: list of lora ranks corresponding to weight_indices
|
131
|
+
scalings: list of scaling factors corresponding to weight_indices
|
132
|
+
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
|
133
|
+
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
|
134
|
+
"""
|
135
|
+
pass
|
98
136
|
|
99
137
|
|
100
138
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
105
143
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
106
144
|
|
107
145
|
return TritonLoRABackend
|
146
|
+
# elif name == "csgmv":
|
147
|
+
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
148
|
+
|
149
|
+
# return ChunkedSgmvLoRABackend
|
108
150
|
elif name == "flashinfer":
|
109
151
|
raise ValueError(
|
110
152
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
1
3
|
import torch
|
2
4
|
|
3
5
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
|
|
8
10
|
sgemm_lora_b_fwd,
|
9
11
|
)
|
10
12
|
from sglang.srt.lora.utils import LoRABatchInfo
|
13
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
14
|
|
12
15
|
|
13
16
|
class TritonLoRABackend(BaseLoRABackend):
|
17
|
+
name = "triton"
|
14
18
|
|
15
|
-
def __init__(self,
|
16
|
-
super().__init__(
|
19
|
+
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
20
|
+
super().__init__(max_loras_per_batch, device)
|
17
21
|
|
18
22
|
def run_lora_a_sgemm(
|
19
23
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
86
90
|
base_output,
|
87
91
|
)
|
88
92
|
return lora_output
|
93
|
+
|
94
|
+
def init_cuda_graph_batch_info(
|
95
|
+
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
|
96
|
+
):
|
97
|
+
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
98
|
+
# across batches.
|
99
|
+
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
|
100
|
+
torch.cumsum(
|
101
|
+
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
|
102
|
+
dim=0,
|
103
|
+
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
|
104
|
+
)
|
105
|
+
|
106
|
+
def prepare_lora_batch(
|
107
|
+
self,
|
108
|
+
forward_batch: ForwardBatch,
|
109
|
+
weight_indices: list[int],
|
110
|
+
lora_ranks: list[int],
|
111
|
+
scalings: list[float],
|
112
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
113
|
+
):
|
114
|
+
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
115
|
+
weight_indices_tensor = torch.tensor(
|
116
|
+
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
117
|
+
)
|
118
|
+
lora_ranks_tensor = torch.tensor(
|
119
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
120
|
+
)
|
121
|
+
scalings_tensor = torch.tensor(
|
122
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
123
|
+
)
|
124
|
+
|
125
|
+
bs = forward_batch.batch_size
|
126
|
+
|
127
|
+
if batch_info is not None:
|
128
|
+
assert (
|
129
|
+
batch_info.use_cuda_graph
|
130
|
+
), "batch_info.use_cuda_graph must be True when batch_info is provided"
|
131
|
+
batch_info.bs = forward_batch.batch_size
|
132
|
+
batch_info.num_segments = forward_batch.batch_size
|
133
|
+
else:
|
134
|
+
max_len = (
|
135
|
+
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
136
|
+
max(forward_batch.extend_seq_lens_cpu)
|
137
|
+
if forward_batch.forward_mode.is_extend()
|
138
|
+
else 1
|
139
|
+
)
|
140
|
+
seg_lens = (
|
141
|
+
forward_batch.extend_seq_lens
|
142
|
+
if forward_batch.forward_mode.is_extend()
|
143
|
+
else torch.ones(bs, device=self.device)
|
144
|
+
)
|
145
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
146
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
147
|
+
|
148
|
+
batch_info = LoRABatchInfo(
|
149
|
+
bs=forward_batch.batch_size,
|
150
|
+
num_segments=forward_batch.batch_size,
|
151
|
+
max_len=max_len,
|
152
|
+
use_cuda_graph=False,
|
153
|
+
seg_lens=seg_lens,
|
154
|
+
seg_indptr=seg_indptr,
|
155
|
+
weight_indices=torch.empty(
|
156
|
+
(bs,), dtype=torch.int32, device=self.device
|
157
|
+
),
|
158
|
+
lora_ranks=torch.empty(
|
159
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
160
|
+
),
|
161
|
+
scalings=torch.empty(
|
162
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
163
|
+
),
|
164
|
+
permutation=None,
|
165
|
+
)
|
166
|
+
|
167
|
+
# Copy to device asynchronously
|
168
|
+
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
169
|
+
lora_ranks_tensor, non_blocking=True
|
170
|
+
)
|
171
|
+
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
172
|
+
scalings_tensor, non_blocking=True
|
173
|
+
)
|
174
|
+
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
175
|
+
|
176
|
+
self.batch_info = batch_info
|
sglang/srt/lora/layers.py
CHANGED
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
66
66
|
lora_backend: BaseLoRABackend,
|
67
67
|
) -> None:
|
68
68
|
super().__init__(base_layer, lora_backend)
|
69
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
70
|
+
self.output_offset = torch.tensor(
|
71
|
+
[
|
72
|
+
0,
|
73
|
+
shard_size,
|
74
|
+
],
|
75
|
+
dtype=torch.int32,
|
76
|
+
device=next(self.base_layer.parameters()).device,
|
77
|
+
)
|
69
78
|
|
70
79
|
def set_lora_info(
|
71
80
|
self,
|
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
81
90
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
82
91
|
x=lora_a_output,
|
83
92
|
weights=self.B_buffer,
|
93
|
+
output_offset=self.output_offset,
|
84
94
|
base_output=base_output,
|
85
95
|
)
|
86
96
|
return lora_output
|
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
130
140
|
self.A_buffer_gate_up = A_buffer
|
131
141
|
self.B_buffer_gate_up = B_buffer
|
132
142
|
|
143
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
144
|
+
self.output_offset = torch.tensor(
|
145
|
+
[
|
146
|
+
0,
|
147
|
+
shard_size,
|
148
|
+
2 * shard_size,
|
149
|
+
],
|
150
|
+
dtype=torch.int32,
|
151
|
+
device=next(self.base_layer.parameters()).device,
|
152
|
+
)
|
153
|
+
|
133
154
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
134
155
|
lora_output = self.lora_backend.run_gate_up_lora(
|
135
156
|
x=x,
|
136
157
|
gate_up_lora_a=self.A_buffer_gate_up,
|
137
158
|
gate_up_lora_b=self.B_buffer_gate_up,
|
159
|
+
output_offset=self.output_offset,
|
138
160
|
base_output=base_output,
|
139
161
|
)
|
140
162
|
return lora_output
|
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
243
265
|
self.set_lora = True
|
244
266
|
self.A_buffer = A_buffer
|
245
267
|
self.B_buffer = B_buffer
|
268
|
+
output_size = self.base_layer.output_size
|
269
|
+
self.output_offset = torch.tensor(
|
270
|
+
[
|
271
|
+
0,
|
272
|
+
output_size,
|
273
|
+
],
|
274
|
+
dtype=torch.int32,
|
275
|
+
device=next(self.base_layer.parameters()).device,
|
276
|
+
)
|
246
277
|
|
247
278
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
248
279
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
249
280
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
250
281
|
x=lora_a_output,
|
251
282
|
weights=self.B_buffer,
|
283
|
+
output_offset=self.output_offset,
|
252
284
|
base_output=base_output,
|
253
285
|
)
|
254
286
|
return lora_output
|
sglang/srt/lora/lora.py
CHANGED
@@ -28,6 +28,9 @@ from torch import nn
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
29
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
30
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
31
|
+
|
32
|
+
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
33
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
31
34
|
from sglang.srt.lora.lora_config import LoRAConfig
|
32
35
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
33
36
|
|
@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
|
|
156
159
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
157
160
|
if up_name not in weights:
|
158
161
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
159
|
-
assert self.lora_backend
|
162
|
+
assert isinstance(self.lora_backend, TritonLoRABackend), (
|
160
163
|
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
161
164
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
162
165
|
f"or consider implementing custom initialization logic for other backends."
|