sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -58,6 +58,7 @@ from sglang.srt.disaggregation.utils import (
58
58
  prepare_abort,
59
59
  )
60
60
  from sglang.srt.distributed import get_pp_group, get_world_group
61
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
61
62
  from sglang.srt.hf_transformers_utils import (
62
63
  get_processor,
63
64
  get_tokenizer,
@@ -65,9 +66,6 @@ from sglang.srt.hf_transformers_utils import (
65
66
  )
66
67
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
67
68
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
68
- from sglang.srt.managers.expert_distribution import (
69
- get_global_expert_distribution_recorder,
70
- )
71
69
  from sglang.srt.managers.io_struct import (
72
70
  AbortReq,
73
71
  CloseSessionReqInput,
@@ -82,6 +80,8 @@ from sglang.srt.managers.io_struct import (
82
80
  HealthCheckOutput,
83
81
  InitWeightsUpdateGroupReqInput,
84
82
  InitWeightsUpdateGroupReqOutput,
83
+ LoadLoRAAdapterReqInput,
84
+ LoadLoRAAdapterReqOutput,
85
85
  OpenSessionReqInput,
86
86
  OpenSessionReqOutput,
87
87
  ProfileReq,
@@ -99,6 +99,8 @@ from sglang.srt.managers.io_struct import (
99
99
  SlowDownReqOutput,
100
100
  TokenizedEmbeddingReqInput,
101
101
  TokenizedGenerateReqInput,
102
+ UnloadLoRAAdapterReqInput,
103
+ UnloadLoRAAdapterReqOutput,
102
104
  UpdateWeightFromDiskReqInput,
103
105
  UpdateWeightFromDiskReqOutput,
104
106
  UpdateWeightsFromDistributedReqInput,
@@ -126,7 +128,8 @@ from sglang.srt.managers.session_controller import Session
126
128
  from sglang.srt.managers.tp_worker import TpModelWorker
127
129
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
130
  from sglang.srt.managers.utils import validate_input_length
129
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
131
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
132
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
133
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
134
  from sglang.srt.mem_cache.radix_cache import RadixCache
132
135
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -146,6 +149,7 @@ from sglang.srt.utils import (
146
149
  get_available_gpu_memory,
147
150
  get_bool_env_var,
148
151
  get_zmq_socket,
152
+ is_cpu,
149
153
  kill_itself_when_parent_died,
150
154
  point_to_point_pyobj,
151
155
  pyspy_dump_schedulers,
@@ -164,6 +168,8 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
164
168
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
165
169
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
166
170
 
171
+ _is_cpu = is_cpu()
172
+
167
173
 
168
174
  @dataclass
169
175
  class GenerationBatchResult:
@@ -415,14 +421,16 @@ class Scheduler(
415
421
  self.last_decode_stats_tic = time.perf_counter()
416
422
  self.last_prefill_stats_tic = time.perf_counter()
417
423
  self.return_health_check_ct = 0
424
+ self.num_retracted_reqs: int = 0
425
+ self.num_paused_reqs: int = 0
426
+ self.kv_transfer_speed_gb_s: float = 0.0
427
+ self.kv_transfer_latency_ms: float = 0.0
428
+ self.sessions: Dict[str, Session] = {}
418
429
  self.current_stream = torch.get_device_module(self.device).current_stream()
419
430
  if self.device == "cpu":
420
431
  self.current_stream.synchronize = lambda: None # No-op for CPU
421
432
  self.forward_sleep_time = None
422
433
 
423
- # Init session info
424
- self.sessions: Dict[str, Session] = {}
425
-
426
434
  # Init chunked prefill
427
435
  self.chunked_prefill_size = server_args.chunked_prefill_size
428
436
  if self.chunked_prefill_size <= 0: # -1 means disable
@@ -470,26 +478,12 @@ class Scheduler(
470
478
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
471
479
  t.start()
472
480
  self.parent_process = psutil.Process().parent()
481
+
482
+ # Init memory saver, profiler and metric stats
473
483
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
474
484
  enable=server_args.enable_memory_saver
475
485
  )
476
-
477
- # Init profiler
478
- self.torch_profiler = None
479
- self.torch_profiler_output_dir: Optional[str] = None
480
- self.profiler_activities: Optional[List[str]] = None
481
- self.profile_id: Optional[str] = None
482
- self.profiler_target_forward_ct: Optional[int] = None
483
- self.profiler_target_prefill_ct: Optional[int] = None
484
- self.profiler_target_decode_ct: Optional[int] = None
485
- self.profiler_prefill_ct: Optional[int] = None
486
- self.profiler_decode_ct: Optional[int] = None
487
- self.profile_by_stage: bool = False
488
- self.profile_steps: Optional[int] = None
489
- self.profile_in_progress: bool = False
490
- self.rpd_profiler = None
491
-
492
- # Init metrics stats
486
+ self.init_profier()
493
487
  self.init_metrics()
494
488
  self.init_kv_events(server_args.kv_events_config)
495
489
 
@@ -518,9 +512,12 @@ class Scheduler(
518
512
  (SetInternalStateReq, self.set_internal_state),
519
513
  (RpcReqInput, self.handle_rpc_request),
520
514
  (ExpertDistributionReq, self.expert_distribution_handle),
515
+ (LoadLoRAAdapterReqInput, self.load_lora_adapter),
516
+ (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
521
517
  ]
522
518
  )
523
519
 
520
+ # Init disaggregation
524
521
  self.disaggregation_mode = DisaggregationMode(
525
522
  self.server_args.disaggregation_mode
526
523
  )
@@ -570,7 +567,11 @@ class Scheduler(
570
567
  server_args.chunked_prefill_size is not None
571
568
  and server_args.disable_radix_cache
572
569
  ):
573
- self.tree_cache = ChunkCache(
570
+ if self.model_config.is_hybrid:
571
+ ChunkCacheClass = SWAChunkCache
572
+ else:
573
+ ChunkCacheClass = ChunkCache
574
+ self.tree_cache = ChunkCacheClass(
574
575
  req_to_token_pool=self.req_to_token_pool,
575
576
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
576
577
  page_size=self.page_size,
@@ -615,6 +616,21 @@ class Scheduler(
615
616
  )
616
617
  )
617
618
 
619
+ def init_profier(self):
620
+ self.torch_profiler = None
621
+ self.torch_profiler_output_dir: Optional[str] = None
622
+ self.profiler_activities: Optional[List[str]] = None
623
+ self.profile_id: Optional[str] = None
624
+ self.profiler_target_forward_ct: Optional[int] = None
625
+ self.profiler_target_prefill_ct: Optional[int] = None
626
+ self.profiler_target_decode_ct: Optional[int] = None
627
+ self.profiler_prefill_ct: Optional[int] = None
628
+ self.profiler_decode_ct: Optional[int] = None
629
+ self.profile_by_stage: bool = False
630
+ self.profile_steps: Optional[int] = None
631
+ self.profile_in_progress: bool = False
632
+ self.rpd_profiler = None
633
+
618
634
  def init_metrics(self):
619
635
  self.last_gen_throughput: float = 0.0
620
636
  self.last_input_throughput: float = 0.0
@@ -694,9 +710,6 @@ class Scheduler(
694
710
  transfer_backend=self.transfer_backend,
695
711
  )
696
712
 
697
- # Metric for pre-allocation
698
- self.num_tokens_pre_allocated = 0
699
-
700
713
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
701
714
  # *2 for the headroom.
702
715
  buffer_size = self.max_running_requests * 2
@@ -915,7 +928,7 @@ class Scheduler(
915
928
  point_to_point_pyobj(
916
929
  recv_reqs,
917
930
  self.pp_rank * self.tp_size + dp_offset,
918
- self.world_group.cpu_group,
931
+ self.world_group.device_group,
919
932
  self.pp_rank * self.tp_size + dp_offset,
920
933
  (self.pp_rank + 1) * self.tp_size + dp_offset,
921
934
  )
@@ -962,7 +975,7 @@ class Scheduler(
962
975
  recv_reqs = point_to_point_pyobj(
963
976
  [],
964
977
  self.pp_rank * self.tp_size + dp_offset,
965
- self.world_group.cpu_group,
978
+ self.world_group.device_group,
966
979
  (self.pp_rank - 1) * self.tp_size + dp_offset,
967
980
  self.pp_rank * self.tp_size + dp_offset,
968
981
  )
@@ -1087,7 +1100,7 @@ class Scheduler(
1087
1100
  recv_req.session_params is not None
1088
1101
  and recv_req.session_params.id is not None
1089
1102
  ):
1090
- req.finished_reason = FINISH_ABORT(
1103
+ req.set_finish_with_abort(
1091
1104
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1092
1105
  )
1093
1106
  self._add_request_to_queue(req)
@@ -1283,9 +1296,8 @@ class Scheduler(
1283
1296
  self.last_input_throughput = self.last_prefill_tokens / gap_latency
1284
1297
  self.last_prefill_tokens = adder.log_input_tokens
1285
1298
 
1286
- num_used = self.max_total_num_tokens - (
1287
- self.token_to_kv_pool_allocator.available_size()
1288
- + self.tree_cache.evictable_size()
1299
+ usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1300
+ self.tree_cache.evictable_size()
1289
1301
  )
1290
1302
 
1291
1303
  num_new_seq = len(can_run_list)
@@ -1294,7 +1306,7 @@ class Scheduler(
1294
1306
  f"#new-seq: {num_new_seq}, "
1295
1307
  f"#new-token: {adder.log_input_tokens}, "
1296
1308
  f"#cached-token: {adder.log_hit_tokens}, "
1297
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1309
+ f"{usage_msg}"
1298
1310
  )
1299
1311
 
1300
1312
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1337,9 +1349,8 @@ class Scheduler(
1337
1349
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1338
1350
  self.num_generated_tokens = 0
1339
1351
  num_running_reqs = len(batch.reqs)
1340
- num_used = self.max_total_num_tokens - (
1341
- self.token_to_kv_pool_allocator.available_size()
1342
- + self.tree_cache.evictable_size()
1352
+ usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1353
+ self.tree_cache.evictable_size()
1343
1354
  )
1344
1355
 
1345
1356
  if RECORD_STEP_TIME:
@@ -1347,12 +1358,7 @@ class Scheduler(
1347
1358
  gap_latency / self.server_args.decode_log_interval
1348
1359
  )
1349
1360
 
1350
- msg = (
1351
- f"Decode batch. "
1352
- f"#running-req: {num_running_reqs}, "
1353
- f"#token: {num_used}, "
1354
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1355
- )
1361
+ msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
1356
1362
 
1357
1363
  if self.spec_algorithm.is_none():
1358
1364
  spec_accept_length = 0
@@ -1366,7 +1372,7 @@ class Scheduler(
1366
1372
  msg += f"accept len: {spec_accept_length:.2f}, "
1367
1373
 
1368
1374
  if self.disaggregation_mode == DisaggregationMode.DECODE:
1369
- msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1375
+ msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1370
1376
  msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1371
1377
 
1372
1378
  msg += (
@@ -1390,10 +1396,11 @@ class Scheduler(
1390
1396
  self._publish_kv_events()
1391
1397
 
1392
1398
  def check_memory(self):
1393
- available_size = (
1394
- self.token_to_kv_pool_allocator.available_size()
1395
- + self.tree_cache.evictable_size()
1396
- )
1399
+ if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1400
+ available_token_size = self.token_to_kv_pool_allocator.full_available_size()
1401
+ else:
1402
+ available_token_size = self.token_to_kv_pool_allocator.available_size()
1403
+ available_size = available_token_size + self.tree_cache.evictable_size()
1397
1404
  protected_size = self.tree_cache.protected_size()
1398
1405
  memory_leak = available_size != (
1399
1406
  self.max_total_num_tokens
@@ -1404,7 +1411,7 @@ class Scheduler(
1404
1411
  msg = (
1405
1412
  "token_to_kv_pool_allocator memory leak detected! "
1406
1413
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
1407
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1414
+ f"{available_token_size=}\n"
1408
1415
  f"{self.tree_cache.evictable_size()=}\n"
1409
1416
  )
1410
1417
  raise ValueError(msg)
@@ -1483,7 +1490,7 @@ class Scheduler(
1483
1490
  if need_dp_attn_preparation and not self.spec_algorithm.is_none():
1484
1491
  # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
1485
1492
  # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1486
- new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
1493
+ new_batch = self.prepare_mlp_sync_batch(new_batch)
1487
1494
  need_dp_attn_preparation = new_batch is None
1488
1495
 
1489
1496
  if new_batch is not None:
@@ -1499,7 +1506,7 @@ class Scheduler(
1499
1506
 
1500
1507
  # Handle DP attention
1501
1508
  if need_dp_attn_preparation:
1502
- ret, _ = self.prepare_mlp_sync_batch(ret)
1509
+ ret = self.prepare_mlp_sync_batch(ret)
1503
1510
 
1504
1511
  return ret
1505
1512
 
@@ -1916,8 +1923,7 @@ class Scheduler(
1916
1923
  if not disable_cuda_graph:
1917
1924
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1918
1925
 
1919
- # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
1920
- return local_batch, any(is_extend_in_batch)
1926
+ return local_batch
1921
1927
 
1922
1928
  def get_idle_batch(self):
1923
1929
  idle_batch = ScheduleBatch.init_new(
@@ -2104,6 +2110,21 @@ class Scheduler(
2104
2110
  def get_internal_state(self, recv_req: GetInternalStateReq):
2105
2111
  ret = dict(global_server_args_dict)
2106
2112
  ret["last_gen_throughput"] = self.last_gen_throughput
2113
+ ret["memory_usage"] = {
2114
+ "weight": round(
2115
+ self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
2116
+ ),
2117
+ "kvcache": round(
2118
+ self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
2119
+ ),
2120
+ "token_capacity": int(self.max_total_num_tokens),
2121
+ }
2122
+
2123
+ if not _is_cpu:
2124
+ ret["memory_usage"]["cuda_graph"] = round(
2125
+ self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2126
+ )
2127
+
2107
2128
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2108
2129
  ret["avg_spec_accept_length"] = (
2109
2130
  self.cum_spec_accept_length / self.cum_spec_accept_count
@@ -2192,7 +2213,7 @@ class Scheduler(
2192
2213
  # Delete requests in the waiting queue
2193
2214
  to_del = []
2194
2215
  for i, req in enumerate(self.waiting_queue):
2195
- if req.rid.startswith(recv_req.rid):
2216
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2196
2217
  to_del.append(i)
2197
2218
 
2198
2219
  # Sort in reverse order to avoid index issues when deleting
@@ -2209,7 +2230,7 @@ class Scheduler(
2209
2230
  # Abort method 2: call `set_finish_with_abort`
2210
2231
  # The request will still run one prefill forward pass.
2211
2232
  # In this case, we change the input_ids to be only one token to make this prefill cheap.
2212
- if req.rid.startswith(recv_req.rid):
2233
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2213
2234
  logger.debug(f"Abort grammar queue request. {req.rid=}")
2214
2235
  if req.grammar:
2215
2236
  req.grammar.cancel()
@@ -2222,7 +2243,9 @@ class Scheduler(
2222
2243
  reqs = self.running_batch.reqs + self.cur_batch.reqs
2223
2244
 
2224
2245
  for req in reqs:
2225
- if req.rid.startswith(recv_req.rid) and not req.finished():
2246
+ if not req.finished() and (
2247
+ recv_req.abort_all or req.rid.startswith(recv_req.rid)
2248
+ ):
2226
2249
  # Abort method 3: set `to_abort=True`
2227
2250
  # The request will still run one decode forward pass.
2228
2251
  # Then we reuse all existing code to clean up the KV cache allocation.
@@ -2242,6 +2265,36 @@ class Scheduler(
2242
2265
  logger.error(message)
2243
2266
  return UpdateWeightFromDiskReqOutput(success, message, 0)
2244
2267
 
2268
+ def load_lora_adapter(
2269
+ self, recv_req: LoadLoRAAdapterReqInput
2270
+ ) -> LoadLoRAAdapterReqOutput:
2271
+ """In-place loading a new lora adapter from disk or huggingface."""
2272
+
2273
+ result = self.tp_worker.load_lora_adapter(recv_req)
2274
+
2275
+ if result.success:
2276
+ flush_cache_success = self.flush_cache()
2277
+ assert flush_cache_success, "Cache flush failed after loading lora adapter."
2278
+ else:
2279
+ logger.error(result.error_message)
2280
+ return result
2281
+
2282
+ def unload_lora_adapter(
2283
+ self, recv_req: UnloadLoRAAdapterReqInput
2284
+ ) -> UnloadLoRAAdapterReqOutput:
2285
+ """Unload the lora adapter."""
2286
+
2287
+ result = self.tp_worker.unload_lora_adapter(recv_req)
2288
+
2289
+ if result.success:
2290
+ flush_cache_success = self.flush_cache()
2291
+ assert (
2292
+ flush_cache_success
2293
+ ), "Cache flush failed after unloading LoRA weights"
2294
+ else:
2295
+ logger.error(result.error_message)
2296
+ return result
2297
+
2245
2298
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
2246
2299
  """Initialize the online model parameter update group."""
2247
2300
  success, message = self.tp_worker.init_weights_update_group(recv_req)
@@ -2254,8 +2307,9 @@ class Scheduler(
2254
2307
  """Update the online model parameter."""
2255
2308
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2256
2309
  if success:
2257
- flush_cache_success = self.flush_cache()
2258
- assert flush_cache_success, "Cache flush failed after updating weights"
2310
+ if recv_req.flush_cache:
2311
+ flush_cache_success = self.flush_cache()
2312
+ assert flush_cache_success, "Cache flush failed after updating weights"
2259
2313
  else:
2260
2314
  logger.error(message)
2261
2315
  return UpdateWeightsFromDistributedReqOutput(success, message)
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
521
521
  stream_interval = (
522
522
  req.sampling_params.stream_interval or self.stream_interval
523
523
  )
524
- should_output = len(req.output_ids) % stream_interval == 0
524
+ should_output = (
525
+ len(req.output_ids) % stream_interval == 1
526
+ if not self.model_config.is_multimodal_gen
527
+ and stream_interval > 1
528
+ else len(req.output_ids) % stream_interval == 0
529
+ )
525
530
  else:
526
531
  should_output = (
527
532
  len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
528
- and not self.model_config.is_multimodal_gen
533
+ if not self.model_config.is_multimodal_gen
534
+ else False
529
535
  )
530
536
 
531
537
  if should_output:
@@ -54,7 +54,7 @@ class SessionReqNode:
54
54
  prefix += " -- " + self.childs[0].req.rid
55
55
  ret = self.childs[0]._str_helper(prefix)
56
56
  for child in self.childs[1:]:
57
- prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
57
+ prefix = " " * len(origin_prefix) + " \- " + child.req.rid
58
58
  ret += child._str_helper(prefix)
59
59
  return ret
60
60
 
@@ -106,14 +106,22 @@ class Session:
106
106
  last_req.origin_input_ids
107
107
  + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
108
108
  )
109
+
110
+ if session_params.drop_previous_output:
111
+ input_ids = last_req.origin_input_ids[:]
112
+
109
113
  if session_params.offset and session_params.offset != 0:
110
114
  input_ids = input_ids[: session_params.offset] + req.input_ids
111
115
  else:
112
116
  input_ids += req.input_ids
117
+
113
118
  input_ids_unpadded = (
114
119
  last_req.origin_input_ids_unpadded
115
120
  + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
116
121
  )
122
+ if session_params.drop_previous_output:
123
+ input_ids_unpadded = last_req.origin_input_ids_unpadded[:]
124
+
117
125
  if session_params.offset and session_params.offset != 0:
118
126
  input_ids_unpadded = (
119
127
  input_ids_unpadded[: session_params.offset] + req.input_ids
@@ -138,10 +146,11 @@ class Session:
138
146
  token_ids_logprob=req.token_ids_logprob,
139
147
  )
140
148
  if last_req is not None:
141
- new_req.multimodal_inputs = last_req.mm_inputs
149
+ new_req.multimodal_inputs = last_req.multimodal_inputs
142
150
  new_req.tokenizer = tokenizer
151
+
143
152
  if abort:
144
- new_req.to_abort = True
153
+ new_req.set_finish_with_abort("Invalid request session id")
145
154
  else:
146
155
  new_req_node = SessionReqNode(new_req, last_req_node)
147
156
  self.req_nodes[req.rid] = new_req_node