sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/_custom_ops.py +29 -1
  2. sglang/srt/configs/model_config.py +1 -1
  3. sglang/srt/conversation.py +1 -1
  4. sglang/srt/disaggregation/common/conn.py +34 -6
  5. sglang/srt/disaggregation/mini_lb.py +3 -2
  6. sglang/srt/disaggregation/mooncake/conn.py +49 -20
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  8. sglang/srt/disaggregation/nixl/conn.py +17 -13
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  10. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  11. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  12. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  13. sglang/srt/distributed/parallel_state.py +70 -15
  14. sglang/srt/entrypoints/engine.py +2 -8
  15. sglang/srt/entrypoints/http_server.py +20 -32
  16. sglang/srt/entrypoints/openai/protocol.py +3 -3
  17. sglang/srt/entrypoints/openai/serving_chat.py +27 -4
  18. sglang/srt/function_call/base_format_detector.py +74 -12
  19. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  20. sglang/srt/function_call/ebnf_composer.py +95 -63
  21. sglang/srt/function_call/function_call_parser.py +4 -4
  22. sglang/srt/function_call/kimik2_detector.py +41 -16
  23. sglang/srt/function_call/llama32_detector.py +6 -3
  24. sglang/srt/function_call/mistral_detector.py +11 -3
  25. sglang/srt/function_call/pythonic_detector.py +16 -14
  26. sglang/srt/function_call/qwen25_detector.py +12 -3
  27. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
  28. sglang/srt/layers/activation.py +11 -3
  29. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  30. sglang/srt/layers/communicator.py +12 -12
  31. sglang/srt/layers/dp_attention.py +72 -24
  32. sglang/srt/layers/logits_processor.py +34 -24
  33. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  35. sglang/srt/layers/moe/topk.py +5 -13
  36. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  37. sglang/srt/layers/quantization/modelopt_quant.py +8 -4
  38. sglang/srt/layers/quantization/utils.py +0 -9
  39. sglang/srt/layers/radix_attention.py +5 -3
  40. sglang/srt/lora/lora_manager.py +133 -169
  41. sglang/srt/lora/lora_registry.py +124 -0
  42. sglang/srt/lora/mem_pool.py +2 -2
  43. sglang/srt/managers/cache_controller.py +53 -6
  44. sglang/srt/managers/io_struct.py +19 -1
  45. sglang/srt/managers/schedule_batch.py +13 -3
  46. sglang/srt/managers/scheduler.py +13 -25
  47. sglang/srt/managers/tokenizer_manager.py +28 -25
  48. sglang/srt/managers/tp_worker.py +2 -4
  49. sglang/srt/mem_cache/allocator.py +67 -7
  50. sglang/srt/mem_cache/hicache_storage.py +17 -1
  51. sglang/srt/mem_cache/hiradix_cache.py +30 -16
  52. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  53. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  54. sglang/srt/model_executor/forward_batch_info.py +201 -29
  55. sglang/srt/model_executor/model_runner.py +41 -23
  56. sglang/srt/models/deepseek_v2.py +1 -2
  57. sglang/srt/models/mllama4.py +10 -3
  58. sglang/srt/models/qwen2_moe.py +0 -4
  59. sglang/srt/models/qwen3_moe.py +1 -6
  60. sglang/srt/reasoning_parser.py +46 -4
  61. sglang/srt/sampling/sampling_batch_info.py +6 -5
  62. sglang/srt/server_args.py +76 -55
  63. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  65. sglang/srt/speculative/eagle_utils.py +51 -23
  66. sglang/srt/speculative/eagle_worker.py +59 -44
  67. sglang/srt/two_batch_overlap.py +9 -5
  68. sglang/srt/utils.py +17 -68
  69. sglang/test/test_activation.py +50 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
  72. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
  73. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -219,6 +219,7 @@ class HiCacheController:
219
219
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
220
220
  mem_pool_host: HostKVCache,
221
221
  page_size: int,
222
+ tp_group: torch.distributed.ProcessGroup,
222
223
  load_cache_event: threading.Event = None,
223
224
  write_policy: str = "write_through_selective",
224
225
  io_backend: str = "",
@@ -244,11 +245,17 @@ class HiCacheController:
244
245
  self.enable_storage = False
245
246
  # todo: move backend initialization to storage backend module
246
247
  if storage_backend is not None:
248
+ # create a new communication group for synchronizing storage operations across TP workers
249
+ self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
250
+ if self.tp_world_size > 1:
251
+ group_ranks = torch.distributed.get_process_group_ranks(tp_group)
252
+ self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
253
+
247
254
  if storage_backend == "file":
248
255
  self.storage_backend = HiCacheFile()
249
256
  self.enable_storage = True
250
257
  # todo: threshold policy for prefetching
251
- self.prefetch_threshold = prefetch_threshold
258
+ self.prefetch_threshold = max(prefetch_threshold, self.page_size)
252
259
  else:
253
260
  raise NotImplementedError(
254
261
  f"Unsupported storage backend: {storage_backend}"
@@ -358,6 +365,7 @@ class HiCacheController:
358
365
  if host_indices is None:
359
366
  return None
360
367
  self.mem_pool_host.protect_write(host_indices)
368
+ torch.cuda.current_stream().synchronize()
361
369
  self.write_queue.put(
362
370
  CacheOperation(host_indices, device_indices, node_id, priority)
363
371
  )
@@ -567,13 +575,32 @@ class HiCacheController:
567
575
  else:
568
576
  break
569
577
 
578
+ if self.tp_world_size > 1:
579
+ storage_hit_count_tensor = torch.tensor(
580
+ storage_hit_count, dtype=torch.int
581
+ )
582
+ torch.distributed.all_reduce(
583
+ storage_hit_count_tensor,
584
+ op=torch.distributed.ReduceOp.MIN,
585
+ group=self.tp_group,
586
+ )
587
+ storage_hit_count = storage_hit_count_tensor.item()
588
+
570
589
  if storage_hit_count < self.prefetch_threshold:
571
590
  # not to prefetch if not enough benefits
572
591
  self.prefetch_revoke_queue.put(operation.request_id)
592
+ logger.debug(
593
+ f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
594
+ )
573
595
  else:
574
- operation.hash_value = hash_value
596
+ operation.hash_value = hash_value[
597
+ : (storage_hit_count // self.page_size)
598
+ ]
599
+ # free the pre-allocated memory for pages that are not hit
600
+ self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
601
+ operation.host_indices = operation.host_indices[:storage_hit_count]
575
602
  logger.debug(
576
- f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
603
+ f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
577
604
  )
578
605
  self.prefetch_buffer.put(operation)
579
606
 
@@ -610,17 +637,37 @@ class HiCacheController:
610
637
  last_hash = get_hash_str(
611
638
  tokens_to_backup[i : i + self.page_size], last_hash
612
639
  )
613
- # todo, handle failures in storage backend
614
- self.storage_backend.set(
640
+ success = self.storage_backend.set(
615
641
  last_hash,
616
642
  self.mem_pool_host.get_flat_data_page(
617
643
  operation.host_indices[i]
618
644
  ),
619
645
  )
646
+ if not success:
647
+ logger.warning(f"Failed to write page {last_hash} to storage.")
648
+ break
620
649
  operation.completed_tokens += self.page_size
621
650
  operation.hash_value.append(last_hash)
622
651
 
623
- self.ack_backup_queue.put((operation.id, operation.hash_value))
652
+ min_completed_tokens = operation.completed_tokens
653
+ if self.tp_world_size > 1:
654
+ completed_tokens_tensor = torch.tensor(
655
+ min_completed_tokens, dtype=torch.int
656
+ )
657
+ torch.distributed.all_reduce(
658
+ completed_tokens_tensor,
659
+ op=torch.distributed.ReduceOp.MIN,
660
+ group=self.tp_group,
661
+ )
662
+ min_completed_tokens = completed_tokens_tensor.item()
663
+
664
+ self.ack_backup_queue.put(
665
+ (
666
+ operation.id,
667
+ operation.hash_value[: min_completed_tokens // self.page_size],
668
+ min_completed_tokens,
669
+ )
670
+ )
624
671
 
625
672
  except Empty:
626
673
  continue
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
22
22
  from enum import Enum
23
23
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24
24
 
25
+ from sglang.srt.lora.lora_registry import LoRARef
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
27
  from sglang.srt.multimodal.mm_utils import has_valid_data
27
28
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
1067
1068
  lora_name: str
1068
1069
  # The path of loading.
1069
1070
  lora_path: str
1071
+ # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1072
+ lora_id: Optional[str] = None
1073
+
1074
+ def to_ref(self) -> LoRARef:
1075
+ return LoRARef(
1076
+ lora_id=self.lora_id,
1077
+ lora_name=self.lora_name,
1078
+ lora_path=self.lora_path,
1079
+ )
1070
1080
 
1071
1081
 
1072
1082
  @dataclass
1073
1083
  class UnloadLoRAAdapterReqInput:
1074
1084
  # The name of lora module to unload.
1075
1085
  lora_name: str
1086
+ # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1087
+ lora_id: Optional[str] = None
1088
+
1089
+ def to_ref(self) -> LoRARef:
1090
+ return LoRARef(
1091
+ lora_id=self.lora_id,
1092
+ lora_name=self.lora_name,
1093
+ )
1076
1094
 
1077
1095
 
1078
1096
  @dataclass
1079
1097
  class LoRAUpdateResult:
1080
1098
  success: bool
1081
1099
  error_message: Optional[str] = None
1082
- loaded_adapters: Dict[str, str] = field(default_factory=dict)
1100
+ loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1083
1101
 
1084
1102
 
1085
1103
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -45,7 +45,6 @@ import triton
45
45
  import triton.language as tl
46
46
 
47
47
  from sglang.global_config import global_config
48
- from sglang.srt.configs.model_config import ModelConfig
49
48
  from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
49
  from sglang.srt.disaggregation.base import BaseKVSender
51
50
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
68
67
  from sglang.srt.utils import flatten_nested_list, support_triton
69
68
 
70
69
  if TYPE_CHECKING:
70
+ from sglang.srt.configs.model_config import ModelConfig
71
71
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
72
72
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
73
73
 
@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
106
106
  "num_reserved_decode_tokens",
107
107
  "weight_loader_disable_mmap",
108
108
  "enable_triton_kernel_moe",
109
+ "enable_multimodal",
109
110
  ]
110
111
 
111
112
  # Put some global args for easy access
@@ -430,6 +431,7 @@ class Req:
430
431
  bootstrap_port: Optional[int] = None,
431
432
  bootstrap_room: Optional[int] = None,
432
433
  data_parallel_rank: Optional[int] = None,
434
+ vocab_size: Optional[int] = None,
433
435
  ):
434
436
  # Input and output info
435
437
  self.rid = rid
@@ -479,6 +481,7 @@ class Req:
479
481
  self.to_abort_message: str = None
480
482
  self.stream = stream
481
483
  self.eos_token_ids = eos_token_ids
484
+ self.vocab_size = vocab_size
482
485
 
483
486
  # For incremental decoding
484
487
  # ----- | --------- read_ids -------|
@@ -712,6 +715,14 @@ class Req:
712
715
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
713
716
  return
714
717
 
718
+ if last_token_id > self.vocab_size or last_token_id < 0:
719
+ if self.sampling_params.stop_token_ids:
720
+ self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
721
+ if self.eos_token_ids:
722
+ self.output_ids[-1] = next(iter(self.eos_token_ids))
723
+ self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
724
+ return
725
+
715
726
  # Check stop strings
716
727
  if len(self.sampling_params.stop_strs) > 0:
717
728
  tail_str = self.tokenizer.decode(
@@ -1879,7 +1890,7 @@ class ModelWorkerBatch:
1879
1890
  sampling_info: SamplingBatchInfo
1880
1891
 
1881
1892
  # The input Embeds
1882
- input_embeds: Optional[torch.tensor] = None
1893
+ input_embeds: Optional[torch.Tensor] = None
1883
1894
 
1884
1895
  # For corss-encoder model
1885
1896
  token_type_ids: Optional[torch.Tensor] = None
@@ -1889,7 +1900,6 @@ class ModelWorkerBatch:
1889
1900
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1890
1901
  # If set, the output of the batch contains the hidden states of the run.
1891
1902
  capture_hidden_mode: CaptureHiddenMode = None
1892
- spec_num_draft_tokens: Optional[int] = None
1893
1903
  hicache_consumer_index: int = 0
1894
1904
 
1895
1905
  # Overlap event
@@ -247,7 +247,7 @@ class Scheduler(
247
247
  self.pp_size = server_args.pp_size
248
248
  self.dp_size = server_args.dp_size
249
249
  self.schedule_policy = server_args.schedule_policy
250
- self.lora_paths = server_args.lora_paths
250
+ self.enable_lora = server_args.enable_lora
251
251
  self.max_loras_per_batch = server_args.max_loras_per_batch
252
252
  self.enable_overlap = not server_args.disable_overlap_schedule
253
253
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
@@ -653,6 +653,9 @@ class Scheduler(
653
653
  )
654
654
  )
655
655
 
656
+ embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
657
+ init_embedding_cache(embedding_cache_size * 1024 * 1024)
658
+
656
659
  def init_profier(self):
657
660
  self.torch_profiler = None
658
661
  self.torch_profiler_output_dir: Optional[str] = None
@@ -1126,6 +1129,7 @@ class Scheduler(
1126
1129
  bootstrap_port=recv_req.bootstrap_port,
1127
1130
  bootstrap_room=recv_req.bootstrap_room,
1128
1131
  data_parallel_rank=recv_req.data_parallel_rank,
1132
+ vocab_size=self.model_config.vocab_size,
1129
1133
  )
1130
1134
  req.tokenizer = self.tokenizer
1131
1135
 
@@ -1392,8 +1396,10 @@ class Scheduler(
1392
1396
  logger.info(f)
1393
1397
 
1394
1398
  if self.enable_metrics:
1395
- cache_hit_rate = adder.log_hit_tokens / (
1396
- adder.log_input_tokens + adder.log_hit_tokens
1399
+ total_tokens = adder.log_input_tokens + adder.log_hit_tokens
1400
+
1401
+ cache_hit_rate = (
1402
+ adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
1397
1403
  )
1398
1404
  self.stats.num_running_reqs = running_bs
1399
1405
  self.stats.num_used_tokens = num_used
@@ -1706,13 +1712,13 @@ class Scheduler(
1706
1712
  self.chunked_req.init_next_round_input()
1707
1713
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
1708
1714
 
1709
- if self.lora_paths:
1715
+ if self.enable_lora:
1710
1716
  lora_set = set([req.lora_path for req in self.running_batch.reqs])
1711
1717
 
1712
1718
  # Get requests from the waiting queue to a new prefill batch
1713
1719
  for req in self.waiting_queue:
1714
1720
  if (
1715
- self.lora_paths
1721
+ self.enable_lora
1716
1722
  and len(
1717
1723
  lora_set
1718
1724
  | set([req.lora_path for req in adder.can_run_list])
@@ -2466,12 +2472,6 @@ class Scheduler(
2466
2472
  """In-place loading a new lora adapter from disk or huggingface."""
2467
2473
 
2468
2474
  result = self.tp_worker.load_lora_adapter(recv_req)
2469
-
2470
- if result.success:
2471
- flush_cache_success = self.flush_cache()
2472
- assert flush_cache_success, "Cache flush failed after loading lora adapter."
2473
- else:
2474
- logger.error(result.error_message)
2475
2475
  return result
2476
2476
 
2477
2477
  def unload_lora_adapter(
@@ -2480,14 +2480,6 @@ class Scheduler(
2480
2480
  """Unload the lora adapter."""
2481
2481
 
2482
2482
  result = self.tp_worker.unload_lora_adapter(recv_req)
2483
-
2484
- if result.success:
2485
- flush_cache_success = self.flush_cache()
2486
- assert (
2487
- flush_cache_success
2488
- ), "Cache flush failed after unloading LoRA weights"
2489
- else:
2490
- logger.error(result.error_message)
2491
2483
  return result
2492
2484
 
2493
2485
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
@@ -2909,9 +2901,9 @@ def run_scheduler_process(
2909
2901
  prefix += f" PP{pp_rank}"
2910
2902
 
2911
2903
  # Config the process
2912
- kill_itself_when_parent_died()
2913
2904
  setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2914
2905
  faulthandler.enable()
2906
+ kill_itself_when_parent_died()
2915
2907
  parent_process = psutil.Process().parent()
2916
2908
 
2917
2909
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
@@ -2926,10 +2918,6 @@ def run_scheduler_process(
2926
2918
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2927
2919
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2928
2920
 
2929
- embedding_cache_size = 100
2930
- if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
2931
- embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
2932
- init_embedding_cache(embedding_cache_size * 1024 * 1024)
2933
2921
  # Create a scheduler and run the event loop
2934
2922
  try:
2935
2923
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
@@ -2940,8 +2928,8 @@ def run_scheduler_process(
2940
2928
  "max_req_input_len": scheduler.max_req_input_len,
2941
2929
  }
2942
2930
  )
2943
- disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2944
2931
 
2932
+ disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2945
2933
  if disaggregation_mode == DisaggregationMode.NULL:
2946
2934
  if server_args.pp_size > 1:
2947
2935
  scheduler.event_loop_pp()
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
62
62
  get_tokenizer,
63
63
  get_tokenizer_from_processor,
64
64
  )
65
+ from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
65
66
  from sglang.srt.managers.io_struct import (
66
67
  AbortReq,
67
68
  BatchEmbeddingOut,
@@ -242,11 +243,11 @@ class TokenizerManager:
242
243
  revision=server_args.revision,
243
244
  )
244
245
 
245
- # Initialize loaded loRA adapters with the initial lora paths in the server_args.
246
- # This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
247
- self.loaded_lora_adapters: Dict[str, str] = dict(
248
- self.server_args.lora_paths or {}
249
- )
246
+ # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
247
+ # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
248
+ # serves as the source of truth for available adapters and maps user-friendly LoRA names
249
+ # to internally used unique LoRA IDs.
250
+ self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
250
251
 
251
252
  # Store states
252
253
  self.no_create_loop = False
@@ -523,6 +524,10 @@ class TokenizerManager:
523
524
  else:
524
525
  mm_inputs = None
525
526
 
527
+ if self.server_args.enable_lora and obj.lora_path:
528
+ # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
529
+ obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
530
+
526
531
  self._validate_one_request(obj, input_ids)
527
532
  return self._create_tokenized_object(
528
533
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -574,8 +579,6 @@ class TokenizerManager:
574
579
  "The server is not configured to enable custom logit processor. "
575
580
  "Please set `--enable-custom-logits-processor` to enable this feature."
576
581
  )
577
- if self.server_args.enable_lora and obj.lora_path:
578
- self._validate_lora_adapters(obj)
579
582
 
580
583
  def _validate_input_ids_in_vocab(
581
584
  self, input_ids: List[int], vocab_size: int
@@ -689,21 +692,6 @@ class TokenizerManager:
689
692
  "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
690
693
  )
691
694
 
692
- def _validate_lora_adapters(self, obj: GenerateReqInput):
693
- """Validate that the requested LoRA adapters are loaded."""
694
- requested_adapters = (
695
- set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
696
- )
697
- loaded_adapters = (
698
- self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
699
- )
700
- unloaded_adapters = requested_adapters - loaded_adapters
701
- if unloaded_adapters:
702
- raise ValueError(
703
- f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
704
- f"Loaded adapters: {loaded_adapters}."
705
- )
706
-
707
695
  def _send_one_request(
708
696
  self,
709
697
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -1054,8 +1042,18 @@ class TokenizerManager:
1054
1042
  )
1055
1043
 
1056
1044
  async with self.model_update_lock.writer_lock:
1045
+ # Generate new uniquely identifiable LoRARef object.
1046
+ new_adapter = LoRARef(
1047
+ lora_name=obj.lora_name,
1048
+ lora_path=obj.lora_path,
1049
+ )
1050
+
1051
+ # Register the new adapter in the registry.
1052
+ obj.lora_id = new_adapter.lora_id
1057
1053
  result = (await self.update_lora_adapter_communicator(obj))[0]
1058
- self.loaded_lora_adapters = result.loaded_adapters
1054
+ if result.success:
1055
+ await self.lora_registry.register(new_adapter)
1056
+
1059
1057
  return result
1060
1058
 
1061
1059
  async def unload_lora_adapter(
@@ -1069,6 +1067,10 @@ class TokenizerManager:
1069
1067
  "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1070
1068
  )
1071
1069
 
1070
+ assert (
1071
+ obj.lora_name is not None
1072
+ ), "lora_name must be provided to unload LoRA adapter"
1073
+
1072
1074
  # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1073
1075
  # with dp_size > 1.
1074
1076
  assert (
@@ -1080,8 +1082,9 @@ class TokenizerManager:
1080
1082
  )
1081
1083
 
1082
1084
  async with self.model_update_lock.writer_lock:
1085
+ obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
1083
1086
  result = (await self.update_lora_adapter_communicator(obj))[0]
1084
- self.loaded_lora_adapters = result.loaded_adapters
1087
+
1085
1088
  return result
1086
1089
 
1087
1090
  async def get_weights_by_name(
@@ -1309,7 +1312,7 @@ class TokenizerManager:
1309
1312
  filename = os.path.join(
1310
1313
  self.crash_dump_folder,
1311
1314
  os.getenv("HOSTNAME", None),
1312
- f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
1315
+ f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
1313
1316
  )
1314
1317
 
1315
1318
  os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -293,11 +293,9 @@ class TpModelWorker:
293
293
  return parameter
294
294
 
295
295
  def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
296
- result = self.model_runner.load_lora_adapter(
297
- recv_req.lora_name, recv_req.lora_path
298
- )
296
+ result = self.model_runner.load_lora_adapter(recv_req.to_ref())
299
297
  return result
300
298
 
301
299
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
302
- result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
300
+ result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
303
301
  return result
@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
51
51
  self._kvcache = kvcache
52
52
 
53
53
  self.free_pages = None
54
+ self.release_pages = None
54
55
  self.is_not_in_free_group = True
55
56
  self.free_group = []
56
57
 
@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
58
59
  return ""
59
60
 
60
61
  def available_size(self):
61
- return len(self.free_pages) * self.page_size
62
+ return (len(self.free_pages) + len(self.release_pages)) * self.page_size
62
63
 
63
64
  def get_kvcache(self):
64
65
  return self._kvcache
65
66
 
66
- def restore_state(self, free_pages):
67
- self.free_pages = free_pages
67
+ def restore_state(self, state):
68
+ self.free_pages, self.release_pages = state
68
69
 
69
70
  def backup_state(self):
70
- return self.free_pages
71
+ return (self.free_pages, self.release_pages)
71
72
 
72
73
  def free_group_begin(self):
73
74
  self.is_not_in_free_group = False
@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
78
79
  if self.free_group:
79
80
  self.free(torch.cat(self.free_group))
80
81
 
82
+ def merge_and_sort_free(self):
83
+ if len(self.release_pages) > 0:
84
+ self.free_pages = torch.cat((self.free_pages, self.release_pages))
85
+ self.free_pages, _ = torch.sort(self.free_pages)
86
+ self.release_pages = torch.empty(
87
+ (0,), dtype=self.release_pages.dtype, device=self.device
88
+ )
89
+
81
90
  def get_cpu_copy(self, *args, **kwargs):
82
91
  # FIXME: reuse the get_cpu_copy after paged allocator is implemented
83
92
  raise NotImplementedError()
@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
119
128
  )
120
129
  self.is_not_in_free_group = True
121
130
  self.free_group = []
131
+ self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
122
132
 
123
133
  def available_size(self):
124
134
  # To avoid minor "len(free_pages) * 1" overhead
125
- return len(self.free_pages)
135
+ return len(self.free_pages) + len(self.release_pages)
126
136
 
127
137
  def alloc(self, need_size: int):
138
+ if need_size > len(self.free_pages):
139
+ self.merge_and_sort_free()
128
140
  if need_size > len(self.free_pages):
129
141
  return None
130
142
 
@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
137
149
  return
138
150
 
139
151
  if self.is_not_in_free_group:
140
- self.free_pages = torch.cat((self.free_pages, free_index))
152
+ self.release_pages = torch.cat((self.release_pages, free_index))
141
153
  else:
142
154
  self.free_group.append(free_index)
143
155
 
@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
421
433
  ), "The allocation size should be page-aligned"
422
434
 
423
435
  num_pages = need_size // self.page_size
436
+ if num_pages > len(self.free_pages):
437
+ self.merge_and_sort_free()
424
438
  if num_pages > len(self.free_pages):
425
439
  return None
426
440
 
@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
446
460
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
447
461
  )
448
462
 
463
+ estimated_num_new_pages = (
464
+ (
465
+ (seq_lens + self.page_size - 1) // self.page_size
466
+ - (prefix_lens + self.page_size - 1) // self.page_size
467
+ )
468
+ .sum()
469
+ .item()
470
+ )
471
+ if estimated_num_new_pages > len(self.free_pages):
472
+ self.merge_and_sort_free()
473
+
449
474
  bs = len(prefix_lens)
450
475
  out_indices = torch.empty(
451
476
  (extend_num_tokens,), dtype=torch.int64, device=self.device
@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
483
508
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
484
509
  )
485
510
 
511
+ estimated_num_new_pages = (
512
+ (
513
+ (seq_lens + self.page_size - 1) // self.page_size
514
+ - (seq_lens - 1 + self.page_size - 1) // self.page_size
515
+ )
516
+ .sum()
517
+ .item()
518
+ )
519
+ if estimated_num_new_pages > len(self.free_pages):
520
+ self.merge_and_sort_free()
521
+
486
522
  bs = len(seq_lens)
487
523
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
488
524
  alloc_decode_kernel[(bs,)](
@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
511
547
 
512
548
  if self.is_not_in_free_group:
513
549
  free_page_indices = torch.unique(free_index // self.page_size)
514
- self.free_pages = torch.cat((free_page_indices, self.free_pages))
550
+ self.release_pages = torch.cat((free_page_indices, self.release_pages))
515
551
  else:
516
552
  self.free_group.append(free_index)
517
553
 
@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
525
561
  )
526
562
  self.is_not_in_free_group = True
527
563
  self.free_group = []
564
+ self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
528
565
 
529
566
  def get_cpu_copy(self, indices):
530
567
  return self._kvcache.get_cpu_copy(indices)
@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
633
670
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
634
671
  )
635
672
 
673
+ estimated_num_new_pages = (
674
+ (
675
+ (seq_lens + self.page_size - 1) // self.page_size
676
+ - (prefix_lens + self.page_size - 1) // self.page_size
677
+ )
678
+ .sum()
679
+ .item()
680
+ )
681
+ if estimated_num_new_pages > len(self.free_pages):
682
+ self.merge_and_sort_free()
683
+
636
684
  bs = len(prefix_lens)
637
685
  out_indices = torch.empty(
638
686
  (extend_num_tokens,), dtype=torch.int32, device=self.device
@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
668
716
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
669
717
  )
670
718
 
719
+ estimated_num_new_pages = (
720
+ (
721
+ (seq_lens + self.page_size - 1) // self.page_size
722
+ - (seq_lens - 1 + self.page_size - 1) // self.page_size
723
+ )
724
+ .sum()
725
+ .item()
726
+ )
727
+ if estimated_num_new_pages > len(self.free_pages):
728
+ self.merge_and_sort_free()
729
+
671
730
  bs = len(seq_lens)
672
731
  out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
673
732
 
@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
692
751
  def clear(self):
693
752
  super().clear()
694
753
  self.free_pages = self.free_pages.to(torch.int32)
754
+ self.release_pages = self.release_pages.to(torch.int32)