sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -247,7 +247,7 @@ class Scheduler(
247
247
  self.pp_size = server_args.pp_size
248
248
  self.dp_size = server_args.dp_size
249
249
  self.schedule_policy = server_args.schedule_policy
250
- self.lora_paths = server_args.lora_paths
250
+ self.enable_lora = server_args.enable_lora
251
251
  self.max_loras_per_batch = server_args.max_loras_per_batch
252
252
  self.enable_overlap = not server_args.disable_overlap_schedule
253
253
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
@@ -458,7 +458,10 @@ class Scheduler(
458
458
  self.grammar_queue: List[Req] = []
459
459
  if not server_args.skip_tokenizer_init:
460
460
  self.grammar_backend = create_grammar_backend(
461
- server_args, self.tokenizer, self.model_config.vocab_size
461
+ server_args,
462
+ self.tokenizer,
463
+ self.model_config.vocab_size,
464
+ self.model_config.hf_eos_token_id,
462
465
  )
463
466
  else:
464
467
  self.grammar_backend = None
@@ -653,6 +656,9 @@ class Scheduler(
653
656
  )
654
657
  )
655
658
 
659
+ embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
660
+ init_embedding_cache(embedding_cache_size * 1024 * 1024)
661
+
656
662
  def init_profier(self):
657
663
  self.torch_profiler = None
658
664
  self.torch_profiler_output_dir: Optional[str] = None
@@ -1126,6 +1132,7 @@ class Scheduler(
1126
1132
  bootstrap_port=recv_req.bootstrap_port,
1127
1133
  bootstrap_room=recv_req.bootstrap_room,
1128
1134
  data_parallel_rank=recv_req.data_parallel_rank,
1135
+ vocab_size=self.model_config.vocab_size,
1129
1136
  )
1130
1137
  req.tokenizer = self.tokenizer
1131
1138
 
@@ -1392,8 +1399,10 @@ class Scheduler(
1392
1399
  logger.info(f)
1393
1400
 
1394
1401
  if self.enable_metrics:
1395
- cache_hit_rate = adder.log_hit_tokens / (
1396
- adder.log_input_tokens + adder.log_hit_tokens
1402
+ total_tokens = adder.log_input_tokens + adder.log_hit_tokens
1403
+
1404
+ cache_hit_rate = (
1405
+ adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
1397
1406
  )
1398
1407
  self.stats.num_running_reqs = running_bs
1399
1408
  self.stats.num_used_tokens = num_used
@@ -1706,13 +1715,13 @@ class Scheduler(
1706
1715
  self.chunked_req.init_next_round_input()
1707
1716
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
1708
1717
 
1709
- if self.lora_paths:
1718
+ if self.enable_lora:
1710
1719
  lora_set = set([req.lora_path for req in self.running_batch.reqs])
1711
1720
 
1712
1721
  # Get requests from the waiting queue to a new prefill batch
1713
1722
  for req in self.waiting_queue:
1714
1723
  if (
1715
- self.lora_paths
1724
+ self.enable_lora
1716
1725
  and len(
1717
1726
  lora_set
1718
1727
  | set([req.lora_path for req in adder.can_run_list])
@@ -2431,6 +2440,37 @@ class Scheduler(
2431
2440
  req.grammar.cancel()
2432
2441
  req.set_finish_with_abort("Aborted by AbortReq.")
2433
2442
 
2443
+ # Delete requests not in the waiting queue when PD disaggregation is enabled
2444
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
2445
+ # Abort requests that have not yet been bootstrapped
2446
+ for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2447
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2448
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2449
+ if hasattr(req.disagg_kv_sender, "abort"):
2450
+ req.disagg_kv_sender.abort()
2451
+
2452
+ # Abort in-flight requests
2453
+ for i, req in enumerate(self.disagg_prefill_inflight_queue):
2454
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2455
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2456
+ if hasattr(req.disagg_kv_sender, "abort"):
2457
+ req.disagg_kv_sender.abort()
2458
+
2459
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
2460
+ # Abort requests that have not yet finished preallocation
2461
+ for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2462
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2463
+ if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2464
+ if hasattr(decode_req.kv_receiver, "abort"):
2465
+ decode_req.kv_receiver.abort()
2466
+
2467
+ # Abort requests waiting for kvcache to release tree cache
2468
+ for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2469
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2470
+ if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2471
+ if hasattr(decode_req.kv_receiver, "abort"):
2472
+ decode_req.kv_receiver.abort()
2473
+
2434
2474
  # Delete requests in the running batch
2435
2475
  if self.cur_batch is self.running_batch or self.cur_batch is None:
2436
2476
  reqs = self.running_batch.reqs
@@ -2466,12 +2506,6 @@ class Scheduler(
2466
2506
  """In-place loading a new lora adapter from disk or huggingface."""
2467
2507
 
2468
2508
  result = self.tp_worker.load_lora_adapter(recv_req)
2469
-
2470
- if result.success:
2471
- flush_cache_success = self.flush_cache()
2472
- assert flush_cache_success, "Cache flush failed after loading lora adapter."
2473
- else:
2474
- logger.error(result.error_message)
2475
2509
  return result
2476
2510
 
2477
2511
  def unload_lora_adapter(
@@ -2480,14 +2514,6 @@ class Scheduler(
2480
2514
  """Unload the lora adapter."""
2481
2515
 
2482
2516
  result = self.tp_worker.unload_lora_adapter(recv_req)
2483
-
2484
- if result.success:
2485
- flush_cache_success = self.flush_cache()
2486
- assert (
2487
- flush_cache_success
2488
- ), "Cache flush failed after unloading LoRA weights"
2489
- else:
2490
- logger.error(result.error_message)
2491
2517
  return result
2492
2518
 
2493
2519
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
@@ -2909,9 +2935,9 @@ def run_scheduler_process(
2909
2935
  prefix += f" PP{pp_rank}"
2910
2936
 
2911
2937
  # Config the process
2912
- kill_itself_when_parent_died()
2913
2938
  setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2914
2939
  faulthandler.enable()
2940
+ kill_itself_when_parent_died()
2915
2941
  parent_process = psutil.Process().parent()
2916
2942
 
2917
2943
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
@@ -2926,10 +2952,6 @@ def run_scheduler_process(
2926
2952
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2927
2953
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2928
2954
 
2929
- embedding_cache_size = 100
2930
- if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
2931
- embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
2932
- init_embedding_cache(embedding_cache_size * 1024 * 1024)
2933
2955
  # Create a scheduler and run the event loop
2934
2956
  try:
2935
2957
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
@@ -2940,8 +2962,8 @@ def run_scheduler_process(
2940
2962
  "max_req_input_len": scheduler.max_req_input_len,
2941
2963
  }
2942
2964
  )
2943
- disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2944
2965
 
2966
+ disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2945
2967
  if disaggregation_mode == DisaggregationMode.NULL:
2946
2968
  if server_args.pp_size > 1:
2947
2969
  scheduler.event_loop_pp()
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
62
62
  get_tokenizer,
63
63
  get_tokenizer_from_processor,
64
64
  )
65
+ from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
65
66
  from sglang.srt.managers.io_struct import (
66
67
  AbortReq,
67
68
  BatchEmbeddingOut,
@@ -111,6 +112,7 @@ from sglang.srt.managers.io_struct import (
111
112
  UpdateWeightsFromTensorReqInput,
112
113
  UpdateWeightsFromTensorReqOutput,
113
114
  )
115
+ from sglang.srt.managers.mm_utils import TensorTransportMode
114
116
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
115
117
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
116
118
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -165,6 +167,16 @@ class ReqState:
165
167
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
166
168
 
167
169
 
170
+ def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
171
+ is_cross_node = server_args.dist_init_addr
172
+
173
+ if is_cross_node:
174
+ # Fallback to default CPU transport for multi-node
175
+ return "default"
176
+ else:
177
+ return "cuda_ipc"
178
+
179
+
168
180
  class TokenizerManager:
169
181
  """TokenizerManager is a process that tokenizes the text."""
170
182
 
@@ -215,12 +227,13 @@ class TokenizerManager:
215
227
  revision=server_args.revision,
216
228
  use_fast=not server_args.disable_fast_image_processor,
217
229
  )
230
+ transport_mode = _determine_tensor_transport_mode(self.server_args)
218
231
 
219
232
  # We want to parallelize the image pre-processing so we create an executor for it
220
233
  # We create mm_processor for any skip_tokenizer_init to make sure we still encode
221
234
  # images even with skip_tokenizer_init=False.
222
235
  self.mm_processor = get_mm_processor(
223
- self.model_config.hf_config, server_args, _processor
236
+ self.model_config.hf_config, server_args, _processor, transport_mode
224
237
  )
225
238
 
226
239
  if server_args.skip_tokenizer_init:
@@ -242,11 +255,11 @@ class TokenizerManager:
242
255
  revision=server_args.revision,
243
256
  )
244
257
 
245
- # Initialize loaded loRA adapters with the initial lora paths in the server_args.
246
- # This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
247
- self.loaded_lora_adapters: Dict[str, str] = dict(
248
- self.server_args.lora_paths or {}
249
- )
258
+ # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
259
+ # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
260
+ # serves as the source of truth for available adapters and maps user-friendly LoRA names
261
+ # to internally used unique LoRA IDs.
262
+ self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
250
263
 
251
264
  # Store states
252
265
  self.no_create_loop = False
@@ -269,6 +282,11 @@ class TokenizerManager:
269
282
  None
270
283
  )
271
284
 
285
+ # Lock to serialize LoRA update operations.
286
+ # Please note that, unlike `model_update_lock`, this does not block inference, allowing
287
+ # LoRA updates and inference to overlap.
288
+ self.lora_update_lock = asyncio.Lock()
289
+
272
290
  # For pd disaggregtion
273
291
  self.disaggregation_mode = DisaggregationMode(
274
292
  self.server_args.disaggregation_mode
@@ -523,6 +541,11 @@ class TokenizerManager:
523
541
  else:
524
542
  mm_inputs = None
525
543
 
544
+ if self.server_args.enable_lora and obj.lora_path:
545
+ # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
546
+ # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
547
+ obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
548
+
526
549
  self._validate_one_request(obj, input_ids)
527
550
  return self._create_tokenized_object(
528
551
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -574,8 +597,6 @@ class TokenizerManager:
574
597
  "The server is not configured to enable custom logit processor. "
575
598
  "Please set `--enable-custom-logits-processor` to enable this feature."
576
599
  )
577
- if self.server_args.enable_lora and obj.lora_path:
578
- self._validate_lora_adapters(obj)
579
600
 
580
601
  def _validate_input_ids_in_vocab(
581
602
  self, input_ids: List[int], vocab_size: int
@@ -689,21 +710,6 @@ class TokenizerManager:
689
710
  "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
690
711
  )
691
712
 
692
- def _validate_lora_adapters(self, obj: GenerateReqInput):
693
- """Validate that the requested LoRA adapters are loaded."""
694
- requested_adapters = (
695
- set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
696
- )
697
- loaded_adapters = (
698
- self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
699
- )
700
- unloaded_adapters = requested_adapters - loaded_adapters
701
- if unloaded_adapters:
702
- raise ValueError(
703
- f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
704
- f"Loaded adapters: {loaded_adapters}."
705
- )
706
-
707
713
  def _send_one_request(
708
714
  self,
709
715
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -747,6 +753,10 @@ class TokenizerManager:
747
753
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
748
754
  logger.info(msg)
749
755
 
756
+ # Mark ongoing LoRA request as finished.
757
+ if self.server_args.enable_lora and obj.lora_path:
758
+ await self.lora_registry.release(obj.lora_path)
759
+
750
760
  # Check if this was an abort/error created by scheduler
751
761
  if isinstance(out["meta_info"].get("finish_reason"), dict):
752
762
  finish_reason = out["meta_info"]["finish_reason"]
@@ -1053,9 +1063,21 @@ class TokenizerManager:
1053
1063
  obj.lora_path,
1054
1064
  )
1055
1065
 
1056
- async with self.model_update_lock.writer_lock:
1066
+ async with self.lora_update_lock:
1067
+ # Generate new uniquely identifiable LoRARef object.
1068
+ new_adapter = LoRARef(
1069
+ lora_name=obj.lora_name,
1070
+ lora_path=obj.lora_path,
1071
+ )
1072
+
1073
+ # Trigger the actual loading operation at the backend processes.
1074
+ obj.lora_id = new_adapter.lora_id
1057
1075
  result = (await self.update_lora_adapter_communicator(obj))[0]
1058
- self.loaded_lora_adapters = result.loaded_adapters
1076
+
1077
+ # Register the LoRA adapter only after loading is successful.
1078
+ if result.success:
1079
+ await self.lora_registry.register(new_adapter)
1080
+
1059
1081
  return result
1060
1082
 
1061
1083
  async def unload_lora_adapter(
@@ -1069,6 +1091,10 @@ class TokenizerManager:
1069
1091
  "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1070
1092
  )
1071
1093
 
1094
+ assert (
1095
+ obj.lora_name is not None
1096
+ ), "lora_name must be provided to unload LoRA adapter"
1097
+
1072
1098
  # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1073
1099
  # with dp_size > 1.
1074
1100
  assert (
@@ -1079,9 +1105,17 @@ class TokenizerManager:
1079
1105
  obj.lora_name,
1080
1106
  )
1081
1107
 
1082
- async with self.model_update_lock.writer_lock:
1108
+ async with self.lora_update_lock:
1109
+ # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1110
+ # from being started.
1111
+ lora_id = await self.lora_registry.unregister(obj.lora_name)
1112
+ obj.lora_id = lora_id
1113
+
1114
+ # Initiate the actual unloading operation at the backend processes only after all
1115
+ # ongoing requests using this LoRA adapter are finished.
1116
+ await self.lora_registry.wait_for_unload(lora_id)
1083
1117
  result = (await self.update_lora_adapter_communicator(obj))[0]
1084
- self.loaded_lora_adapters = result.loaded_adapters
1118
+
1085
1119
  return result
1086
1120
 
1087
1121
  async def get_weights_by_name(
@@ -1309,7 +1343,7 @@ class TokenizerManager:
1309
1343
  filename = os.path.join(
1310
1344
  self.crash_dump_folder,
1311
1345
  os.getenv("HOSTNAME", None),
1312
- f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
1346
+ f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
1313
1347
  )
1314
1348
 
1315
1349
  os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
43
43
  from sglang.srt.model_executor.model_runner import ModelRunner
44
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
44
45
  from sglang.srt.server_args import ServerArgs
45
46
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
46
47
 
@@ -278,6 +279,8 @@ class TpModelWorker:
278
279
  return success, message
279
280
 
280
281
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
282
+
283
+ monkey_patch_torch_reductions()
281
284
  success, message = self.model_runner.update_weights_from_tensor(
282
285
  named_tensors=MultiprocessingSerializer.deserialize(
283
286
  recv_req.serialized_named_tensors[self.tp_rank]
@@ -293,11 +296,9 @@ class TpModelWorker:
293
296
  return parameter
294
297
 
295
298
  def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
296
- result = self.model_runner.load_lora_adapter(
297
- recv_req.lora_name, recv_req.lora_path
298
- )
299
+ result = self.model_runner.load_lora_adapter(recv_req.to_ref())
299
300
  return result
300
301
 
301
302
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
302
- result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
303
+ result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
303
304
  return result
@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
51
51
  self._kvcache = kvcache
52
52
 
53
53
  self.free_pages = None
54
+ self.release_pages = None
54
55
  self.is_not_in_free_group = True
55
56
  self.free_group = []
56
57
 
@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
58
59
  return ""
59
60
 
60
61
  def available_size(self):
61
- return len(self.free_pages) * self.page_size
62
+ return (len(self.free_pages) + len(self.release_pages)) * self.page_size
62
63
 
63
64
  def get_kvcache(self):
64
65
  return self._kvcache
65
66
 
66
- def restore_state(self, free_pages):
67
- self.free_pages = free_pages
67
+ def restore_state(self, state):
68
+ self.free_pages, self.release_pages = state
68
69
 
69
70
  def backup_state(self):
70
- return self.free_pages
71
+ return (self.free_pages, self.release_pages)
71
72
 
72
73
  def free_group_begin(self):
73
74
  self.is_not_in_free_group = False
@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
78
79
  if self.free_group:
79
80
  self.free(torch.cat(self.free_group))
80
81
 
82
+ def merge_and_sort_free(self):
83
+ if len(self.release_pages) > 0:
84
+ self.free_pages = torch.cat((self.free_pages, self.release_pages))
85
+ self.free_pages, _ = torch.sort(self.free_pages)
86
+ self.release_pages = torch.empty(
87
+ (0,), dtype=self.release_pages.dtype, device=self.device
88
+ )
89
+
81
90
  def get_cpu_copy(self, *args, **kwargs):
82
91
  # FIXME: reuse the get_cpu_copy after paged allocator is implemented
83
92
  raise NotImplementedError()
@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
119
128
  )
120
129
  self.is_not_in_free_group = True
121
130
  self.free_group = []
131
+ self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
122
132
 
123
133
  def available_size(self):
124
134
  # To avoid minor "len(free_pages) * 1" overhead
125
- return len(self.free_pages)
135
+ return len(self.free_pages) + len(self.release_pages)
126
136
 
127
137
  def alloc(self, need_size: int):
138
+ if need_size > len(self.free_pages):
139
+ self.merge_and_sort_free()
128
140
  if need_size > len(self.free_pages):
129
141
  return None
130
142
 
@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
137
149
  return
138
150
 
139
151
  if self.is_not_in_free_group:
140
- self.free_pages = torch.cat((self.free_pages, free_index))
152
+ self.release_pages = torch.cat((self.release_pages, free_index))
141
153
  else:
142
154
  self.free_group.append(free_index)
143
155
 
@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
421
433
  ), "The allocation size should be page-aligned"
422
434
 
423
435
  num_pages = need_size // self.page_size
436
+ if num_pages > len(self.free_pages):
437
+ self.merge_and_sort_free()
424
438
  if num_pages > len(self.free_pages):
425
439
  return None
426
440
 
@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
446
460
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
447
461
  )
448
462
 
463
+ estimated_num_new_pages = (
464
+ (
465
+ (seq_lens + self.page_size - 1) // self.page_size
466
+ - (prefix_lens + self.page_size - 1) // self.page_size
467
+ )
468
+ .sum()
469
+ .item()
470
+ )
471
+ if estimated_num_new_pages > len(self.free_pages):
472
+ self.merge_and_sort_free()
473
+
449
474
  bs = len(prefix_lens)
450
475
  out_indices = torch.empty(
451
476
  (extend_num_tokens,), dtype=torch.int64, device=self.device
@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
483
508
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
484
509
  )
485
510
 
511
+ estimated_num_new_pages = (
512
+ (
513
+ (seq_lens + self.page_size - 1) // self.page_size
514
+ - (seq_lens - 1 + self.page_size - 1) // self.page_size
515
+ )
516
+ .sum()
517
+ .item()
518
+ )
519
+ if estimated_num_new_pages > len(self.free_pages):
520
+ self.merge_and_sort_free()
521
+
486
522
  bs = len(seq_lens)
487
523
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
488
524
  alloc_decode_kernel[(bs,)](
@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
511
547
 
512
548
  if self.is_not_in_free_group:
513
549
  free_page_indices = torch.unique(free_index // self.page_size)
514
- self.free_pages = torch.cat((free_page_indices, self.free_pages))
550
+ self.release_pages = torch.cat((free_page_indices, self.release_pages))
515
551
  else:
516
552
  self.free_group.append(free_index)
517
553
 
@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
525
561
  )
526
562
  self.is_not_in_free_group = True
527
563
  self.free_group = []
564
+ self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
528
565
 
529
566
  def get_cpu_copy(self, indices):
530
567
  return self._kvcache.get_cpu_copy(indices)
@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
633
670
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
634
671
  )
635
672
 
673
+ estimated_num_new_pages = (
674
+ (
675
+ (seq_lens + self.page_size - 1) // self.page_size
676
+ - (prefix_lens + self.page_size - 1) // self.page_size
677
+ )
678
+ .sum()
679
+ .item()
680
+ )
681
+ if estimated_num_new_pages > len(self.free_pages):
682
+ self.merge_and_sort_free()
683
+
636
684
  bs = len(prefix_lens)
637
685
  out_indices = torch.empty(
638
686
  (extend_num_tokens,), dtype=torch.int32, device=self.device
@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
668
716
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
669
717
  )
670
718
 
719
+ estimated_num_new_pages = (
720
+ (
721
+ (seq_lens + self.page_size - 1) // self.page_size
722
+ - (seq_lens - 1 + self.page_size - 1) // self.page_size
723
+ )
724
+ .sum()
725
+ .item()
726
+ )
727
+ if estimated_num_new_pages > len(self.free_pages):
728
+ self.merge_and_sort_free()
729
+
671
730
  bs = len(seq_lens)
672
731
  out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
673
732
 
@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
692
751
  def clear(self):
693
752
  super().clear()
694
753
  self.free_pages = self.free_pages.to(torch.int32)
754
+ self.release_pages = self.release_pages.to(torch.int32)
@@ -9,6 +9,12 @@ import torch
9
9
  logger = logging.getLogger(__name__)
10
10
 
11
11
 
12
+ from sglang.srt.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
16
+
17
+
12
18
  def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
13
19
  hasher = hashlib.sha256()
14
20
 
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
80
86
 
81
87
  def __init__(self, file_path: str = "/tmp/hicache"):
82
88
  self.file_path = file_path
83
- if not os.path.exists(self.file_path):
89
+ tp_rank = get_tensor_model_parallel_rank()
90
+ tp_size = get_tensor_model_parallel_world_size()
91
+ self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
92
+ if not os.path.exists(self.file_path) and tp_rank == 0:
84
93
  os.makedirs(self.file_path)
85
94
  logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
86
95
 
96
+ def _get_suffixed_key(self, key: str) -> str:
97
+ return key + self.tp_suffix
98
+
87
99
  def get(
88
100
  self, key: str, target_location: Optional[torch.Tensor] = None
89
101
  ) -> torch.Tensor | None:
102
+ key = self._get_suffixed_key(key)
90
103
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
91
104
  try:
92
105
  # todo: fixing the target_location logic to enable in-place loading
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
112
125
  ]
113
126
 
114
127
  def set(self, key: str, value: torch.Tensor) -> bool:
128
+ key = self._get_suffixed_key(key)
115
129
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
116
130
  if self.exists(key):
117
131
  logger.debug(f"Key {key} already exists. Skipped.")
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
130
144
  return True
131
145
 
132
146
  def exists(self, key: str) -> bool:
147
+ key = self._get_suffixed_key(key)
133
148
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
134
149
  return os.path.exists(tensor_path)
135
150
 
136
151
  def delete(self, key: str) -> None:
152
+ key = self._get_suffixed_key(key)
137
153
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
138
154
  try:
139
155
  os.remove(tensor_path)