sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
261
261
|
)
|
262
262
|
self.embedding_dim = embedding_dim
|
263
263
|
|
264
|
-
|
264
|
+
quant_method = None
|
265
265
|
if quant_config is not None:
|
266
|
-
|
267
|
-
|
268
|
-
|
266
|
+
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
267
|
+
print("quant_method", quant_method)
|
268
|
+
if quant_method is None:
|
269
|
+
quant_method = UnquantizedEmbeddingMethod()
|
269
270
|
|
270
271
|
# If we are making an embedding layer, then our quantization linear
|
271
272
|
# method must implement the embedding operation. If we are another
|
272
273
|
# layer type like ParallelLMHead, this is not important.
|
273
274
|
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
274
|
-
|
275
|
-
type(
|
275
|
+
quant_method_implements_embedding = method_has_implemented_embedding(
|
276
|
+
type(quant_method)
|
276
277
|
)
|
277
|
-
if is_embedding_layer and not
|
278
|
+
if is_embedding_layer and not quant_method_implements_embedding:
|
278
279
|
raise NotImplementedError(
|
279
|
-
f"The class {type(
|
280
|
+
f"The class {type(quant_method).__name__} must implement "
|
280
281
|
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
281
282
|
)
|
282
283
|
|
283
|
-
self.
|
284
|
+
self.quant_method: QuantizeMethodBase = quant_method
|
284
285
|
|
285
286
|
if params_dtype is None:
|
286
287
|
params_dtype = torch.get_default_dtype()
|
@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
301
302
|
- self.shard_indices.added_vocab_start_index
|
302
303
|
)
|
303
304
|
|
304
|
-
self.
|
305
|
+
self.quant_method.create_weights(
|
305
306
|
self,
|
306
307
|
self.embedding_dim,
|
307
308
|
[self.num_embeddings_per_partition],
|
@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
446
447
|
packed_factor = (
|
447
448
|
param.packed_factor
|
448
449
|
if isinstance(param, BasevLLMParameter)
|
449
|
-
else param.
|
450
|
+
else param.packed_factor
|
450
451
|
)
|
451
452
|
assert loaded_weight.shape[output_dim] == (
|
452
453
|
self.org_vocab_size // param.packed_factor
|
@@ -457,7 +458,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
457
458
|
assert loaded_weight.shape[output_dim] == (
|
458
459
|
self.org_vocab_size
|
459
460
|
// (self.tp_size if self.use_presharded_weights else 1)
|
460
|
-
)
|
461
|
+
), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
|
461
462
|
|
462
463
|
# Copy the data.
|
463
464
|
if not self.use_presharded_weights:
|
@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
479
480
|
else:
|
480
481
|
masked_input = input_
|
481
482
|
# Get the embeddings.
|
482
|
-
output_parallel = self.
|
483
|
+
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
483
484
|
# Mask the output embedding.
|
484
485
|
if self.tp_size > 1:
|
485
486
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
sglang/srt/lora/lora.py
CHANGED
@@ -18,6 +18,7 @@
|
|
18
18
|
# LoRA layers class inheritance adapted from:
|
19
19
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
20
20
|
|
21
|
+
import logging
|
21
22
|
import re
|
22
23
|
from typing import Dict, List
|
23
24
|
|
@@ -30,6 +31,8 @@ from sglang.srt.lora.backend import BaseLoRABackend
|
|
30
31
|
from sglang.srt.lora.lora_config import LoRAConfig
|
31
32
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
32
33
|
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
33
36
|
|
34
37
|
class LoRALayer(nn.Module):
|
35
38
|
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
@@ -173,6 +176,18 @@ class LoRAAdapter(nn.Module):
|
|
173
176
|
if "gate_proj" in weight_name:
|
174
177
|
up_name = weight_name.replace("gate_proj", "up_proj")
|
175
178
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
179
|
+
if up_name not in weights:
|
180
|
+
logger.warning(
|
181
|
+
f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
|
182
|
+
f"Initializing up projection to zero."
|
183
|
+
)
|
184
|
+
weights[up_name] = torch.zeros_like(weights[weight_name])
|
185
|
+
# FIXME: Add gate-only support for flashinfer in future implementations
|
186
|
+
assert self.lora_backend.name == "triton", (
|
187
|
+
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
188
|
+
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
189
|
+
f"or consider implementing custom initialization logic for other backends."
|
190
|
+
)
|
176
191
|
if "lora_A" in weight_name:
|
177
192
|
weights[gate_up_name] = torch.cat(
|
178
193
|
(weights[weight_name], weights[up_name]), 0
|
@@ -182,4 +197,5 @@ class LoRAAdapter(nn.Module):
|
|
182
197
|
[weights[weight_name], weights[up_name]], dim=0
|
183
198
|
)
|
184
199
|
weights.pop(weight_name)
|
185
|
-
weights
|
200
|
+
if up_name in weights:
|
201
|
+
weights.pop(up_name)
|
sglang/srt/lora/lora_config.py
CHANGED
@@ -26,6 +26,11 @@ class LoRAConfig:
|
|
26
26
|
self.path = path
|
27
27
|
self.hf_config = self.get_lora_config()
|
28
28
|
self.target_modules = self.hf_config["target_modules"]
|
29
|
+
|
30
|
+
# TODO: Support more modules
|
31
|
+
if any(module in self.target_modules for module in ["embed_tokens", "lm_head"]):
|
32
|
+
raise ValueError("Not supported yet")
|
33
|
+
|
29
34
|
self.r = self.hf_config["r"]
|
30
35
|
self.lora_alpha = self.hf_config["lora_alpha"]
|
31
36
|
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -76,9 +76,7 @@ class LoRAManager:
|
|
76
76
|
self.hf_target_names: Set[str] = set()
|
77
77
|
for name, path in self.lora_paths.items():
|
78
78
|
self.configs[name] = LoRAConfig(path)
|
79
|
-
self.hf_target_names
|
80
|
-
self.configs[name].target_modules
|
81
|
-
)
|
79
|
+
self.hf_target_names.update(self.configs[name].target_modules)
|
82
80
|
|
83
81
|
# Target lora weight names for lora_a and lora_b modules repectively.
|
84
82
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team
|
|
5
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
6
6
|
you may not use this file except in compliance with the License.
|
7
7
|
You may obtain a copy of the License at
|
8
|
-
|
9
8
|
http://www.apache.org/licenses/LICENSE-2.0
|
10
|
-
|
11
9
|
Unless required by applicable law or agreed to in writing, software
|
12
10
|
distributed under the License is distributed on an "AS IS" BASIS,
|
13
11
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
@@ -15,14 +13,16 @@ See the License for the specific language governing permissions and
|
|
15
13
|
limitations under the License.
|
16
14
|
"""
|
17
15
|
|
16
|
+
import concurrent.futures
|
18
17
|
import logging
|
18
|
+
import math
|
19
19
|
import threading
|
20
|
-
from queue import PriorityQueue, Queue
|
21
|
-
from typing import Optional
|
20
|
+
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
+
from typing import List, Optional
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool,
|
25
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
26
26
|
|
27
27
|
logger = logging.getLogger(__name__)
|
28
28
|
|
@@ -55,6 +55,27 @@ class CacheOperation:
|
|
55
55
|
self.priority = min(self.priority, other.priority)
|
56
56
|
self.node_ids.extend(other.node_ids)
|
57
57
|
|
58
|
+
def split(self, factor) -> List["CacheOperation"]:
|
59
|
+
# split an operation into smaller operations to reduce the size of intermediate buffers
|
60
|
+
if factor <= 1:
|
61
|
+
return [self]
|
62
|
+
|
63
|
+
chunk_size = math.ceil(len(self.host_indices) / factor)
|
64
|
+
split_ops = []
|
65
|
+
for i in range(0, len(self.host_indices), chunk_size):
|
66
|
+
split_ops.append(
|
67
|
+
CacheOperation(
|
68
|
+
host_indices=self.host_indices[i : i + chunk_size],
|
69
|
+
device_indices=self.device_indices[i : i + chunk_size],
|
70
|
+
node_id=0,
|
71
|
+
)
|
72
|
+
)
|
73
|
+
# Inherit the node_ids on the final chunk
|
74
|
+
if split_ops:
|
75
|
+
split_ops[-1].node_ids = self.node_ids
|
76
|
+
|
77
|
+
return split_ops
|
78
|
+
|
58
79
|
def __lt__(self, other: "CacheOperation"):
|
59
80
|
return self.priority < other.priority
|
60
81
|
|
@@ -64,7 +85,10 @@ class TransferBuffer:
|
|
64
85
|
Overlapping buffer preparation and transfer operations to improve throughput.
|
65
86
|
"""
|
66
87
|
|
67
|
-
def __init__(
|
88
|
+
def __init__(
|
89
|
+
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
|
90
|
+
) -> None:
|
91
|
+
self.stop_event = stop_event
|
68
92
|
self.buffers = Queue(maxsize=buffer_count)
|
69
93
|
# todo: adjust the buffer size based on throughput profile of the system
|
70
94
|
self.max_buffer_size = max_buffer_size
|
@@ -75,22 +99,36 @@ class TransferBuffer:
|
|
75
99
|
def empty(self) -> bool:
|
76
100
|
return self.buffers.empty()
|
77
101
|
|
78
|
-
def put(self, item, block=True) -> None:
|
79
|
-
self.
|
102
|
+
def put(self, item, block=True, timeout=1) -> None:
|
103
|
+
while not self.stop_event.is_set():
|
104
|
+
try:
|
105
|
+
self.buffers.put(item, block=block, timeout=timeout)
|
106
|
+
break
|
107
|
+
except Full:
|
108
|
+
if not block:
|
109
|
+
break
|
110
|
+
continue
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(e)
|
80
113
|
|
81
|
-
def get(self, block=True) -> Optional[CacheOperation]:
|
114
|
+
def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
|
82
115
|
try:
|
83
|
-
return self.buffers.get(block=block)
|
116
|
+
return self.buffers.get(block=block, timeout=timeout)
|
117
|
+
except Empty:
|
118
|
+
return None
|
84
119
|
except Exception as e:
|
85
120
|
logger.error(e)
|
86
121
|
|
122
|
+
def clear(self):
|
123
|
+
self.buffers.queue.clear()
|
124
|
+
|
87
125
|
|
88
126
|
class HiCacheController:
|
89
127
|
|
90
128
|
def __init__(
|
91
129
|
self,
|
92
130
|
mem_pool_device: MHATokenToKVPool,
|
93
|
-
mem_pool_host:
|
131
|
+
mem_pool_host: MHATokenToKVPoolHost,
|
94
132
|
write_policy: str = "write_through_selective",
|
95
133
|
):
|
96
134
|
|
@@ -111,8 +149,11 @@ class HiCacheController:
|
|
111
149
|
self.ack_write_queue = Queue()
|
112
150
|
self.ack_load_queue = Queue()
|
113
151
|
|
114
|
-
self.
|
115
|
-
self.
|
152
|
+
self.stop_event = threading.Event()
|
153
|
+
self.write_buffer = TransferBuffer(self.stop_event)
|
154
|
+
self.load_buffer = TransferBuffer(
|
155
|
+
self.stop_event, buffer_count=10, max_buffer_size=100
|
156
|
+
)
|
116
157
|
|
117
158
|
self.write_stream = torch.cuda.Stream()
|
118
159
|
self.load_stream = torch.cuda.Stream()
|
@@ -126,6 +167,28 @@ class HiCacheController:
|
|
126
167
|
self.write_thread.start()
|
127
168
|
self.load_thread.start()
|
128
169
|
|
170
|
+
def reset(self):
|
171
|
+
self.stop_event.set()
|
172
|
+
self.write_thread.join()
|
173
|
+
self.load_thread.join()
|
174
|
+
|
175
|
+
self.write_queue.queue.clear()
|
176
|
+
self.load_queue.queue.clear()
|
177
|
+
self.write_buffer.clear()
|
178
|
+
self.load_buffer.clear()
|
179
|
+
self.ack_write_queue.queue.clear()
|
180
|
+
self.ack_load_queue.queue.clear()
|
181
|
+
|
182
|
+
self.write_thread = threading.Thread(
|
183
|
+
target=self.write_thread_func_buffer, daemon=True
|
184
|
+
)
|
185
|
+
self.load_thread = threading.Thread(
|
186
|
+
target=self.load_thread_func_buffer, daemon=True
|
187
|
+
)
|
188
|
+
self.stop_event.clear()
|
189
|
+
self.write_thread.start()
|
190
|
+
self.load_thread.start()
|
191
|
+
|
129
192
|
def write(
|
130
193
|
self,
|
131
194
|
device_indices: torch.Tensor,
|
@@ -138,10 +201,10 @@ class HiCacheController:
|
|
138
201
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
139
202
|
if host_indices is None:
|
140
203
|
return None
|
204
|
+
self.mem_pool_host.protect_write(host_indices)
|
141
205
|
self.write_queue.put(
|
142
206
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
143
207
|
)
|
144
|
-
self.mem_pool_host.protect_write(host_indices)
|
145
208
|
return host_indices
|
146
209
|
|
147
210
|
def load(
|
@@ -156,10 +219,10 @@ class HiCacheController:
|
|
156
219
|
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
157
220
|
if device_indices is None:
|
158
221
|
return None
|
222
|
+
self.mem_pool_host.protect_load(host_indices)
|
159
223
|
self.load_queue.put(
|
160
224
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
161
225
|
)
|
162
|
-
self.mem_pool_host.protect_load(host_indices)
|
163
226
|
return device_indices
|
164
227
|
|
165
228
|
def write_thread_func_direct(self):
|
@@ -167,16 +230,19 @@ class HiCacheController:
|
|
167
230
|
Directly write through KV caches to host memory without buffering.
|
168
231
|
"""
|
169
232
|
with torch.cuda.stream(self.write_stream):
|
170
|
-
while
|
233
|
+
while not self.stop_event.is_set():
|
171
234
|
try:
|
172
|
-
operation = self.write_queue.get(block=True)
|
235
|
+
operation = self.write_queue.get(block=True, timeout=1)
|
173
236
|
operation.data = self.mem_pool_device.get_flat_data(
|
174
237
|
operation.device_indices
|
175
238
|
)
|
176
239
|
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
177
240
|
self.mem_pool_host.complete_io(operation.host_indices)
|
178
241
|
for node_id in operation.node_ids:
|
179
|
-
|
242
|
+
if node_id != 0:
|
243
|
+
self.ack_write_queue.put(node_id)
|
244
|
+
except Empty:
|
245
|
+
continue
|
180
246
|
except Exception as e:
|
181
247
|
logger.error(e)
|
182
248
|
|
@@ -185,9 +251,10 @@ class HiCacheController:
|
|
185
251
|
Directly load KV caches from host memory to device memory without buffering.
|
186
252
|
"""
|
187
253
|
with torch.cuda.stream(self.load_stream):
|
188
|
-
while
|
254
|
+
while not self.stop_event.is_set():
|
189
255
|
try:
|
190
|
-
operation = self.load_queue.get(block=True)
|
256
|
+
operation = self.load_queue.get(block=True, timeout=1)
|
257
|
+
# time.sleep(18e-6 * len(operation.host_indices))
|
191
258
|
operation.data = self.mem_pool_host.get_flat_data(
|
192
259
|
operation.host_indices
|
193
260
|
)
|
@@ -196,7 +263,10 @@ class HiCacheController:
|
|
196
263
|
)
|
197
264
|
self.mem_pool_host.complete_io(operation.host_indices)
|
198
265
|
for node_id in operation.node_ids:
|
199
|
-
|
266
|
+
if node_id != 0:
|
267
|
+
self.ack_load_queue.put(node_id)
|
268
|
+
except Empty:
|
269
|
+
continue
|
200
270
|
except Exception as e:
|
201
271
|
logger.error(e)
|
202
272
|
|
@@ -204,39 +274,98 @@ class HiCacheController:
|
|
204
274
|
"""
|
205
275
|
Auxiliary function to prepare the buffer for write operations.
|
206
276
|
"""
|
277
|
+
|
278
|
+
def _to_op(op_):
|
279
|
+
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
280
|
+
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
281
|
+
self.mem_pool_host.device
|
282
|
+
)
|
283
|
+
self.write_buffer.put(op_)
|
284
|
+
return op_
|
285
|
+
|
207
286
|
buffer = None
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
287
|
+
with torch.cuda.stream(self.write_stream):
|
288
|
+
while not self.stop_event.is_set():
|
289
|
+
try:
|
290
|
+
operation = self.write_queue.get(block=True, timeout=1)
|
291
|
+
factor = (
|
292
|
+
len(operation.device_indices)
|
293
|
+
// self.write_buffer.max_buffer_size
|
294
|
+
)
|
295
|
+
|
296
|
+
if factor >= 1:
|
297
|
+
if buffer is not None:
|
298
|
+
_to_op(buffer)
|
299
|
+
buffer = None
|
300
|
+
|
301
|
+
if factor < 2:
|
302
|
+
_to_op(operation)
|
303
|
+
else:
|
304
|
+
split_ops = operation.split(factor)
|
305
|
+
for op_ in split_ops:
|
306
|
+
_to_op(op_)
|
307
|
+
continue
|
308
|
+
|
309
|
+
if buffer is None:
|
310
|
+
buffer = operation
|
311
|
+
else:
|
312
|
+
buffer.merge(operation)
|
313
|
+
if (
|
314
|
+
no_wait
|
315
|
+
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
316
|
+
or self.write_queue.empty()
|
317
|
+
or self.write_buffer.empty()
|
318
|
+
):
|
319
|
+
_to_op(buffer)
|
320
|
+
buffer = None
|
321
|
+
except Empty:
|
322
|
+
continue
|
323
|
+
except Exception as e:
|
324
|
+
logger.error(e)
|
231
325
|
|
232
326
|
def load_aux_func(self):
|
233
327
|
"""
|
234
328
|
Auxiliary function to prepare the buffer for load operations.
|
235
329
|
"""
|
330
|
+
|
331
|
+
def _pin_op(op_, put=True):
|
332
|
+
op_.data = (
|
333
|
+
self.mem_pool_host.get_flat_data(op_.host_indices)
|
334
|
+
.contiguous()
|
335
|
+
.pin_memory()
|
336
|
+
)
|
337
|
+
if put:
|
338
|
+
self.load_buffer.put(op_)
|
339
|
+
return op_
|
340
|
+
|
236
341
|
buffer = None
|
237
|
-
while
|
342
|
+
while not self.stop_event.is_set():
|
238
343
|
try:
|
239
|
-
operation = self.load_queue.get(block=True)
|
344
|
+
operation = self.load_queue.get(block=True, timeout=1)
|
345
|
+
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
346
|
+
|
347
|
+
if factor >= 1:
|
348
|
+
if buffer is not None:
|
349
|
+
_pin_op(buffer)
|
350
|
+
buffer = None
|
351
|
+
|
352
|
+
if factor < 2:
|
353
|
+
_pin_op(operation)
|
354
|
+
else:
|
355
|
+
split_ops = operation.split(factor)
|
356
|
+
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
357
|
+
split_args.append((split_ops[-1], False))
|
358
|
+
# Spawn threads to pin each op concurrently
|
359
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
360
|
+
pinned_ops = list(
|
361
|
+
executor.map(
|
362
|
+
lambda x: _pin_op(x[0], put=x[1]), split_args
|
363
|
+
)
|
364
|
+
)
|
365
|
+
# preserve the order of last op to ensure correct ack
|
366
|
+
self.load_buffer.put(pinned_ops[-1])
|
367
|
+
continue
|
368
|
+
|
240
369
|
if buffer is None:
|
241
370
|
buffer = operation
|
242
371
|
else:
|
@@ -246,41 +375,43 @@ class HiCacheController:
|
|
246
375
|
or self.load_queue.empty()
|
247
376
|
or self.load_buffer.empty()
|
248
377
|
):
|
249
|
-
buffer
|
250
|
-
self.mem_pool_host.get_flat_data(buffer.host_indices)
|
251
|
-
.contiguous()
|
252
|
-
.pin_memory()
|
253
|
-
)
|
254
|
-
self.load_buffer.put(buffer, block=True)
|
378
|
+
_pin_op(buffer)
|
255
379
|
buffer = None
|
380
|
+
except Empty:
|
381
|
+
continue
|
256
382
|
except Exception as e:
|
257
383
|
logger.error(e)
|
258
384
|
|
259
385
|
def write_thread_func_buffer(self):
|
260
386
|
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
261
387
|
aux_thread.start()
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
388
|
+
|
389
|
+
while not self.stop_event.is_set():
|
390
|
+
operation = self.write_buffer.get()
|
391
|
+
if operation is None:
|
392
|
+
continue
|
393
|
+
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
394
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
395
|
+
for node_id in operation.node_ids:
|
396
|
+
if node_id != 0:
|
270
397
|
self.ack_write_queue.put(node_id)
|
398
|
+
aux_thread.join()
|
271
399
|
|
272
400
|
def load_thread_func_buffer(self):
|
273
401
|
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
274
402
|
aux_thread.start()
|
403
|
+
|
275
404
|
with torch.cuda.stream(self.load_stream):
|
276
|
-
while
|
405
|
+
while not self.stop_event.is_set():
|
277
406
|
operation = self.load_buffer.get()
|
278
407
|
if operation is None:
|
279
408
|
continue
|
280
409
|
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
281
410
|
self.mem_pool_host.complete_io(operation.host_indices)
|
282
411
|
for node_id in operation.node_ids:
|
283
|
-
|
412
|
+
if node_id != 0:
|
413
|
+
self.ack_load_queue.put(node_id)
|
414
|
+
aux_thread.join()
|
284
415
|
|
285
416
|
def evict_device(
|
286
417
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
@@ -28,6 +28,7 @@ if __name__ == "__main__":
|
|
28
28
|
parser = argparse.ArgumentParser()
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
30
|
parser.add_argument("--log-requests", action="store_true")
|
31
|
+
parser.add_argument("--log-requests-level", type=int, default=2)
|
31
32
|
parser.add_argument(
|
32
33
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
33
34
|
)
|
@@ -38,7 +39,7 @@ if __name__ == "__main__":
|
|
38
39
|
args.url + "/configure_logging",
|
39
40
|
json={
|
40
41
|
"log_requests": args.log_requests,
|
41
|
-
"log_requests_level":
|
42
|
+
"log_requests_level": args.log_requests_level, # Log full requests
|
42
43
|
"dump_requests_folder": args.dump_requests_folder,
|
43
44
|
"dump_requests_threshold": args.dump_requests_threshold,
|
44
45
|
},
|
@@ -121,7 +121,7 @@ class DataParallelController:
|
|
121
121
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
122
122
|
)
|
123
123
|
threads.append(thread)
|
124
|
-
base_gpu_id += server_args.tp_size
|
124
|
+
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
|
125
125
|
|
126
126
|
# Free all sockets before starting the threads to launch TP workers
|
127
127
|
for sock in sockets:
|
@@ -177,7 +177,11 @@ class DataParallelController:
|
|
177
177
|
rank_port_args.nccl_port = port_args.nccl_port
|
178
178
|
|
179
179
|
reader, writer = mp.Pipe(duplex=False)
|
180
|
-
gpu_id =
|
180
|
+
gpu_id = (
|
181
|
+
server_args.base_gpu_id
|
182
|
+
+ base_gpu_id
|
183
|
+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
184
|
+
)
|
181
185
|
proc = mp.Process(
|
182
186
|
target=run_scheduler_process,
|
183
187
|
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|