sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -69,7 +69,10 @@ class LoRAManager:
|
|
69
69
|
# LoRA backend for running sgemm kernels
|
70
70
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
71
71
|
backend_type = get_backend_from_name(lora_backend)
|
72
|
-
self.lora_backend: BaseLoRABackend = backend_type(
|
72
|
+
self.lora_backend: BaseLoRABackend = backend_type(
|
73
|
+
max_loras_per_batch=max_loras_per_batch,
|
74
|
+
device=self.device,
|
75
|
+
)
|
73
76
|
|
74
77
|
# Initialize mutable internal state of the LoRAManager.
|
75
78
|
self.init_state(
|
@@ -82,29 +85,22 @@ class LoRAManager:
|
|
82
85
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
83
86
|
with torch.device("cuda"):
|
84
87
|
self.cuda_graph_batch_info = LoRABatchInfo(
|
85
|
-
bs=
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
),
|
88
|
+
bs=max_bs_in_cuda_graph,
|
89
|
+
use_cuda_graph=True,
|
90
|
+
num_segments=None,
|
91
|
+
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
92
|
+
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
|
90
93
|
max_len=1,
|
91
|
-
weight_indices=torch.zeros(
|
92
|
-
|
93
|
-
),
|
94
|
+
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
95
|
+
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
94
96
|
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
95
97
|
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
96
98
|
)
|
97
99
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
103
|
-
dim=0,
|
104
|
-
out=self.cuda_graph_batch_info.seg_indptr[
|
105
|
-
1 : self.max_bs_in_cuda_graph + 1
|
106
|
-
],
|
107
|
-
)
|
100
|
+
self.lora_backend.init_cuda_graph_batch_info(
|
101
|
+
cuda_graph_batch_info=self.cuda_graph_batch_info,
|
102
|
+
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
|
103
|
+
)
|
108
104
|
|
109
105
|
def create_lora_update_result(
|
110
106
|
self, success: bool, error_message: str = ""
|
@@ -232,7 +228,6 @@ class LoRAManager:
|
|
232
228
|
return required_slots <= mem_pool_vacancy
|
233
229
|
|
234
230
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
235
|
-
|
236
231
|
# Load active loras into lora memory pool
|
237
232
|
cur_uids = set(forward_batch.lora_ids)
|
238
233
|
|
@@ -247,102 +242,30 @@ class LoRAManager:
|
|
247
242
|
# set up batch info shared by all lora modules
|
248
243
|
bs = forward_batch.batch_size
|
249
244
|
|
250
|
-
|
251
|
-
weight_indices_out: torch.Tensor,
|
252
|
-
lora_ranks_out: torch.Tensor,
|
253
|
-
scalings_out: torch.Tensor,
|
254
|
-
):
|
255
|
-
"""
|
256
|
-
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
257
|
-
to device (CUDA) asynchronously.
|
258
|
-
"""
|
259
|
-
weight_indices = [0] * len(forward_batch.lora_ids)
|
260
|
-
lora_ranks = [0] * self.max_loras_per_batch
|
261
|
-
scalings = [0] * self.max_loras_per_batch
|
262
|
-
for i, uid in enumerate(forward_batch.lora_ids):
|
263
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
264
|
-
if uid is not None:
|
265
|
-
lora = self.loras[uid]
|
266
|
-
lora_ranks[weight_indices[i]] = lora.config.r
|
267
|
-
scalings[weight_indices[i]] = lora.scaling
|
268
|
-
|
269
|
-
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
270
|
-
weight_indices_tensor = torch.tensor(
|
271
|
-
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
272
|
-
)
|
273
|
-
lora_ranks_tensor = torch.tensor(
|
274
|
-
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
275
|
-
)
|
276
|
-
scalings_tensor = torch.tensor(
|
277
|
-
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
278
|
-
)
|
279
|
-
|
280
|
-
# Copy to device tensors asynchronously
|
281
|
-
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
282
|
-
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
283
|
-
lora_ranks_tensor, non_blocking=True
|
284
|
-
)
|
285
|
-
scalings_out[: self.max_loras_per_batch].copy_(
|
286
|
-
scalings_tensor, non_blocking=True
|
287
|
-
)
|
288
|
-
|
289
|
-
if (
|
245
|
+
use_cuda_graph = (
|
290
246
|
hasattr(self, "max_bs_in_cuda_graph")
|
291
247
|
and bs <= self.max_bs_in_cuda_graph
|
292
248
|
and forward_batch.forward_mode.is_cuda_graph()
|
293
|
-
)
|
294
|
-
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
295
|
-
# could use CUDA graph.
|
296
|
-
|
297
|
-
transfer_adapter_info(
|
298
|
-
self.cuda_graph_batch_info.weight_indices,
|
299
|
-
self.cuda_graph_batch_info.lora_ranks,
|
300
|
-
self.cuda_graph_batch_info.scalings,
|
301
|
-
)
|
302
|
-
|
303
|
-
self.cuda_graph_batch_info.bs = bs
|
304
|
-
self.cuda_graph_batch_info.max_len = 1
|
305
|
-
batch_info = self.cuda_graph_batch_info
|
306
|
-
else:
|
307
|
-
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
308
|
-
lora_ranks = torch.zeros(
|
309
|
-
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
310
|
-
)
|
311
|
-
scalings = torch.zeros(
|
312
|
-
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
313
|
-
)
|
314
|
-
transfer_adapter_info(
|
315
|
-
weight_indices,
|
316
|
-
lora_ranks,
|
317
|
-
scalings,
|
318
|
-
)
|
319
|
-
|
320
|
-
seg_lens = (
|
321
|
-
forward_batch.extend_seq_lens
|
322
|
-
if forward_batch.forward_mode.is_extend()
|
323
|
-
else torch.ones(bs, device=self.device)
|
324
|
-
)
|
325
|
-
|
326
|
-
max_len = (
|
327
|
-
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
328
|
-
max(forward_batch.extend_seq_lens_cpu)
|
329
|
-
if forward_batch.forward_mode.is_extend()
|
330
|
-
else 1
|
331
|
-
)
|
249
|
+
)
|
332
250
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
weight_indices=
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
251
|
+
weight_indices = [0] * len(forward_batch.lora_ids)
|
252
|
+
lora_ranks = [0] * self.max_loras_per_batch
|
253
|
+
scalings = [0] * self.max_loras_per_batch
|
254
|
+
for i, uid in enumerate(forward_batch.lora_ids):
|
255
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
256
|
+
if uid is not None:
|
257
|
+
lora = self.loras[uid]
|
258
|
+
lora_ranks[weight_indices[i]] = lora.config.r
|
259
|
+
scalings[weight_indices[i]] = lora.scaling
|
260
|
+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
261
|
+
# could use CUDA graph.
|
262
|
+
self.lora_backend.prepare_lora_batch(
|
263
|
+
forward_batch=forward_batch,
|
264
|
+
weight_indices=weight_indices,
|
265
|
+
lora_ranks=lora_ranks,
|
266
|
+
scalings=scalings,
|
267
|
+
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
|
268
|
+
)
|
346
269
|
|
347
270
|
def update_lora_info(self):
|
348
271
|
"""
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
|
|
104
104
|
return all(_can_support(x) for x in config)
|
105
105
|
|
106
106
|
def get_lora_A_shape(
|
107
|
-
self,
|
107
|
+
self,
|
108
|
+
module_name: str,
|
109
|
+
base_model: torch.nn.Module,
|
110
|
+
max_lora_dim: int,
|
111
|
+
layer_idx: int,
|
108
112
|
) -> Tuple[int]:
|
109
113
|
"""
|
110
114
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
111
115
|
"""
|
112
|
-
input_dim, _ = get_hidden_dim(
|
116
|
+
input_dim, _ = get_hidden_dim(
|
117
|
+
module_name, self.base_hf_config, base_model, layer_idx
|
118
|
+
)
|
113
119
|
c = get_stacked_multiply(module_name)
|
114
120
|
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
115
121
|
input_dim = divide(input_dim, self.tp_size)
|
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
|
|
120
126
|
)
|
121
127
|
|
122
128
|
def get_lora_B_shape(
|
123
|
-
self,
|
129
|
+
self,
|
130
|
+
module_name: str,
|
131
|
+
base_model: torch.nn.Module,
|
132
|
+
max_lora_dim: int,
|
133
|
+
layer_idx: int,
|
124
134
|
) -> Tuple[int]:
|
125
135
|
"""
|
126
136
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
127
137
|
"""
|
128
|
-
_, output_dim = get_hidden_dim(
|
138
|
+
_, output_dim = get_hidden_dim(
|
139
|
+
module_name, self.base_hf_config, base_model, layer_idx
|
140
|
+
)
|
129
141
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
130
142
|
output_dim = divide(output_dim, self.tp_size)
|
131
143
|
return (
|
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
|
|
140
152
|
def init_buffer(
|
141
153
|
buffer: Dict[str, List[torch.Tensor]],
|
142
154
|
target_modules: Set[str],
|
143
|
-
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
155
|
+
get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
|
144
156
|
):
|
145
157
|
for module_name in target_modules:
|
146
|
-
lora_shape = get_lora_shape_fn(
|
147
|
-
module_name, base_model, self.max_lora_rank
|
148
|
-
)
|
149
158
|
buffer[module_name] = [
|
150
159
|
torch.empty(
|
151
|
-
|
160
|
+
get_lora_shape_fn(
|
161
|
+
module_name,
|
162
|
+
base_model,
|
163
|
+
self.max_lora_rank,
|
164
|
+
idx,
|
165
|
+
),
|
152
166
|
dtype=self.dtype,
|
153
167
|
device=device,
|
154
168
|
)
|
155
|
-
for
|
169
|
+
for idx in range(self.num_layer)
|
156
170
|
]
|
157
171
|
|
158
172
|
init_buffer(
|
sglang/srt/lora/utils.py
CHANGED
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
|
10
10
|
|
11
11
|
@dataclass
|
12
12
|
class LoRABatchInfo:
|
13
|
+
# The forward mode is using CUDA Graph.
|
14
|
+
use_cuda_graph: bool
|
15
|
+
|
13
16
|
# Batch size
|
14
17
|
bs: int
|
15
18
|
|
16
|
-
#
|
17
|
-
|
19
|
+
# Number of segments. For triton backend, it is equal to batch size.
|
20
|
+
num_segments: int
|
18
21
|
|
19
|
-
# Indice pointers of each
|
22
|
+
# Indice pointers of each segment in shape (num_segments + 1, )
|
20
23
|
seg_indptr: torch.Tensor
|
21
24
|
|
22
|
-
#
|
23
|
-
max_len: int
|
24
|
-
|
25
|
-
# The index of lora adapter used by each sequence, in shape (bs,)
|
25
|
+
# The index of lora adapter used by each segment, in shape (num_segments,)
|
26
26
|
weight_indices: torch.Tensor
|
27
27
|
|
28
28
|
# ranks of each lora adapter, in shape (lora_num,)
|
@@ -31,6 +31,15 @@ class LoRABatchInfo:
|
|
31
31
|
# scaling of each lora adapter, in shape (lora_num,)
|
32
32
|
scalings: torch.Tensor
|
33
33
|
|
34
|
+
# Lengths of each segments in shape (num_segments,)
|
35
|
+
seg_lens: Optional[torch.Tensor]
|
36
|
+
|
37
|
+
# Maximum segment length of current batch
|
38
|
+
max_len: Optional[int]
|
39
|
+
|
40
|
+
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
|
41
|
+
permutation: Optional[torch.Tensor]
|
42
|
+
|
34
43
|
|
35
44
|
class LoRAType(Enum):
|
36
45
|
LORA_A = 0
|
@@ -48,14 +57,14 @@ def get_layer_id(name: str) -> int:
|
|
48
57
|
|
49
58
|
|
50
59
|
def get_hidden_dim(
|
51
|
-
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
60
|
+
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
|
52
61
|
) -> Tuple[int]:
|
53
62
|
"""
|
54
63
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
55
64
|
"""
|
56
65
|
|
57
66
|
if hasattr(base_model, "get_hidden_dim"):
|
58
|
-
return base_model.get_hidden_dim(module_name)
|
67
|
+
return base_model.get_hidden_dim(module_name, layer_idx)
|
59
68
|
else:
|
60
69
|
"""
|
61
70
|
WARNING: get_hidden_dim() is not defined,
|
@@ -0,0 +1,170 @@
|
|
1
|
+
"""
|
2
|
+
Asynchronous dynamic batch tokenizer for SGLang.
|
3
|
+
|
4
|
+
This module provides an async tokenizer with dynamic batching capabilities
|
5
|
+
to reduce tokenization overhead when multiple requests arrive concurrently.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
import logging
|
10
|
+
from concurrent.futures import ThreadPoolExecutor
|
11
|
+
from functools import partial
|
12
|
+
from typing import Any, Dict, List, Optional
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class AsyncDynamicbatchTokenizer:
|
18
|
+
"""Asynchronous tokenizer with dynamic batching for single string prompts.
|
19
|
+
|
20
|
+
Dynamically batches pending encode requests from a queue to reduce overhead.
|
21
|
+
Only handles single string prompts - regular batch processing of multiple
|
22
|
+
strings per request should be handled at a higher level.
|
23
|
+
A single-thread ThreadPoolExecutor is used so the event loop stays responsive.
|
24
|
+
|
25
|
+
Note: Uses lazy initialization for asyncio components because this class
|
26
|
+
is instantiated in TokenizerManager.__init__() before the event loop starts.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
tokenizer,
|
32
|
+
max_batch_size: int = 32,
|
33
|
+
batch_wait_timeout_s: float = 0.002,
|
34
|
+
) -> None:
|
35
|
+
self.tokenizer = tokenizer
|
36
|
+
self.max_batch_size = max_batch_size
|
37
|
+
self.batch_wait_timeout_s = batch_wait_timeout_s
|
38
|
+
|
39
|
+
# Single queue for all encode requests - initialized lazily
|
40
|
+
self._queue: Optional[asyncio.Queue] = None
|
41
|
+
self._batcher_task: Optional[asyncio.Task] = None
|
42
|
+
|
43
|
+
# Single-thread executor for blocking tokenizer calls
|
44
|
+
self._executor = ThreadPoolExecutor(max_workers=1)
|
45
|
+
self._initialized = False
|
46
|
+
|
47
|
+
def _ensure_initialized(self):
|
48
|
+
"""Lazy initialization of event loop dependent components."""
|
49
|
+
if not self._initialized:
|
50
|
+
self._queue = asyncio.Queue()
|
51
|
+
self._batcher_task = asyncio.create_task(self._dynamic_batch_loop())
|
52
|
+
self._initialized = True
|
53
|
+
|
54
|
+
async def __call__(self, prompt: str, **kwargs) -> Any:
|
55
|
+
"""Encode a single prompt."""
|
56
|
+
return await self.encode(prompt, **kwargs)
|
57
|
+
|
58
|
+
async def encode(self, prompt: str, **kwargs) -> Any:
|
59
|
+
"""Encode a single prompt."""
|
60
|
+
self._ensure_initialized()
|
61
|
+
result_future: asyncio.Future = asyncio.get_running_loop().create_future()
|
62
|
+
await self._queue.put((prompt, kwargs, result_future))
|
63
|
+
return await result_future
|
64
|
+
|
65
|
+
async def _dynamic_batch_loop(self):
|
66
|
+
"""Dynamically batch incoming encode requests for efficiency."""
|
67
|
+
while True:
|
68
|
+
try:
|
69
|
+
# Get the first request
|
70
|
+
prompt, kwargs, result_future = await self._queue.get()
|
71
|
+
|
72
|
+
# Collect requests into dynamic batch
|
73
|
+
prompts = [prompt]
|
74
|
+
kwargs_list = [kwargs]
|
75
|
+
result_futures = [result_future]
|
76
|
+
|
77
|
+
# Check if there are more items immediately available in the queue
|
78
|
+
# If queue is empty, process single item immediately without timeout
|
79
|
+
if self._queue.empty():
|
80
|
+
# No other requests waiting, process immediately
|
81
|
+
pass
|
82
|
+
else:
|
83
|
+
# There might be more requests, wait for dynamic batching opportunity
|
84
|
+
start_time = asyncio.get_running_loop().time()
|
85
|
+
|
86
|
+
# Collect more requests up to max_batch_size or batch_wait_timeout_s
|
87
|
+
while len(prompts) < self.max_batch_size:
|
88
|
+
elapsed = asyncio.get_running_loop().time() - start_time
|
89
|
+
if elapsed >= self.batch_wait_timeout_s:
|
90
|
+
break
|
91
|
+
|
92
|
+
remaining_time = self.batch_wait_timeout_s - elapsed
|
93
|
+
try:
|
94
|
+
prompt, kwargs, result_future = await asyncio.wait_for(
|
95
|
+
self._queue.get(), remaining_time
|
96
|
+
)
|
97
|
+
prompts.append(prompt)
|
98
|
+
kwargs_list.append(kwargs)
|
99
|
+
result_futures.append(result_future)
|
100
|
+
except asyncio.TimeoutError:
|
101
|
+
break
|
102
|
+
|
103
|
+
# Log dynamic batch information
|
104
|
+
logger.debug(
|
105
|
+
f"AsyncDynamicbatchTokenizer: Processing dynamic batch of size {len(prompts)}"
|
106
|
+
)
|
107
|
+
|
108
|
+
# Process the dynamic batch
|
109
|
+
await self._process_dynamic_batch(prompts, kwargs_list, result_futures)
|
110
|
+
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(f"Error in dynamic batch loop: {e}")
|
113
|
+
# Continue the loop to handle other requests
|
114
|
+
|
115
|
+
async def _process_dynamic_batch(
|
116
|
+
self,
|
117
|
+
prompts: List[str],
|
118
|
+
kwargs_list: List[Dict],
|
119
|
+
result_futures: List[asyncio.Future],
|
120
|
+
) -> None:
|
121
|
+
"""Process a dynamic batch of encode requests for single string prompts."""
|
122
|
+
# Check if all kwargs are identical for efficient batch processing
|
123
|
+
can_batch = len(set(str(sorted(kw.items())) for kw in kwargs_list)) == 1
|
124
|
+
kwargs = kwargs_list[0] if can_batch else None
|
125
|
+
|
126
|
+
try:
|
127
|
+
# If every request uses identical kwargs we can run a single
|
128
|
+
# batch tokenizer call for a big speed-up.
|
129
|
+
if can_batch and len(prompts) > 1:
|
130
|
+
encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
131
|
+
results = await asyncio.get_running_loop().run_in_executor(
|
132
|
+
self._executor, encode_fn
|
133
|
+
)
|
134
|
+
|
135
|
+
for i, fut in enumerate(result_futures):
|
136
|
+
if not fut.done():
|
137
|
+
data = {k: v[i] for k, v in results.items()}
|
138
|
+
fut.set_result(data)
|
139
|
+
else:
|
140
|
+
# Process each request individually due to different kwargs
|
141
|
+
if len(prompts) > 1 and not can_batch:
|
142
|
+
logger.warning(
|
143
|
+
f"AsyncDynamicbatchTokenizer: Dynamic batching disabled for batch of {len(prompts)} "
|
144
|
+
f"requests due to differing kwargs. This reduces performance benefits. "
|
145
|
+
f"Consider using consistent tokenization parameters across requests."
|
146
|
+
)
|
147
|
+
|
148
|
+
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
149
|
+
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs_list)
|
150
|
+
]
|
151
|
+
results = await asyncio.get_running_loop().run_in_executor(
|
152
|
+
self._executor, encode_fn
|
153
|
+
)
|
154
|
+
|
155
|
+
for fut, res in zip(result_futures, results):
|
156
|
+
if not fut.done():
|
157
|
+
fut.set_result(res)
|
158
|
+
except Exception as e:
|
159
|
+
logger.error(f"Error in dynamic batch processing: {e}")
|
160
|
+
for fut in result_futures:
|
161
|
+
if not fut.done():
|
162
|
+
fut.set_exception(e)
|
163
|
+
|
164
|
+
def __del__(self):
|
165
|
+
"""Clean up background tasks."""
|
166
|
+
if hasattr(self, "_batcher_task") and self._batcher_task:
|
167
|
+
if not self._batcher_task.done():
|
168
|
+
self._batcher_task.cancel()
|
169
|
+
if hasattr(self, "_executor"):
|
170
|
+
self._executor.shutdown(wait=False)
|