sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -247,7 +247,7 @@ class Scheduler(
|
|
247
247
|
self.pp_size = server_args.pp_size
|
248
248
|
self.dp_size = server_args.dp_size
|
249
249
|
self.schedule_policy = server_args.schedule_policy
|
250
|
-
self.
|
250
|
+
self.enable_lora = server_args.enable_lora
|
251
251
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
252
252
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
253
253
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
@@ -458,7 +458,10 @@ class Scheduler(
|
|
458
458
|
self.grammar_queue: List[Req] = []
|
459
459
|
if not server_args.skip_tokenizer_init:
|
460
460
|
self.grammar_backend = create_grammar_backend(
|
461
|
-
server_args,
|
461
|
+
server_args,
|
462
|
+
self.tokenizer,
|
463
|
+
self.model_config.vocab_size,
|
464
|
+
self.model_config.hf_eos_token_id,
|
462
465
|
)
|
463
466
|
else:
|
464
467
|
self.grammar_backend = None
|
@@ -653,6 +656,9 @@ class Scheduler(
|
|
653
656
|
)
|
654
657
|
)
|
655
658
|
|
659
|
+
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
660
|
+
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
661
|
+
|
656
662
|
def init_profier(self):
|
657
663
|
self.torch_profiler = None
|
658
664
|
self.torch_profiler_output_dir: Optional[str] = None
|
@@ -1126,6 +1132,7 @@ class Scheduler(
|
|
1126
1132
|
bootstrap_port=recv_req.bootstrap_port,
|
1127
1133
|
bootstrap_room=recv_req.bootstrap_room,
|
1128
1134
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1135
|
+
vocab_size=self.model_config.vocab_size,
|
1129
1136
|
)
|
1130
1137
|
req.tokenizer = self.tokenizer
|
1131
1138
|
|
@@ -1392,8 +1399,10 @@ class Scheduler(
|
|
1392
1399
|
logger.info(f)
|
1393
1400
|
|
1394
1401
|
if self.enable_metrics:
|
1395
|
-
|
1396
|
-
|
1402
|
+
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
1403
|
+
|
1404
|
+
cache_hit_rate = (
|
1405
|
+
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
1397
1406
|
)
|
1398
1407
|
self.stats.num_running_reqs = running_bs
|
1399
1408
|
self.stats.num_used_tokens = num_used
|
@@ -1706,13 +1715,13 @@ class Scheduler(
|
|
1706
1715
|
self.chunked_req.init_next_round_input()
|
1707
1716
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1708
1717
|
|
1709
|
-
if self.
|
1718
|
+
if self.enable_lora:
|
1710
1719
|
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
1711
1720
|
|
1712
1721
|
# Get requests from the waiting queue to a new prefill batch
|
1713
1722
|
for req in self.waiting_queue:
|
1714
1723
|
if (
|
1715
|
-
self.
|
1724
|
+
self.enable_lora
|
1716
1725
|
and len(
|
1717
1726
|
lora_set
|
1718
1727
|
| set([req.lora_path for req in adder.can_run_list])
|
@@ -2431,6 +2440,37 @@ class Scheduler(
|
|
2431
2440
|
req.grammar.cancel()
|
2432
2441
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
2433
2442
|
|
2443
|
+
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2444
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2445
|
+
# Abort requests that have not yet been bootstrapped
|
2446
|
+
for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
|
2447
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2448
|
+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2449
|
+
if hasattr(req.disagg_kv_sender, "abort"):
|
2450
|
+
req.disagg_kv_sender.abort()
|
2451
|
+
|
2452
|
+
# Abort in-flight requests
|
2453
|
+
for i, req in enumerate(self.disagg_prefill_inflight_queue):
|
2454
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2455
|
+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2456
|
+
if hasattr(req.disagg_kv_sender, "abort"):
|
2457
|
+
req.disagg_kv_sender.abort()
|
2458
|
+
|
2459
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2460
|
+
# Abort requests that have not yet finished preallocation
|
2461
|
+
for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
|
2462
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2463
|
+
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2464
|
+
if hasattr(decode_req.kv_receiver, "abort"):
|
2465
|
+
decode_req.kv_receiver.abort()
|
2466
|
+
|
2467
|
+
# Abort requests waiting for kvcache to release tree cache
|
2468
|
+
for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
|
2469
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2470
|
+
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2471
|
+
if hasattr(decode_req.kv_receiver, "abort"):
|
2472
|
+
decode_req.kv_receiver.abort()
|
2473
|
+
|
2434
2474
|
# Delete requests in the running batch
|
2435
2475
|
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
2436
2476
|
reqs = self.running_batch.reqs
|
@@ -2466,12 +2506,6 @@ class Scheduler(
|
|
2466
2506
|
"""In-place loading a new lora adapter from disk or huggingface."""
|
2467
2507
|
|
2468
2508
|
result = self.tp_worker.load_lora_adapter(recv_req)
|
2469
|
-
|
2470
|
-
if result.success:
|
2471
|
-
flush_cache_success = self.flush_cache()
|
2472
|
-
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
2473
|
-
else:
|
2474
|
-
logger.error(result.error_message)
|
2475
2509
|
return result
|
2476
2510
|
|
2477
2511
|
def unload_lora_adapter(
|
@@ -2480,14 +2514,6 @@ class Scheduler(
|
|
2480
2514
|
"""Unload the lora adapter."""
|
2481
2515
|
|
2482
2516
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2483
|
-
|
2484
|
-
if result.success:
|
2485
|
-
flush_cache_success = self.flush_cache()
|
2486
|
-
assert (
|
2487
|
-
flush_cache_success
|
2488
|
-
), "Cache flush failed after unloading LoRA weights"
|
2489
|
-
else:
|
2490
|
-
logger.error(result.error_message)
|
2491
2517
|
return result
|
2492
2518
|
|
2493
2519
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
@@ -2909,9 +2935,9 @@ def run_scheduler_process(
|
|
2909
2935
|
prefix += f" PP{pp_rank}"
|
2910
2936
|
|
2911
2937
|
# Config the process
|
2912
|
-
kill_itself_when_parent_died()
|
2913
2938
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
2914
2939
|
faulthandler.enable()
|
2940
|
+
kill_itself_when_parent_died()
|
2915
2941
|
parent_process = psutil.Process().parent()
|
2916
2942
|
|
2917
2943
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
@@ -2926,10 +2952,6 @@ def run_scheduler_process(
|
|
2926
2952
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2927
2953
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2928
2954
|
|
2929
|
-
embedding_cache_size = 100
|
2930
|
-
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
2931
|
-
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
2932
|
-
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
2933
2955
|
# Create a scheduler and run the event loop
|
2934
2956
|
try:
|
2935
2957
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
@@ -2940,8 +2962,8 @@ def run_scheduler_process(
|
|
2940
2962
|
"max_req_input_len": scheduler.max_req_input_len,
|
2941
2963
|
}
|
2942
2964
|
)
|
2943
|
-
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2944
2965
|
|
2966
|
+
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2945
2967
|
if disaggregation_mode == DisaggregationMode.NULL:
|
2946
2968
|
if server_args.pp_size > 1:
|
2947
2969
|
scheduler.event_loop_pp()
|
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
62
62
|
get_tokenizer,
|
63
63
|
get_tokenizer_from_processor,
|
64
64
|
)
|
65
|
+
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
65
66
|
from sglang.srt.managers.io_struct import (
|
66
67
|
AbortReq,
|
67
68
|
BatchEmbeddingOut,
|
@@ -111,6 +112,7 @@ from sglang.srt.managers.io_struct import (
|
|
111
112
|
UpdateWeightsFromTensorReqInput,
|
112
113
|
UpdateWeightsFromTensorReqOutput,
|
113
114
|
)
|
115
|
+
from sglang.srt.managers.mm_utils import TensorTransportMode
|
114
116
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
115
117
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
116
118
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -165,6 +167,16 @@ class ReqState:
|
|
165
167
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
166
168
|
|
167
169
|
|
170
|
+
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
171
|
+
is_cross_node = server_args.dist_init_addr
|
172
|
+
|
173
|
+
if is_cross_node:
|
174
|
+
# Fallback to default CPU transport for multi-node
|
175
|
+
return "default"
|
176
|
+
else:
|
177
|
+
return "cuda_ipc"
|
178
|
+
|
179
|
+
|
168
180
|
class TokenizerManager:
|
169
181
|
"""TokenizerManager is a process that tokenizes the text."""
|
170
182
|
|
@@ -215,12 +227,13 @@ class TokenizerManager:
|
|
215
227
|
revision=server_args.revision,
|
216
228
|
use_fast=not server_args.disable_fast_image_processor,
|
217
229
|
)
|
230
|
+
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
218
231
|
|
219
232
|
# We want to parallelize the image pre-processing so we create an executor for it
|
220
233
|
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
221
234
|
# images even with skip_tokenizer_init=False.
|
222
235
|
self.mm_processor = get_mm_processor(
|
223
|
-
self.model_config.hf_config, server_args, _processor
|
236
|
+
self.model_config.hf_config, server_args, _processor, transport_mode
|
224
237
|
)
|
225
238
|
|
226
239
|
if server_args.skip_tokenizer_init:
|
@@ -242,11 +255,11 @@ class TokenizerManager:
|
|
242
255
|
revision=server_args.revision,
|
243
256
|
)
|
244
257
|
|
245
|
-
# Initialize
|
246
|
-
#
|
247
|
-
|
248
|
-
|
249
|
-
)
|
258
|
+
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
259
|
+
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
260
|
+
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
261
|
+
# to internally used unique LoRA IDs.
|
262
|
+
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
250
263
|
|
251
264
|
# Store states
|
252
265
|
self.no_create_loop = False
|
@@ -269,6 +282,11 @@ class TokenizerManager:
|
|
269
282
|
None
|
270
283
|
)
|
271
284
|
|
285
|
+
# Lock to serialize LoRA update operations.
|
286
|
+
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
287
|
+
# LoRA updates and inference to overlap.
|
288
|
+
self.lora_update_lock = asyncio.Lock()
|
289
|
+
|
272
290
|
# For pd disaggregtion
|
273
291
|
self.disaggregation_mode = DisaggregationMode(
|
274
292
|
self.server_args.disaggregation_mode
|
@@ -523,6 +541,11 @@ class TokenizerManager:
|
|
523
541
|
else:
|
524
542
|
mm_inputs = None
|
525
543
|
|
544
|
+
if self.server_args.enable_lora and obj.lora_path:
|
545
|
+
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
546
|
+
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
547
|
+
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
548
|
+
|
526
549
|
self._validate_one_request(obj, input_ids)
|
527
550
|
return self._create_tokenized_object(
|
528
551
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
@@ -574,8 +597,6 @@ class TokenizerManager:
|
|
574
597
|
"The server is not configured to enable custom logit processor. "
|
575
598
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
576
599
|
)
|
577
|
-
if self.server_args.enable_lora and obj.lora_path:
|
578
|
-
self._validate_lora_adapters(obj)
|
579
600
|
|
580
601
|
def _validate_input_ids_in_vocab(
|
581
602
|
self, input_ids: List[int], vocab_size: int
|
@@ -689,21 +710,6 @@ class TokenizerManager:
|
|
689
710
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
690
711
|
)
|
691
712
|
|
692
|
-
def _validate_lora_adapters(self, obj: GenerateReqInput):
|
693
|
-
"""Validate that the requested LoRA adapters are loaded."""
|
694
|
-
requested_adapters = (
|
695
|
-
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
|
696
|
-
)
|
697
|
-
loaded_adapters = (
|
698
|
-
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
|
699
|
-
)
|
700
|
-
unloaded_adapters = requested_adapters - loaded_adapters
|
701
|
-
if unloaded_adapters:
|
702
|
-
raise ValueError(
|
703
|
-
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
|
704
|
-
f"Loaded adapters: {loaded_adapters}."
|
705
|
-
)
|
706
|
-
|
707
713
|
def _send_one_request(
|
708
714
|
self,
|
709
715
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -747,6 +753,10 @@ class TokenizerManager:
|
|
747
753
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
748
754
|
logger.info(msg)
|
749
755
|
|
756
|
+
# Mark ongoing LoRA request as finished.
|
757
|
+
if self.server_args.enable_lora and obj.lora_path:
|
758
|
+
await self.lora_registry.release(obj.lora_path)
|
759
|
+
|
750
760
|
# Check if this was an abort/error created by scheduler
|
751
761
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
752
762
|
finish_reason = out["meta_info"]["finish_reason"]
|
@@ -1053,9 +1063,21 @@ class TokenizerManager:
|
|
1053
1063
|
obj.lora_path,
|
1054
1064
|
)
|
1055
1065
|
|
1056
|
-
async with self.
|
1066
|
+
async with self.lora_update_lock:
|
1067
|
+
# Generate new uniquely identifiable LoRARef object.
|
1068
|
+
new_adapter = LoRARef(
|
1069
|
+
lora_name=obj.lora_name,
|
1070
|
+
lora_path=obj.lora_path,
|
1071
|
+
)
|
1072
|
+
|
1073
|
+
# Trigger the actual loading operation at the backend processes.
|
1074
|
+
obj.lora_id = new_adapter.lora_id
|
1057
1075
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1058
|
-
|
1076
|
+
|
1077
|
+
# Register the LoRA adapter only after loading is successful.
|
1078
|
+
if result.success:
|
1079
|
+
await self.lora_registry.register(new_adapter)
|
1080
|
+
|
1059
1081
|
return result
|
1060
1082
|
|
1061
1083
|
async def unload_lora_adapter(
|
@@ -1069,6 +1091,10 @@ class TokenizerManager:
|
|
1069
1091
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1070
1092
|
)
|
1071
1093
|
|
1094
|
+
assert (
|
1095
|
+
obj.lora_name is not None
|
1096
|
+
), "lora_name must be provided to unload LoRA adapter"
|
1097
|
+
|
1072
1098
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1073
1099
|
# with dp_size > 1.
|
1074
1100
|
assert (
|
@@ -1079,9 +1105,17 @@ class TokenizerManager:
|
|
1079
1105
|
obj.lora_name,
|
1080
1106
|
)
|
1081
1107
|
|
1082
|
-
async with self.
|
1108
|
+
async with self.lora_update_lock:
|
1109
|
+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1110
|
+
# from being started.
|
1111
|
+
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1112
|
+
obj.lora_id = lora_id
|
1113
|
+
|
1114
|
+
# Initiate the actual unloading operation at the backend processes only after all
|
1115
|
+
# ongoing requests using this LoRA adapter are finished.
|
1116
|
+
await self.lora_registry.wait_for_unload(lora_id)
|
1083
1117
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1084
|
-
|
1118
|
+
|
1085
1119
|
return result
|
1086
1120
|
|
1087
1121
|
async def get_weights_by_name(
|
@@ -1309,7 +1343,7 @@ class TokenizerManager:
|
|
1309
1343
|
filename = os.path.join(
|
1310
1344
|
self.crash_dump_folder,
|
1311
1345
|
os.getenv("HOSTNAME", None),
|
1312
|
-
f
|
1346
|
+
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
|
1313
1347
|
)
|
1314
1348
|
|
1315
1349
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
41
41
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
43
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
44
45
|
from sglang.srt.server_args import ServerArgs
|
45
46
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
46
47
|
|
@@ -278,6 +279,8 @@ class TpModelWorker:
|
|
278
279
|
return success, message
|
279
280
|
|
280
281
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
282
|
+
|
283
|
+
monkey_patch_torch_reductions()
|
281
284
|
success, message = self.model_runner.update_weights_from_tensor(
|
282
285
|
named_tensors=MultiprocessingSerializer.deserialize(
|
283
286
|
recv_req.serialized_named_tensors[self.tp_rank]
|
@@ -293,11 +296,9 @@ class TpModelWorker:
|
|
293
296
|
return parameter
|
294
297
|
|
295
298
|
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
296
|
-
result = self.model_runner.load_lora_adapter(
|
297
|
-
recv_req.lora_name, recv_req.lora_path
|
298
|
-
)
|
299
|
+
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
299
300
|
return result
|
300
301
|
|
301
302
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
302
|
-
result = self.model_runner.unload_lora_adapter(recv_req.
|
303
|
+
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
303
304
|
return result
|
@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
51
51
|
self._kvcache = kvcache
|
52
52
|
|
53
53
|
self.free_pages = None
|
54
|
+
self.release_pages = None
|
54
55
|
self.is_not_in_free_group = True
|
55
56
|
self.free_group = []
|
56
57
|
|
@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
58
59
|
return ""
|
59
60
|
|
60
61
|
def available_size(self):
|
61
|
-
return len(self.free_pages) * self.page_size
|
62
|
+
return (len(self.free_pages) + len(self.release_pages)) * self.page_size
|
62
63
|
|
63
64
|
def get_kvcache(self):
|
64
65
|
return self._kvcache
|
65
66
|
|
66
|
-
def restore_state(self,
|
67
|
-
self.free_pages =
|
67
|
+
def restore_state(self, state):
|
68
|
+
self.free_pages, self.release_pages = state
|
68
69
|
|
69
70
|
def backup_state(self):
|
70
|
-
return self.free_pages
|
71
|
+
return (self.free_pages, self.release_pages)
|
71
72
|
|
72
73
|
def free_group_begin(self):
|
73
74
|
self.is_not_in_free_group = False
|
@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
78
79
|
if self.free_group:
|
79
80
|
self.free(torch.cat(self.free_group))
|
80
81
|
|
82
|
+
def merge_and_sort_free(self):
|
83
|
+
if len(self.release_pages) > 0:
|
84
|
+
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
85
|
+
self.free_pages, _ = torch.sort(self.free_pages)
|
86
|
+
self.release_pages = torch.empty(
|
87
|
+
(0,), dtype=self.release_pages.dtype, device=self.device
|
88
|
+
)
|
89
|
+
|
81
90
|
def get_cpu_copy(self, *args, **kwargs):
|
82
91
|
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
83
92
|
raise NotImplementedError()
|
@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
119
128
|
)
|
120
129
|
self.is_not_in_free_group = True
|
121
130
|
self.free_group = []
|
131
|
+
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
122
132
|
|
123
133
|
def available_size(self):
|
124
134
|
# To avoid minor "len(free_pages) * 1" overhead
|
125
|
-
return len(self.free_pages)
|
135
|
+
return len(self.free_pages) + len(self.release_pages)
|
126
136
|
|
127
137
|
def alloc(self, need_size: int):
|
138
|
+
if need_size > len(self.free_pages):
|
139
|
+
self.merge_and_sort_free()
|
128
140
|
if need_size > len(self.free_pages):
|
129
141
|
return None
|
130
142
|
|
@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
137
149
|
return
|
138
150
|
|
139
151
|
if self.is_not_in_free_group:
|
140
|
-
self.
|
152
|
+
self.release_pages = torch.cat((self.release_pages, free_index))
|
141
153
|
else:
|
142
154
|
self.free_group.append(free_index)
|
143
155
|
|
@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
421
433
|
), "The allocation size should be page-aligned"
|
422
434
|
|
423
435
|
num_pages = need_size // self.page_size
|
436
|
+
if num_pages > len(self.free_pages):
|
437
|
+
self.merge_and_sort_free()
|
424
438
|
if num_pages > len(self.free_pages):
|
425
439
|
return None
|
426
440
|
|
@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
446
460
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
447
461
|
)
|
448
462
|
|
463
|
+
estimated_num_new_pages = (
|
464
|
+
(
|
465
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
466
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
467
|
+
)
|
468
|
+
.sum()
|
469
|
+
.item()
|
470
|
+
)
|
471
|
+
if estimated_num_new_pages > len(self.free_pages):
|
472
|
+
self.merge_and_sort_free()
|
473
|
+
|
449
474
|
bs = len(prefix_lens)
|
450
475
|
out_indices = torch.empty(
|
451
476
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
483
508
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
484
509
|
)
|
485
510
|
|
511
|
+
estimated_num_new_pages = (
|
512
|
+
(
|
513
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
514
|
+
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
515
|
+
)
|
516
|
+
.sum()
|
517
|
+
.item()
|
518
|
+
)
|
519
|
+
if estimated_num_new_pages > len(self.free_pages):
|
520
|
+
self.merge_and_sort_free()
|
521
|
+
|
486
522
|
bs = len(seq_lens)
|
487
523
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
488
524
|
alloc_decode_kernel[(bs,)](
|
@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
511
547
|
|
512
548
|
if self.is_not_in_free_group:
|
513
549
|
free_page_indices = torch.unique(free_index // self.page_size)
|
514
|
-
self.
|
550
|
+
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
515
551
|
else:
|
516
552
|
self.free_group.append(free_index)
|
517
553
|
|
@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
525
561
|
)
|
526
562
|
self.is_not_in_free_group = True
|
527
563
|
self.free_group = []
|
564
|
+
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
528
565
|
|
529
566
|
def get_cpu_copy(self, indices):
|
530
567
|
return self._kvcache.get_cpu_copy(indices)
|
@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
633
670
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
634
671
|
)
|
635
672
|
|
673
|
+
estimated_num_new_pages = (
|
674
|
+
(
|
675
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
676
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
677
|
+
)
|
678
|
+
.sum()
|
679
|
+
.item()
|
680
|
+
)
|
681
|
+
if estimated_num_new_pages > len(self.free_pages):
|
682
|
+
self.merge_and_sort_free()
|
683
|
+
|
636
684
|
bs = len(prefix_lens)
|
637
685
|
out_indices = torch.empty(
|
638
686
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
668
716
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
669
717
|
)
|
670
718
|
|
719
|
+
estimated_num_new_pages = (
|
720
|
+
(
|
721
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
722
|
+
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
723
|
+
)
|
724
|
+
.sum()
|
725
|
+
.item()
|
726
|
+
)
|
727
|
+
if estimated_num_new_pages > len(self.free_pages):
|
728
|
+
self.merge_and_sort_free()
|
729
|
+
|
671
730
|
bs = len(seq_lens)
|
672
731
|
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
673
732
|
|
@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
692
751
|
def clear(self):
|
693
752
|
super().clear()
|
694
753
|
self.free_pages = self.free_pages.to(torch.int32)
|
754
|
+
self.release_pages = self.release_pages.to(torch.int32)
|
@@ -9,6 +9,12 @@ import torch
|
|
9
9
|
logger = logging.getLogger(__name__)
|
10
10
|
|
11
11
|
|
12
|
+
from sglang.srt.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
12
18
|
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
13
19
|
hasher = hashlib.sha256()
|
14
20
|
|
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
|
|
80
86
|
|
81
87
|
def __init__(self, file_path: str = "/tmp/hicache"):
|
82
88
|
self.file_path = file_path
|
83
|
-
|
89
|
+
tp_rank = get_tensor_model_parallel_rank()
|
90
|
+
tp_size = get_tensor_model_parallel_world_size()
|
91
|
+
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
|
92
|
+
if not os.path.exists(self.file_path) and tp_rank == 0:
|
84
93
|
os.makedirs(self.file_path)
|
85
94
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
86
95
|
|
96
|
+
def _get_suffixed_key(self, key: str) -> str:
|
97
|
+
return key + self.tp_suffix
|
98
|
+
|
87
99
|
def get(
|
88
100
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
89
101
|
) -> torch.Tensor | None:
|
102
|
+
key = self._get_suffixed_key(key)
|
90
103
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
91
104
|
try:
|
92
105
|
# todo: fixing the target_location logic to enable in-place loading
|
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
|
|
112
125
|
]
|
113
126
|
|
114
127
|
def set(self, key: str, value: torch.Tensor) -> bool:
|
128
|
+
key = self._get_suffixed_key(key)
|
115
129
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
116
130
|
if self.exists(key):
|
117
131
|
logger.debug(f"Key {key} already exists. Skipped.")
|
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
|
|
130
144
|
return True
|
131
145
|
|
132
146
|
def exists(self, key: str) -> bool:
|
147
|
+
key = self._get_suffixed_key(key)
|
133
148
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
134
149
|
return os.path.exists(tensor_path)
|
135
150
|
|
136
151
|
def delete(self, key: str) -> None:
|
152
|
+
key = self._get_suffixed_key(key)
|
137
153
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
138
154
|
try:
|
139
155
|
os.remove(tensor_path)
|