sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
+ import datetime
16
17
  import faulthandler
17
18
  import logging
18
19
  import os
@@ -58,6 +59,7 @@ from sglang.srt.disaggregation.utils import (
58
59
  prepare_abort,
59
60
  )
60
61
  from sglang.srt.distributed import get_pp_group, get_world_group
62
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
61
63
  from sglang.srt.hf_transformers_utils import (
62
64
  get_processor,
63
65
  get_tokenizer,
@@ -65,9 +67,6 @@ from sglang.srt.hf_transformers_utils import (
65
67
  )
66
68
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
67
69
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
68
- from sglang.srt.managers.expert_distribution import (
69
- get_global_expert_distribution_recorder,
70
- )
71
70
  from sglang.srt.managers.io_struct import (
72
71
  AbortReq,
73
72
  CloseSessionReqInput,
@@ -82,6 +81,8 @@ from sglang.srt.managers.io_struct import (
82
81
  HealthCheckOutput,
83
82
  InitWeightsUpdateGroupReqInput,
84
83
  InitWeightsUpdateGroupReqOutput,
84
+ LoadLoRAAdapterReqInput,
85
+ LoadLoRAAdapterReqOutput,
85
86
  OpenSessionReqInput,
86
87
  OpenSessionReqOutput,
87
88
  ProfileReq,
@@ -99,6 +100,8 @@ from sglang.srt.managers.io_struct import (
99
100
  SlowDownReqOutput,
100
101
  TokenizedEmbeddingReqInput,
101
102
  TokenizedGenerateReqInput,
103
+ UnloadLoRAAdapterReqInput,
104
+ UnloadLoRAAdapterReqOutput,
102
105
  UpdateWeightFromDiskReqInput,
103
106
  UpdateWeightFromDiskReqOutput,
104
107
  UpdateWeightsFromDistributedReqInput,
@@ -126,7 +129,8 @@ from sglang.srt.managers.session_controller import Session
126
129
  from sglang.srt.managers.tp_worker import TpModelWorker
127
130
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
131
  from sglang.srt.managers.utils import validate_input_length
129
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
132
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
133
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
134
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
135
  from sglang.srt.mem_cache.radix_cache import RadixCache
132
136
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -146,6 +150,7 @@ from sglang.srt.utils import (
146
150
  get_available_gpu_memory,
147
151
  get_bool_env_var,
148
152
  get_zmq_socket,
153
+ is_cpu,
149
154
  kill_itself_when_parent_died,
150
155
  point_to_point_pyobj,
151
156
  pyspy_dump_schedulers,
@@ -164,6 +169,8 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
164
169
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
165
170
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
166
171
 
172
+ _is_cpu = is_cpu()
173
+
167
174
 
168
175
  @dataclass
169
176
  class GenerationBatchResult:
@@ -415,14 +422,16 @@ class Scheduler(
415
422
  self.last_decode_stats_tic = time.perf_counter()
416
423
  self.last_prefill_stats_tic = time.perf_counter()
417
424
  self.return_health_check_ct = 0
425
+ self.num_retracted_reqs: int = 0
426
+ self.num_paused_reqs: int = 0
427
+ self.kv_transfer_speed_gb_s: float = 0.0
428
+ self.kv_transfer_latency_ms: float = 0.0
429
+ self.sessions: Dict[str, Session] = {}
418
430
  self.current_stream = torch.get_device_module(self.device).current_stream()
419
431
  if self.device == "cpu":
420
432
  self.current_stream.synchronize = lambda: None # No-op for CPU
421
433
  self.forward_sleep_time = None
422
434
 
423
- # Init session info
424
- self.sessions: Dict[str, Session] = {}
425
-
426
435
  # Init chunked prefill
427
436
  self.chunked_prefill_size = server_args.chunked_prefill_size
428
437
  if self.chunked_prefill_size <= 0: # -1 means disable
@@ -470,26 +479,12 @@ class Scheduler(
470
479
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
471
480
  t.start()
472
481
  self.parent_process = psutil.Process().parent()
482
+
483
+ # Init memory saver, profiler and metric stats
473
484
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
474
485
  enable=server_args.enable_memory_saver
475
486
  )
476
-
477
- # Init profiler
478
- self.torch_profiler = None
479
- self.torch_profiler_output_dir: Optional[str] = None
480
- self.profiler_activities: Optional[List[str]] = None
481
- self.profile_id: Optional[str] = None
482
- self.profiler_target_forward_ct: Optional[int] = None
483
- self.profiler_target_prefill_ct: Optional[int] = None
484
- self.profiler_target_decode_ct: Optional[int] = None
485
- self.profiler_prefill_ct: Optional[int] = None
486
- self.profiler_decode_ct: Optional[int] = None
487
- self.profile_by_stage: bool = False
488
- self.profile_steps: Optional[int] = None
489
- self.profile_in_progress: bool = False
490
- self.rpd_profiler = None
491
-
492
- # Init metrics stats
487
+ self.init_profier()
493
488
  self.init_metrics()
494
489
  self.init_kv_events(server_args.kv_events_config)
495
490
 
@@ -518,9 +513,12 @@ class Scheduler(
518
513
  (SetInternalStateReq, self.set_internal_state),
519
514
  (RpcReqInput, self.handle_rpc_request),
520
515
  (ExpertDistributionReq, self.expert_distribution_handle),
516
+ (LoadLoRAAdapterReqInput, self.load_lora_adapter),
517
+ (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
521
518
  ]
522
519
  )
523
520
 
521
+ # Init disaggregation
524
522
  self.disaggregation_mode = DisaggregationMode(
525
523
  self.server_args.disaggregation_mode
526
524
  )
@@ -570,7 +568,11 @@ class Scheduler(
570
568
  server_args.chunked_prefill_size is not None
571
569
  and server_args.disable_radix_cache
572
570
  ):
573
- self.tree_cache = ChunkCache(
571
+ if self.model_config.is_hybrid:
572
+ ChunkCacheClass = SWAChunkCache
573
+ else:
574
+ ChunkCacheClass = ChunkCache
575
+ self.tree_cache = ChunkCacheClass(
574
576
  req_to_token_pool=self.req_to_token_pool,
575
577
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
576
578
  page_size=self.page_size,
@@ -589,6 +591,12 @@ class Scheduler(
589
591
  hicache_ratio=server_args.hicache_ratio,
590
592
  hicache_size=server_args.hicache_size,
591
593
  hicache_write_policy=server_args.hicache_write_policy,
594
+ hicache_io_backend=(
595
+ "direct"
596
+ if server_args.attention_backend
597
+ == "fa3" # hot fix for incompatibility
598
+ else server_args.hicache_io_backend
599
+ ),
592
600
  )
593
601
  self.tp_worker.register_hicache_layer_transfer_counter(
594
602
  self.tree_cache.cache_controller.layer_done_counter
@@ -615,6 +623,21 @@ class Scheduler(
615
623
  )
616
624
  )
617
625
 
626
+ def init_profier(self):
627
+ self.torch_profiler = None
628
+ self.torch_profiler_output_dir: Optional[str] = None
629
+ self.profiler_activities: Optional[List[str]] = None
630
+ self.profile_id: Optional[str] = None
631
+ self.profiler_target_forward_ct: Optional[int] = None
632
+ self.profiler_target_prefill_ct: Optional[int] = None
633
+ self.profiler_target_decode_ct: Optional[int] = None
634
+ self.profiler_prefill_ct: Optional[int] = None
635
+ self.profiler_decode_ct: Optional[int] = None
636
+ self.profile_by_stage: bool = False
637
+ self.profile_steps: Optional[int] = None
638
+ self.profile_in_progress: bool = False
639
+ self.rpd_profiler = None
640
+
618
641
  def init_metrics(self):
619
642
  self.last_gen_throughput: float = 0.0
620
643
  self.last_input_throughput: float = 0.0
@@ -694,9 +717,6 @@ class Scheduler(
694
717
  transfer_backend=self.transfer_backend,
695
718
  )
696
719
 
697
- # Metric for pre-allocation
698
- self.num_tokens_pre_allocated = 0
699
-
700
720
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
701
721
  # *2 for the headroom.
702
722
  buffer_size = self.max_running_requests * 2
@@ -915,7 +935,7 @@ class Scheduler(
915
935
  point_to_point_pyobj(
916
936
  recv_reqs,
917
937
  self.pp_rank * self.tp_size + dp_offset,
918
- self.world_group.cpu_group,
938
+ self.world_group.device_group,
919
939
  self.pp_rank * self.tp_size + dp_offset,
920
940
  (self.pp_rank + 1) * self.tp_size + dp_offset,
921
941
  )
@@ -962,7 +982,7 @@ class Scheduler(
962
982
  recv_reqs = point_to_point_pyobj(
963
983
  [],
964
984
  self.pp_rank * self.tp_size + dp_offset,
965
- self.world_group.cpu_group,
985
+ self.world_group.device_group,
966
986
  (self.pp_rank - 1) * self.tp_size + dp_offset,
967
987
  self.pp_rank * self.tp_size + dp_offset,
968
988
  )
@@ -1087,7 +1107,7 @@ class Scheduler(
1087
1107
  recv_req.session_params is not None
1088
1108
  and recv_req.session_params.id is not None
1089
1109
  ):
1090
- req.finished_reason = FINISH_ABORT(
1110
+ req.set_finish_with_abort(
1091
1111
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1092
1112
  )
1093
1113
  self._add_request_to_queue(req)
@@ -1283,9 +1303,8 @@ class Scheduler(
1283
1303
  self.last_input_throughput = self.last_prefill_tokens / gap_latency
1284
1304
  self.last_prefill_tokens = adder.log_input_tokens
1285
1305
 
1286
- num_used = self.max_total_num_tokens - (
1287
- self.token_to_kv_pool_allocator.available_size()
1288
- + self.tree_cache.evictable_size()
1306
+ usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1307
+ self.tree_cache.evictable_size()
1289
1308
  )
1290
1309
 
1291
1310
  num_new_seq = len(can_run_list)
@@ -1294,17 +1313,19 @@ class Scheduler(
1294
1313
  f"#new-seq: {num_new_seq}, "
1295
1314
  f"#new-token: {adder.log_input_tokens}, "
1296
1315
  f"#cached-token: {adder.log_hit_tokens}, "
1297
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1316
+ f"{usage_msg}"
1298
1317
  )
1299
1318
 
1300
1319
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1301
1320
  f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1302
1321
  f += f"#queue-req: {len(self.waiting_queue)}, "
1303
1322
  f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1304
- f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
1323
+ f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
1305
1324
  else:
1306
1325
  f += f"#running-req: {running_bs}, "
1307
- f += f"#queue-req: {len(self.waiting_queue)}"
1326
+ f += f"#queue-req: {len(self.waiting_queue)}, "
1327
+
1328
+ f += f"timestamp: {datetime.datetime.now().isoformat()}"
1308
1329
 
1309
1330
  logger.info(f)
1310
1331
 
@@ -1337,9 +1358,8 @@ class Scheduler(
1337
1358
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1338
1359
  self.num_generated_tokens = 0
1339
1360
  num_running_reqs = len(batch.reqs)
1340
- num_used = self.max_total_num_tokens - (
1341
- self.token_to_kv_pool_allocator.available_size()
1342
- + self.tree_cache.evictable_size()
1361
+ usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1362
+ self.tree_cache.evictable_size()
1343
1363
  )
1344
1364
 
1345
1365
  if RECORD_STEP_TIME:
@@ -1347,12 +1367,7 @@ class Scheduler(
1347
1367
  gap_latency / self.server_args.decode_log_interval
1348
1368
  )
1349
1369
 
1350
- msg = (
1351
- f"Decode batch. "
1352
- f"#running-req: {num_running_reqs}, "
1353
- f"#token: {num_used}, "
1354
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1355
- )
1370
+ msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
1356
1371
 
1357
1372
  if self.spec_algorithm.is_none():
1358
1373
  spec_accept_length = 0
@@ -1366,13 +1381,14 @@ class Scheduler(
1366
1381
  msg += f"accept len: {spec_accept_length:.2f}, "
1367
1382
 
1368
1383
  if self.disaggregation_mode == DisaggregationMode.DECODE:
1369
- msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1384
+ msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1370
1385
  msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1371
1386
 
1372
1387
  msg += (
1373
1388
  f"cuda graph: {can_run_cuda_graph}, "
1374
1389
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1375
- f"#queue-req: {len(self.waiting_queue)}"
1390
+ f"#queue-req: {len(self.waiting_queue)}, "
1391
+ f"timestamp: {datetime.datetime.now().isoformat()}"
1376
1392
  )
1377
1393
 
1378
1394
  logger.info(msg)
@@ -1390,10 +1406,11 @@ class Scheduler(
1390
1406
  self._publish_kv_events()
1391
1407
 
1392
1408
  def check_memory(self):
1393
- available_size = (
1394
- self.token_to_kv_pool_allocator.available_size()
1395
- + self.tree_cache.evictable_size()
1396
- )
1409
+ if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1410
+ available_token_size = self.token_to_kv_pool_allocator.full_available_size()
1411
+ else:
1412
+ available_token_size = self.token_to_kv_pool_allocator.available_size()
1413
+ available_size = available_token_size + self.tree_cache.evictable_size()
1397
1414
  protected_size = self.tree_cache.protected_size()
1398
1415
  memory_leak = available_size != (
1399
1416
  self.max_total_num_tokens
@@ -1404,7 +1421,7 @@ class Scheduler(
1404
1421
  msg = (
1405
1422
  "token_to_kv_pool_allocator memory leak detected! "
1406
1423
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
1407
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1424
+ f"{available_token_size=}\n"
1408
1425
  f"{self.tree_cache.evictable_size()=}\n"
1409
1426
  )
1410
1427
  raise ValueError(msg)
@@ -1483,7 +1500,7 @@ class Scheduler(
1483
1500
  if need_dp_attn_preparation and not self.spec_algorithm.is_none():
1484
1501
  # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
1485
1502
  # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1486
- new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
1503
+ new_batch = self.prepare_mlp_sync_batch(new_batch)
1487
1504
  need_dp_attn_preparation = new_batch is None
1488
1505
 
1489
1506
  if new_batch is not None:
@@ -1499,7 +1516,7 @@ class Scheduler(
1499
1516
 
1500
1517
  # Handle DP attention
1501
1518
  if need_dp_attn_preparation:
1502
- ret, _ = self.prepare_mlp_sync_batch(ret)
1519
+ ret = self.prepare_mlp_sync_batch(ret)
1503
1520
 
1504
1521
  return ret
1505
1522
 
@@ -1916,8 +1933,7 @@ class Scheduler(
1916
1933
  if not disable_cuda_graph:
1917
1934
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1918
1935
 
1919
- # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
1920
- return local_batch, any(is_extend_in_batch)
1936
+ return local_batch
1921
1937
 
1922
1938
  def get_idle_batch(self):
1923
1939
  idle_batch = ScheduleBatch.init_new(
@@ -2104,6 +2120,21 @@ class Scheduler(
2104
2120
  def get_internal_state(self, recv_req: GetInternalStateReq):
2105
2121
  ret = dict(global_server_args_dict)
2106
2122
  ret["last_gen_throughput"] = self.last_gen_throughput
2123
+ ret["memory_usage"] = {
2124
+ "weight": round(
2125
+ self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
2126
+ ),
2127
+ "kvcache": round(
2128
+ self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
2129
+ ),
2130
+ "token_capacity": int(self.max_total_num_tokens),
2131
+ }
2132
+
2133
+ if not _is_cpu:
2134
+ ret["memory_usage"]["cuda_graph"] = round(
2135
+ self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2136
+ )
2137
+
2107
2138
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2108
2139
  ret["avg_spec_accept_length"] = (
2109
2140
  self.cum_spec_accept_length / self.cum_spec_accept_count
@@ -2192,7 +2223,7 @@ class Scheduler(
2192
2223
  # Delete requests in the waiting queue
2193
2224
  to_del = []
2194
2225
  for i, req in enumerate(self.waiting_queue):
2195
- if req.rid.startswith(recv_req.rid):
2226
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2196
2227
  to_del.append(i)
2197
2228
 
2198
2229
  # Sort in reverse order to avoid index issues when deleting
@@ -2209,7 +2240,7 @@ class Scheduler(
2209
2240
  # Abort method 2: call `set_finish_with_abort`
2210
2241
  # The request will still run one prefill forward pass.
2211
2242
  # In this case, we change the input_ids to be only one token to make this prefill cheap.
2212
- if req.rid.startswith(recv_req.rid):
2243
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2213
2244
  logger.debug(f"Abort grammar queue request. {req.rid=}")
2214
2245
  if req.grammar:
2215
2246
  req.grammar.cancel()
@@ -2222,7 +2253,9 @@ class Scheduler(
2222
2253
  reqs = self.running_batch.reqs + self.cur_batch.reqs
2223
2254
 
2224
2255
  for req in reqs:
2225
- if req.rid.startswith(recv_req.rid) and not req.finished():
2256
+ if not req.finished() and (
2257
+ recv_req.abort_all or req.rid.startswith(recv_req.rid)
2258
+ ):
2226
2259
  # Abort method 3: set `to_abort=True`
2227
2260
  # The request will still run one decode forward pass.
2228
2261
  # Then we reuse all existing code to clean up the KV cache allocation.
@@ -2242,6 +2275,36 @@ class Scheduler(
2242
2275
  logger.error(message)
2243
2276
  return UpdateWeightFromDiskReqOutput(success, message, 0)
2244
2277
 
2278
+ def load_lora_adapter(
2279
+ self, recv_req: LoadLoRAAdapterReqInput
2280
+ ) -> LoadLoRAAdapterReqOutput:
2281
+ """In-place loading a new lora adapter from disk or huggingface."""
2282
+
2283
+ result = self.tp_worker.load_lora_adapter(recv_req)
2284
+
2285
+ if result.success:
2286
+ flush_cache_success = self.flush_cache()
2287
+ assert flush_cache_success, "Cache flush failed after loading lora adapter."
2288
+ else:
2289
+ logger.error(result.error_message)
2290
+ return result
2291
+
2292
+ def unload_lora_adapter(
2293
+ self, recv_req: UnloadLoRAAdapterReqInput
2294
+ ) -> UnloadLoRAAdapterReqOutput:
2295
+ """Unload the lora adapter."""
2296
+
2297
+ result = self.tp_worker.unload_lora_adapter(recv_req)
2298
+
2299
+ if result.success:
2300
+ flush_cache_success = self.flush_cache()
2301
+ assert (
2302
+ flush_cache_success
2303
+ ), "Cache flush failed after unloading LoRA weights"
2304
+ else:
2305
+ logger.error(result.error_message)
2306
+ return result
2307
+
2245
2308
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
2246
2309
  """Initialize the online model parameter update group."""
2247
2310
  success, message = self.tp_worker.init_weights_update_group(recv_req)
@@ -2254,8 +2317,9 @@ class Scheduler(
2254
2317
  """Update the online model parameter."""
2255
2318
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2256
2319
  if success:
2257
- flush_cache_success = self.flush_cache()
2258
- assert flush_cache_success, "Cache flush failed after updating weights"
2320
+ if recv_req.flush_cache:
2321
+ flush_cache_success = self.flush_cache()
2322
+ assert flush_cache_success, "Cache flush failed after updating weights"
2259
2323
  else:
2260
2324
  logger.error(message)
2261
2325
  return UpdateWeightsFromDistributedReqOutput(success, message)
@@ -2279,9 +2343,8 @@ class Scheduler(
2279
2343
 
2280
2344
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2281
2345
  tags = recv_req.tags
2282
- import subprocess
2283
2346
 
2284
- if tags is None:
2347
+ if tags is None or len(tags) == 0:
2285
2348
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2286
2349
 
2287
2350
  if GPU_MEMORY_TYPE_KV_CACHE in tags:
@@ -2292,17 +2355,20 @@ class Scheduler(
2292
2355
  self.stashed_model_static_state = _export_static_state(
2293
2356
  self.tp_worker.worker.model_runner.model
2294
2357
  )
2358
+ torch.distributed.barrier(self.tp_cpu_group)
2295
2359
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2296
2360
 
2297
2361
  return ReleaseMemoryOccupationReqOutput()
2298
2362
 
2299
2363
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2300
2364
  tags = recv_req.tags
2365
+
2301
2366
  if tags is None or len(tags) == 0:
2302
2367
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2303
2368
 
2304
2369
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
2305
2370
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2371
+ torch.distributed.barrier(self.tp_cpu_group)
2306
2372
  _import_static_state(
2307
2373
  self.tp_worker.worker.model_runner.model,
2308
2374
  self.stashed_model_static_state,
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
521
521
  stream_interval = (
522
522
  req.sampling_params.stream_interval or self.stream_interval
523
523
  )
524
- should_output = len(req.output_ids) % stream_interval == 0
524
+ should_output = (
525
+ len(req.output_ids) % stream_interval == 1
526
+ if not self.model_config.is_multimodal_gen
527
+ and stream_interval > 1
528
+ else len(req.output_ids) % stream_interval == 0
529
+ )
525
530
  else:
526
531
  should_output = (
527
532
  len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
528
- and not self.model_config.is_multimodal_gen
533
+ if not self.model_config.is_multimodal_gen
534
+ else False
529
535
  )
530
536
 
531
537
  if should_output:
@@ -54,7 +54,7 @@ class SessionReqNode:
54
54
  prefix += " -- " + self.childs[0].req.rid
55
55
  ret = self.childs[0]._str_helper(prefix)
56
56
  for child in self.childs[1:]:
57
- prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
57
+ prefix = " " * len(origin_prefix) + " \- " + child.req.rid
58
58
  ret += child._str_helper(prefix)
59
59
  return ret
60
60
 
@@ -106,14 +106,22 @@ class Session:
106
106
  last_req.origin_input_ids
107
107
  + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
108
108
  )
109
+
110
+ if session_params.drop_previous_output:
111
+ input_ids = last_req.origin_input_ids[:]
112
+
109
113
  if session_params.offset and session_params.offset != 0:
110
114
  input_ids = input_ids[: session_params.offset] + req.input_ids
111
115
  else:
112
116
  input_ids += req.input_ids
117
+
113
118
  input_ids_unpadded = (
114
119
  last_req.origin_input_ids_unpadded
115
120
  + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
116
121
  )
122
+ if session_params.drop_previous_output:
123
+ input_ids_unpadded = last_req.origin_input_ids_unpadded[:]
124
+
117
125
  if session_params.offset and session_params.offset != 0:
118
126
  input_ids_unpadded = (
119
127
  input_ids_unpadded[: session_params.offset] + req.input_ids
@@ -138,10 +146,11 @@ class Session:
138
146
  token_ids_logprob=req.token_ids_logprob,
139
147
  )
140
148
  if last_req is not None:
141
- new_req.multimodal_inputs = last_req.mm_inputs
149
+ new_req.multimodal_inputs = last_req.multimodal_inputs
142
150
  new_req.tokenizer = tokenizer
151
+
143
152
  if abort:
144
- new_req.to_abort = True
153
+ new_req.set_finish_with_abort("Invalid request session id")
145
154
  else:
146
155
  new_req_node = SessionReqNode(new_req, last_req_node)
147
156
  self.req_nodes[req.rid] = new_req_node