sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -58,6 +58,7 @@ from sglang.srt.disaggregation.utils import (
58
58
  prepare_abort,
59
59
  )
60
60
  from sglang.srt.distributed import get_pp_group, get_world_group
61
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
61
62
  from sglang.srt.hf_transformers_utils import (
62
63
  get_processor,
63
64
  get_tokenizer,
@@ -65,9 +66,6 @@ from sglang.srt.hf_transformers_utils import (
65
66
  )
66
67
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
67
68
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
68
- from sglang.srt.managers.expert_distribution import (
69
- get_global_expert_distribution_recorder,
70
- )
71
69
  from sglang.srt.managers.io_struct import (
72
70
  AbortReq,
73
71
  CloseSessionReqInput,
@@ -82,6 +80,8 @@ from sglang.srt.managers.io_struct import (
82
80
  HealthCheckOutput,
83
81
  InitWeightsUpdateGroupReqInput,
84
82
  InitWeightsUpdateGroupReqOutput,
83
+ LoadLoRAAdapterReqInput,
84
+ LoadLoRAAdapterReqOutput,
85
85
  OpenSessionReqInput,
86
86
  OpenSessionReqOutput,
87
87
  ProfileReq,
@@ -99,6 +99,8 @@ from sglang.srt.managers.io_struct import (
99
99
  SlowDownReqOutput,
100
100
  TokenizedEmbeddingReqInput,
101
101
  TokenizedGenerateReqInput,
102
+ UnloadLoRAAdapterReqInput,
103
+ UnloadLoRAAdapterReqOutput,
102
104
  UpdateWeightFromDiskReqInput,
103
105
  UpdateWeightFromDiskReqOutput,
104
106
  UpdateWeightsFromDistributedReqInput,
@@ -126,7 +128,8 @@ from sglang.srt.managers.session_controller import Session
126
128
  from sglang.srt.managers.tp_worker import TpModelWorker
127
129
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
130
  from sglang.srt.managers.utils import validate_input_length
129
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
131
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
132
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
133
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
134
  from sglang.srt.mem_cache.radix_cache import RadixCache
132
135
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -146,6 +149,7 @@ from sglang.srt.utils import (
146
149
  get_available_gpu_memory,
147
150
  get_bool_env_var,
148
151
  get_zmq_socket,
152
+ is_cpu,
149
153
  kill_itself_when_parent_died,
150
154
  point_to_point_pyobj,
151
155
  pyspy_dump_schedulers,
@@ -164,6 +168,8 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
164
168
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
165
169
  GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
166
170
 
171
+ _is_cpu = is_cpu()
172
+
167
173
 
168
174
  @dataclass
169
175
  class GenerationBatchResult:
@@ -182,6 +188,18 @@ class EmbeddingBatchResult:
182
188
  bid: int
183
189
 
184
190
 
191
+ class KvMetrics:
192
+ def __init__(self):
193
+ self.request_active_slots = None
194
+ self.request_total_slots = None
195
+ self.kv_active_blocks = None
196
+ self.kv_total_blocks = None
197
+ self.num_requests_waiting = None
198
+ self.gpu_cache_usage_perc = None
199
+ self.gpu_prefix_cache_hit_rate = None
200
+ self.data_parallel_rank = None
201
+
202
+
185
203
  class IdleSleeper:
186
204
  """
187
205
  In setups which have long inactivity periods it is desirable to reduce
@@ -223,6 +241,7 @@ class Scheduler(
223
241
  self.server_args = server_args
224
242
  self.tp_rank = tp_rank
225
243
  self.pp_rank = pp_rank
244
+ self.dp_rank = dp_rank
226
245
  self.tp_size = server_args.tp_size
227
246
  self.pp_size = server_args.pp_size
228
247
  self.dp_size = server_args.dp_size
@@ -261,6 +280,9 @@ class Scheduler(
261
280
  self.send_to_tokenizer = get_zmq_socket(
262
281
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
263
282
  )
283
+ self.send_metrics_from_scheduler = get_zmq_socket(
284
+ context, zmq.PUSH, port_args.metrics_ipc_name, False
285
+ )
264
286
 
265
287
  if server_args.skip_tokenizer_init:
266
288
  # Directly send to the TokenizerManager
@@ -286,6 +308,7 @@ class Scheduler(
286
308
  else:
287
309
  self.recv_from_tokenizer = None
288
310
  self.recv_from_rpc = None
311
+ self.send_metrics_from_scheduler = None
289
312
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
290
313
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
291
314
 
@@ -398,14 +421,16 @@ class Scheduler(
398
421
  self.last_decode_stats_tic = time.perf_counter()
399
422
  self.last_prefill_stats_tic = time.perf_counter()
400
423
  self.return_health_check_ct = 0
424
+ self.num_retracted_reqs: int = 0
425
+ self.num_paused_reqs: int = 0
426
+ self.kv_transfer_speed_gb_s: float = 0.0
427
+ self.kv_transfer_latency_ms: float = 0.0
428
+ self.sessions: Dict[str, Session] = {}
401
429
  self.current_stream = torch.get_device_module(self.device).current_stream()
402
430
  if self.device == "cpu":
403
431
  self.current_stream.synchronize = lambda: None # No-op for CPU
404
432
  self.forward_sleep_time = None
405
433
 
406
- # Init session info
407
- self.sessions: Dict[str, Session] = {}
408
-
409
434
  # Init chunked prefill
410
435
  self.chunked_prefill_size = server_args.chunked_prefill_size
411
436
  if self.chunked_prefill_size <= 0: # -1 means disable
@@ -453,26 +478,12 @@ class Scheduler(
453
478
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
454
479
  t.start()
455
480
  self.parent_process = psutil.Process().parent()
481
+
482
+ # Init memory saver, profiler and metric stats
456
483
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
457
484
  enable=server_args.enable_memory_saver
458
485
  )
459
-
460
- # Init profiler
461
- self.torch_profiler = None
462
- self.torch_profiler_output_dir: Optional[str] = None
463
- self.profiler_activities: Optional[List[str]] = None
464
- self.profile_id: Optional[str] = None
465
- self.profiler_target_forward_ct: Optional[int] = None
466
- self.profiler_target_prefill_ct: Optional[int] = None
467
- self.profiler_target_decode_ct: Optional[int] = None
468
- self.profiler_prefill_ct: Optional[int] = None
469
- self.profiler_decode_ct: Optional[int] = None
470
- self.profile_by_stage: bool = False
471
- self.profile_steps: Optional[int] = None
472
- self.profile_in_progress: bool = False
473
- self.rpd_profiler = None
474
-
475
- # Init metrics stats
486
+ self.init_profier()
476
487
  self.init_metrics()
477
488
  self.init_kv_events(server_args.kv_events_config)
478
489
 
@@ -501,9 +512,12 @@ class Scheduler(
501
512
  (SetInternalStateReq, self.set_internal_state),
502
513
  (RpcReqInput, self.handle_rpc_request),
503
514
  (ExpertDistributionReq, self.expert_distribution_handle),
515
+ (LoadLoRAAdapterReqInput, self.load_lora_adapter),
516
+ (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
504
517
  ]
505
518
  )
506
519
 
520
+ # Init disaggregation
507
521
  self.disaggregation_mode = DisaggregationMode(
508
522
  self.server_args.disaggregation_mode
509
523
  )
@@ -553,7 +567,11 @@ class Scheduler(
553
567
  server_args.chunked_prefill_size is not None
554
568
  and server_args.disable_radix_cache
555
569
  ):
556
- self.tree_cache = ChunkCache(
570
+ if self.model_config.is_hybrid:
571
+ ChunkCacheClass = SWAChunkCache
572
+ else:
573
+ ChunkCacheClass = ChunkCache
574
+ self.tree_cache = ChunkCacheClass(
557
575
  req_to_token_pool=self.req_to_token_pool,
558
576
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
559
577
  page_size=self.page_size,
@@ -598,6 +616,21 @@ class Scheduler(
598
616
  )
599
617
  )
600
618
 
619
+ def init_profier(self):
620
+ self.torch_profiler = None
621
+ self.torch_profiler_output_dir: Optional[str] = None
622
+ self.profiler_activities: Optional[List[str]] = None
623
+ self.profile_id: Optional[str] = None
624
+ self.profiler_target_forward_ct: Optional[int] = None
625
+ self.profiler_target_prefill_ct: Optional[int] = None
626
+ self.profiler_target_decode_ct: Optional[int] = None
627
+ self.profiler_prefill_ct: Optional[int] = None
628
+ self.profiler_decode_ct: Optional[int] = None
629
+ self.profile_by_stage: bool = False
630
+ self.profile_steps: Optional[int] = None
631
+ self.profile_in_progress: bool = False
632
+ self.rpd_profiler = None
633
+
601
634
  def init_metrics(self):
602
635
  self.last_gen_throughput: float = 0.0
603
636
  self.last_input_throughput: float = 0.0
@@ -677,9 +710,6 @@ class Scheduler(
677
710
  transfer_backend=self.transfer_backend,
678
711
  )
679
712
 
680
- # Metric for pre-allocation
681
- self.num_tokens_pre_allocated = 0
682
-
683
713
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
684
714
  # *2 for the headroom.
685
715
  buffer_size = self.max_running_requests * 2
@@ -898,7 +928,7 @@ class Scheduler(
898
928
  point_to_point_pyobj(
899
929
  recv_reqs,
900
930
  self.pp_rank * self.tp_size + dp_offset,
901
- self.world_group.cpu_group,
931
+ self.world_group.device_group,
902
932
  self.pp_rank * self.tp_size + dp_offset,
903
933
  (self.pp_rank + 1) * self.tp_size + dp_offset,
904
934
  )
@@ -945,7 +975,7 @@ class Scheduler(
945
975
  recv_reqs = point_to_point_pyobj(
946
976
  [],
947
977
  self.pp_rank * self.tp_size + dp_offset,
948
- self.world_group.cpu_group,
978
+ self.world_group.device_group,
949
979
  (self.pp_rank - 1) * self.tp_size + dp_offset,
950
980
  self.pp_rank * self.tp_size + dp_offset,
951
981
  )
@@ -1070,7 +1100,7 @@ class Scheduler(
1070
1100
  recv_req.session_params is not None
1071
1101
  and recv_req.session_params.id is not None
1072
1102
  ):
1073
- req.finished_reason = FINISH_ABORT(
1103
+ req.set_finish_with_abort(
1074
1104
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1075
1105
  )
1076
1106
  self._add_request_to_queue(req)
@@ -1239,6 +1269,22 @@ class Scheduler(
1239
1269
  req.logprob_start_len = len(req.origin_input_ids) - 1
1240
1270
  self._add_request_to_queue(req)
1241
1271
 
1272
+ def _emit_kv_metrics(self):
1273
+ kv_metrics = KvMetrics()
1274
+ kv_metrics.request_active_slots = self.stats.num_running_reqs
1275
+ kv_metrics.request_total_slots = self.max_running_requests
1276
+ kv_metrics.kv_active_blocks = int(
1277
+ self.stats.token_usage * self.max_total_num_tokens
1278
+ )
1279
+ kv_metrics.kv_total_blocks = self.max_total_num_tokens
1280
+ kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
1281
+ kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
1282
+ kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
1283
+ kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
1284
+
1285
+ if not self.send_metrics_from_scheduler.closed:
1286
+ self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
1287
+
1242
1288
  def log_prefill_stats(
1243
1289
  self,
1244
1290
  adder: PrefillAdder,
@@ -1250,9 +1296,8 @@ class Scheduler(
1250
1296
  self.last_input_throughput = self.last_prefill_tokens / gap_latency
1251
1297
  self.last_prefill_tokens = adder.log_input_tokens
1252
1298
 
1253
- num_used = self.max_total_num_tokens - (
1254
- self.token_to_kv_pool_allocator.available_size()
1255
- + self.tree_cache.evictable_size()
1299
+ usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1300
+ self.tree_cache.evictable_size()
1256
1301
  )
1257
1302
 
1258
1303
  num_new_seq = len(can_run_list)
@@ -1261,7 +1306,7 @@ class Scheduler(
1261
1306
  f"#new-seq: {num_new_seq}, "
1262
1307
  f"#new-token: {adder.log_input_tokens}, "
1263
1308
  f"#cached-token: {adder.log_hit_tokens}, "
1264
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1309
+ f"{usage_msg}"
1265
1310
  )
1266
1311
 
1267
1312
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1291,6 +1336,7 @@ class Scheduler(
1291
1336
  self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1292
1337
 
1293
1338
  self.metrics_collector.log_stats(self.stats)
1339
+ self._emit_kv_metrics()
1294
1340
  self._publish_kv_events()
1295
1341
 
1296
1342
  def log_decode_stats(
@@ -1303,9 +1349,8 @@ class Scheduler(
1303
1349
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1304
1350
  self.num_generated_tokens = 0
1305
1351
  num_running_reqs = len(batch.reqs)
1306
- num_used = self.max_total_num_tokens - (
1307
- self.token_to_kv_pool_allocator.available_size()
1308
- + self.tree_cache.evictable_size()
1352
+ usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
1353
+ self.tree_cache.evictable_size()
1309
1354
  )
1310
1355
 
1311
1356
  if RECORD_STEP_TIME:
@@ -1313,12 +1358,7 @@ class Scheduler(
1313
1358
  gap_latency / self.server_args.decode_log_interval
1314
1359
  )
1315
1360
 
1316
- msg = (
1317
- f"Decode batch. "
1318
- f"#running-req: {num_running_reqs}, "
1319
- f"#token: {num_used}, "
1320
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1321
- )
1361
+ msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
1322
1362
 
1323
1363
  if self.spec_algorithm.is_none():
1324
1364
  spec_accept_length = 0
@@ -1332,7 +1372,7 @@ class Scheduler(
1332
1372
  msg += f"accept len: {spec_accept_length:.2f}, "
1333
1373
 
1334
1374
  if self.disaggregation_mode == DisaggregationMode.DECODE:
1335
- msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1375
+ msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1336
1376
  msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
1337
1377
 
1338
1378
  msg += (
@@ -1352,13 +1392,15 @@ class Scheduler(
1352
1392
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1353
1393
  self.stats.spec_accept_length = spec_accept_length
1354
1394
  self.metrics_collector.log_stats(self.stats)
1395
+ self._emit_kv_metrics()
1355
1396
  self._publish_kv_events()
1356
1397
 
1357
1398
  def check_memory(self):
1358
- available_size = (
1359
- self.token_to_kv_pool_allocator.available_size()
1360
- + self.tree_cache.evictable_size()
1361
- )
1399
+ if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1400
+ available_token_size = self.token_to_kv_pool_allocator.full_available_size()
1401
+ else:
1402
+ available_token_size = self.token_to_kv_pool_allocator.available_size()
1403
+ available_size = available_token_size + self.tree_cache.evictable_size()
1362
1404
  protected_size = self.tree_cache.protected_size()
1363
1405
  memory_leak = available_size != (
1364
1406
  self.max_total_num_tokens
@@ -1369,7 +1411,7 @@ class Scheduler(
1369
1411
  msg = (
1370
1412
  "token_to_kv_pool_allocator memory leak detected! "
1371
1413
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
1372
- f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1414
+ f"{available_token_size=}\n"
1373
1415
  f"{self.tree_cache.evictable_size()=}\n"
1374
1416
  )
1375
1417
  raise ValueError(msg)
@@ -1448,7 +1490,7 @@ class Scheduler(
1448
1490
  if need_dp_attn_preparation and not self.spec_algorithm.is_none():
1449
1491
  # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
1450
1492
  # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1451
- new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
1493
+ new_batch = self.prepare_mlp_sync_batch(new_batch)
1452
1494
  need_dp_attn_preparation = new_batch is None
1453
1495
 
1454
1496
  if new_batch is not None:
@@ -1464,7 +1506,7 @@ class Scheduler(
1464
1506
 
1465
1507
  # Handle DP attention
1466
1508
  if need_dp_attn_preparation:
1467
- ret, _ = self.prepare_mlp_sync_batch(ret)
1509
+ ret = self.prepare_mlp_sync_batch(ret)
1468
1510
 
1469
1511
  return ret
1470
1512
 
@@ -1881,8 +1923,7 @@ class Scheduler(
1881
1923
  if not disable_cuda_graph:
1882
1924
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1883
1925
 
1884
- # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
1885
- return local_batch, any(is_extend_in_batch)
1926
+ return local_batch
1886
1927
 
1887
1928
  def get_idle_batch(self):
1888
1929
  idle_batch = ScheduleBatch.init_new(
@@ -2069,6 +2110,21 @@ class Scheduler(
2069
2110
  def get_internal_state(self, recv_req: GetInternalStateReq):
2070
2111
  ret = dict(global_server_args_dict)
2071
2112
  ret["last_gen_throughput"] = self.last_gen_throughput
2113
+ ret["memory_usage"] = {
2114
+ "weight": round(
2115
+ self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
2116
+ ),
2117
+ "kvcache": round(
2118
+ self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
2119
+ ),
2120
+ "token_capacity": int(self.max_total_num_tokens),
2121
+ }
2122
+
2123
+ if not _is_cpu:
2124
+ ret["memory_usage"]["cuda_graph"] = round(
2125
+ self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2126
+ )
2127
+
2072
2128
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2073
2129
  ret["avg_spec_accept_length"] = (
2074
2130
  self.cum_spec_accept_length / self.cum_spec_accept_count
@@ -2157,7 +2213,7 @@ class Scheduler(
2157
2213
  # Delete requests in the waiting queue
2158
2214
  to_del = []
2159
2215
  for i, req in enumerate(self.waiting_queue):
2160
- if req.rid.startswith(recv_req.rid):
2216
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2161
2217
  to_del.append(i)
2162
2218
 
2163
2219
  # Sort in reverse order to avoid index issues when deleting
@@ -2174,7 +2230,7 @@ class Scheduler(
2174
2230
  # Abort method 2: call `set_finish_with_abort`
2175
2231
  # The request will still run one prefill forward pass.
2176
2232
  # In this case, we change the input_ids to be only one token to make this prefill cheap.
2177
- if req.rid.startswith(recv_req.rid):
2233
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2178
2234
  logger.debug(f"Abort grammar queue request. {req.rid=}")
2179
2235
  if req.grammar:
2180
2236
  req.grammar.cancel()
@@ -2187,7 +2243,9 @@ class Scheduler(
2187
2243
  reqs = self.running_batch.reqs + self.cur_batch.reqs
2188
2244
 
2189
2245
  for req in reqs:
2190
- if req.rid.startswith(recv_req.rid) and not req.finished():
2246
+ if not req.finished() and (
2247
+ recv_req.abort_all or req.rid.startswith(recv_req.rid)
2248
+ ):
2191
2249
  # Abort method 3: set `to_abort=True`
2192
2250
  # The request will still run one decode forward pass.
2193
2251
  # Then we reuse all existing code to clean up the KV cache allocation.
@@ -2201,12 +2259,42 @@ class Scheduler(
2201
2259
  """In-place update of the weights from disk."""
2202
2260
  success, message = self.tp_worker.update_weights_from_disk(recv_req)
2203
2261
  if success:
2204
- flash_cache_success = self.flush_cache()
2205
- assert flash_cache_success, "Cache flush failed after updating weights"
2262
+ flush_cache_success = self.flush_cache()
2263
+ assert flush_cache_success, "Cache flush failed after updating weights"
2206
2264
  else:
2207
2265
  logger.error(message)
2208
2266
  return UpdateWeightFromDiskReqOutput(success, message, 0)
2209
2267
 
2268
+ def load_lora_adapter(
2269
+ self, recv_req: LoadLoRAAdapterReqInput
2270
+ ) -> LoadLoRAAdapterReqOutput:
2271
+ """In-place loading a new lora adapter from disk or huggingface."""
2272
+
2273
+ result = self.tp_worker.load_lora_adapter(recv_req)
2274
+
2275
+ if result.success:
2276
+ flush_cache_success = self.flush_cache()
2277
+ assert flush_cache_success, "Cache flush failed after loading lora adapter."
2278
+ else:
2279
+ logger.error(result.error_message)
2280
+ return result
2281
+
2282
+ def unload_lora_adapter(
2283
+ self, recv_req: UnloadLoRAAdapterReqInput
2284
+ ) -> UnloadLoRAAdapterReqOutput:
2285
+ """Unload the lora adapter."""
2286
+
2287
+ result = self.tp_worker.unload_lora_adapter(recv_req)
2288
+
2289
+ if result.success:
2290
+ flush_cache_success = self.flush_cache()
2291
+ assert (
2292
+ flush_cache_success
2293
+ ), "Cache flush failed after unloading LoRA weights"
2294
+ else:
2295
+ logger.error(result.error_message)
2296
+ return result
2297
+
2210
2298
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
2211
2299
  """Initialize the online model parameter update group."""
2212
2300
  success, message = self.tp_worker.init_weights_update_group(recv_req)
@@ -2219,8 +2307,9 @@ class Scheduler(
2219
2307
  """Update the online model parameter."""
2220
2308
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2221
2309
  if success:
2222
- flash_cache_success = self.flush_cache()
2223
- assert flash_cache_success, "Cache flush failed after updating weights"
2310
+ if recv_req.flush_cache:
2311
+ flush_cache_success = self.flush_cache()
2312
+ assert flush_cache_success, "Cache flush failed after updating weights"
2224
2313
  else:
2225
2314
  logger.error(message)
2226
2315
  return UpdateWeightsFromDistributedReqOutput(success, message)
@@ -2231,10 +2320,11 @@ class Scheduler(
2231
2320
  # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
2232
2321
  if success:
2233
2322
  if recv_req.flush_cache:
2234
- flash_cache_success = self.flush_cache()
2235
- assert flash_cache_success, "Cache flush failed after updating weights"
2323
+ flush_cache_success = self.flush_cache()
2324
+ assert flush_cache_success, "Cache flush failed after updating weights"
2236
2325
  else:
2237
2326
  logger.error(message)
2327
+ barrier(group=self.tp_cpu_group)
2238
2328
  return UpdateWeightsFromTensorReqOutput(success, message)
2239
2329
 
2240
2330
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
521
521
  stream_interval = (
522
522
  req.sampling_params.stream_interval or self.stream_interval
523
523
  )
524
- should_output = len(req.output_ids) % stream_interval == 0
524
+ should_output = (
525
+ len(req.output_ids) % stream_interval == 1
526
+ if not self.model_config.is_multimodal_gen
527
+ and stream_interval > 1
528
+ else len(req.output_ids) % stream_interval == 0
529
+ )
525
530
  else:
526
531
  should_output = (
527
532
  len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
528
- and not self.model_config.is_multimodal_gen
533
+ if not self.model_config.is_multimodal_gen
534
+ else False
529
535
  )
530
536
 
531
537
  if should_output:
@@ -54,7 +54,7 @@ class SessionReqNode:
54
54
  prefix += " -- " + self.childs[0].req.rid
55
55
  ret = self.childs[0]._str_helper(prefix)
56
56
  for child in self.childs[1:]:
57
- prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
57
+ prefix = " " * len(origin_prefix) + " \- " + child.req.rid
58
58
  ret += child._str_helper(prefix)
59
59
  return ret
60
60
 
@@ -106,14 +106,22 @@ class Session:
106
106
  last_req.origin_input_ids
107
107
  + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
108
108
  )
109
+
110
+ if session_params.drop_previous_output:
111
+ input_ids = last_req.origin_input_ids[:]
112
+
109
113
  if session_params.offset and session_params.offset != 0:
110
114
  input_ids = input_ids[: session_params.offset] + req.input_ids
111
115
  else:
112
116
  input_ids += req.input_ids
117
+
113
118
  input_ids_unpadded = (
114
119
  last_req.origin_input_ids_unpadded
115
120
  + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
116
121
  )
122
+ if session_params.drop_previous_output:
123
+ input_ids_unpadded = last_req.origin_input_ids_unpadded[:]
124
+
117
125
  if session_params.offset and session_params.offset != 0:
118
126
  input_ids_unpadded = (
119
127
  input_ids_unpadded[: session_params.offset] + req.input_ids
@@ -138,10 +146,11 @@ class Session:
138
146
  token_ids_logprob=req.token_ids_logprob,
139
147
  )
140
148
  if last_req is not None:
141
- new_req.multimodal_inputs = last_req.mm_inputs
149
+ new_req.multimodal_inputs = last_req.multimodal_inputs
142
150
  new_req.tokenizer = tokenizer
151
+
143
152
  if abort:
144
- new_req.to_abort = True
153
+ new_req.set_finish_with_abort("Invalid request session id")
145
154
  else:
146
155
  new_req_node = SessionReqNode(new_req, last_req_node)
147
156
  self.req_nodes[req.rid] = new_req_node