sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ import time
23
23
  from collections import defaultdict, deque
24
24
  from concurrent import futures
25
25
  from dataclasses import dataclass
26
- from http import HTTPStatus
27
26
  from pathlib import Path
28
27
  from types import SimpleNamespace
29
28
  from typing import Dict, List, Optional, Tuple, Union
@@ -36,6 +35,7 @@ from torch.distributed import barrier
36
35
 
37
36
  from sglang.global_config import global_config
38
37
  from sglang.srt.configs.model_config import ModelConfig
38
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
39
39
  from sglang.srt.constrained.base_grammar_backend import (
40
40
  INVALID_GRAMMAR_OBJ,
41
41
  create_grammar_backend,
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
140
140
  DeepEPMode,
141
141
  DynamicGradMode,
142
142
  broadcast_pyobj,
143
+ configure_gc_logger,
143
144
  configure_logger,
144
145
  disable_request_logging,
145
146
  get_available_gpu_memory,
@@ -148,6 +149,8 @@ from sglang.srt.utils import (
148
149
  kill_itself_when_parent_died,
149
150
  point_to_point_pyobj,
150
151
  pyspy_dump_schedulers,
152
+ require_mlp_sync,
153
+ require_mlp_tp_gather,
151
154
  set_gpu_proc_affinity,
152
155
  set_random_seed,
153
156
  suppress_other_loggers,
@@ -179,6 +182,18 @@ class EmbeddingBatchResult:
179
182
  bid: int
180
183
 
181
184
 
185
+ class KvMetrics:
186
+ def __init__(self):
187
+ self.request_active_slots = None
188
+ self.request_total_slots = None
189
+ self.kv_active_blocks = None
190
+ self.kv_total_blocks = None
191
+ self.num_requests_waiting = None
192
+ self.gpu_cache_usage_perc = None
193
+ self.gpu_prefix_cache_hit_rate = None
194
+ self.data_parallel_rank = None
195
+
196
+
182
197
  class IdleSleeper:
183
198
  """
184
199
  In setups which have long inactivity periods it is desirable to reduce
@@ -220,6 +235,7 @@ class Scheduler(
220
235
  self.server_args = server_args
221
236
  self.tp_rank = tp_rank
222
237
  self.pp_rank = pp_rank
238
+ self.dp_rank = dp_rank
223
239
  self.tp_size = server_args.tp_size
224
240
  self.pp_size = server_args.pp_size
225
241
  self.dp_size = server_args.dp_size
@@ -258,6 +274,9 @@ class Scheduler(
258
274
  self.send_to_tokenizer = get_zmq_socket(
259
275
  context, zmq.PUSH, port_args.tokenizer_ipc_name, False
260
276
  )
277
+ self.send_metrics_from_scheduler = get_zmq_socket(
278
+ context, zmq.PUSH, port_args.metrics_ipc_name, False
279
+ )
261
280
 
262
281
  if server_args.skip_tokenizer_init:
263
282
  # Directly send to the TokenizerManager
@@ -283,6 +302,7 @@ class Scheduler(
283
302
  else:
284
303
  self.recv_from_tokenizer = None
285
304
  self.recv_from_rpc = None
305
+ self.send_metrics_from_scheduler = None
286
306
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
287
307
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
288
308
 
@@ -450,8 +470,6 @@ class Scheduler(
450
470
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
451
471
  t.start()
452
472
  self.parent_process = psutil.Process().parent()
453
-
454
- # Init memory saver
455
473
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
456
474
  enable=server_args.enable_memory_saver
457
475
  )
@@ -508,6 +526,9 @@ class Scheduler(
508
526
  )
509
527
  self.init_disaggregation()
510
528
 
529
+ if get_bool_env_var("SGLANG_GC_LOG"):
530
+ configure_gc_logger()
531
+
511
532
  def maybe_sleep_on_idle(self):
512
533
  if self.idle_sleeper is not None:
513
534
  self.idle_sleeper.maybe_sleep()
@@ -559,12 +580,20 @@ class Scheduler(
559
580
  self.tree_cache = HiRadixCache(
560
581
  req_to_token_pool=self.req_to_token_pool,
561
582
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
562
- tp_cache_group=self.tp_cpu_group,
583
+ tp_cache_group=(
584
+ self.attn_tp_cpu_group
585
+ if self.server_args.enable_dp_attention
586
+ else self.tp_cpu_group
587
+ ),
563
588
  page_size=self.page_size,
564
589
  hicache_ratio=server_args.hicache_ratio,
565
590
  hicache_size=server_args.hicache_size,
566
591
  hicache_write_policy=server_args.hicache_write_policy,
567
592
  )
593
+ self.tp_worker.register_hicache_layer_transfer_counter(
594
+ self.tree_cache.cache_controller.layer_done_counter
595
+ )
596
+
568
597
  else:
569
598
  self.tree_cache = RadixCache(
570
599
  req_to_token_pool=self.req_to_token_pool,
@@ -622,7 +651,12 @@ class Scheduler(
622
651
  self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
623
652
  buffer_size
624
653
  )
625
- self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
654
+ self.disagg_metadata_buffers = MetadataBuffers(
655
+ buffer_size,
656
+ hidden_size=self.model_config.hf_text_config.hidden_size,
657
+ dtype=self.model_config.dtype,
658
+ custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
659
+ )
626
660
 
627
661
  # The decode requests polling kv cache
628
662
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
@@ -669,7 +703,12 @@ class Scheduler(
669
703
  self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
670
704
  buffer_size
671
705
  )
672
- self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
706
+ self.disagg_metadata_buffers = MetadataBuffers(
707
+ buffer_size,
708
+ hidden_size=self.model_config.hf_text_config.hidden_size,
709
+ dtype=self.model_config.dtype,
710
+ custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
711
+ )
673
712
 
674
713
  self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
675
714
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
@@ -795,11 +834,28 @@ class Scheduler(
795
834
  result.next_token_ids,
796
835
  result.bid,
797
836
  )
798
- pp_outputs = PPProxyTensors(
799
- {
800
- "next_token_ids": next_token_ids,
801
- }
802
- )
837
+ if self.cur_batch.return_logprob:
838
+ pp_outputs = PPProxyTensors(
839
+ {
840
+ "next_token_ids": next_token_ids,
841
+ "extend_input_len_per_req": result.extend_input_len_per_req,
842
+ "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
843
+ }
844
+ | (
845
+ {
846
+ f"logits_output.{k}": v
847
+ for k, v in result.logits_output.__dict__.items()
848
+ }
849
+ if result.logits_output is not None
850
+ else {}
851
+ )
852
+ )
853
+ else:
854
+ pp_outputs = PPProxyTensors(
855
+ {
856
+ "next_token_ids": next_token_ids,
857
+ }
858
+ )
803
859
  # send the output from the last round to let the next stage worker run post processing
804
860
  self.pp_group.send_tensor_dict(
805
861
  pp_outputs.tensors,
@@ -816,12 +872,25 @@ class Scheduler(
816
872
  )
817
873
  )
818
874
  mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
875
+ logits_output_args = {
876
+ k[len("logits_output.") :]: v
877
+ for k, v in next_pp_outputs.tensors.items()
878
+ if k.startswith("logits_output.")
879
+ }
880
+ if len(logits_output_args) > 0:
881
+ logits_output = LogitsProcessorOutput(**logits_output_args)
882
+ else:
883
+ logits_output = None
819
884
  output_result = GenerationBatchResult(
820
- logits_output=None,
885
+ logits_output=logits_output,
821
886
  pp_hidden_states_proxy_tensors=None,
822
887
  next_token_ids=next_pp_outputs["next_token_ids"],
823
- extend_input_len_per_req=None,
824
- extend_logprob_start_len_per_req=None,
888
+ extend_input_len_per_req=next_pp_outputs.tensors.get(
889
+ "extend_input_len_per_req", None
890
+ ),
891
+ extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
892
+ "extend_logprob_start_len_per_req", None
893
+ ),
825
894
  bid=bids[next_mb_id],
826
895
  can_run_cuda_graph=result.can_run_cuda_graph,
827
896
  )
@@ -1187,6 +1256,22 @@ class Scheduler(
1187
1256
  req.logprob_start_len = len(req.origin_input_ids) - 1
1188
1257
  self._add_request_to_queue(req)
1189
1258
 
1259
+ def _emit_kv_metrics(self):
1260
+ kv_metrics = KvMetrics()
1261
+ kv_metrics.request_active_slots = self.stats.num_running_reqs
1262
+ kv_metrics.request_total_slots = self.max_running_requests
1263
+ kv_metrics.kv_active_blocks = int(
1264
+ self.stats.token_usage * self.max_total_num_tokens
1265
+ )
1266
+ kv_metrics.kv_total_blocks = self.max_total_num_tokens
1267
+ kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
1268
+ kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
1269
+ kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
1270
+ kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
1271
+
1272
+ if not self.send_metrics_from_scheduler.closed:
1273
+ self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
1274
+
1190
1275
  def log_prefill_stats(
1191
1276
  self,
1192
1277
  adder: PrefillAdder,
@@ -1239,6 +1324,7 @@ class Scheduler(
1239
1324
  self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1240
1325
 
1241
1326
  self.metrics_collector.log_stats(self.stats)
1327
+ self._emit_kv_metrics()
1242
1328
  self._publish_kv_events()
1243
1329
 
1244
1330
  def log_decode_stats(
@@ -1300,6 +1386,7 @@ class Scheduler(
1300
1386
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1301
1387
  self.stats.spec_accept_length = spec_accept_length
1302
1388
  self.metrics_collector.log_stats(self.stats)
1389
+ self._emit_kv_metrics()
1303
1390
  self._publish_kv_events()
1304
1391
 
1305
1392
  def check_memory(self):
@@ -1322,7 +1409,14 @@ class Scheduler(
1322
1409
  )
1323
1410
  raise ValueError(msg)
1324
1411
 
1325
- if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1412
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1413
+ req_total_size = (
1414
+ self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
1415
+ )
1416
+ else:
1417
+ req_total_size = self.req_to_token_pool.size
1418
+
1419
+ if len(self.req_to_token_pool.free_slots) != req_total_size:
1326
1420
  msg = (
1327
1421
  "req_to_token_pool memory leak detected!"
1328
1422
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
@@ -1383,6 +1477,15 @@ class Scheduler(
1383
1477
  self.running_batch.merge_batch(self.last_batch)
1384
1478
 
1385
1479
  new_batch = self.get_new_batch_prefill()
1480
+
1481
+ need_dp_attn_preparation = require_mlp_sync(self.server_args)
1482
+
1483
+ if need_dp_attn_preparation and not self.spec_algorithm.is_none():
1484
+ # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
1485
+ # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1486
+ new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
1487
+ need_dp_attn_preparation = new_batch is None
1488
+
1386
1489
  if new_batch is not None:
1387
1490
  # Run prefill first if possible
1388
1491
  ret = new_batch
@@ -1395,8 +1498,8 @@ class Scheduler(
1395
1498
  ret = None
1396
1499
 
1397
1500
  # Handle DP attention
1398
- if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
1399
- ret, _ = self.prepare_dp_attn_batch(ret)
1501
+ if need_dp_attn_preparation:
1502
+ ret, _ = self.prepare_mlp_sync_batch(ret)
1400
1503
 
1401
1504
  return ret
1402
1505
 
@@ -1428,15 +1531,14 @@ class Scheduler(
1428
1531
  return None
1429
1532
 
1430
1533
  if self.enable_hierarchical_cache:
1431
- # check for completion of hierarchical cache activities to release memory
1432
- self.tree_cache.writing_check()
1433
- self.tree_cache.loading_check()
1534
+ self.tree_cache.check_hicache_events()
1434
1535
 
1435
1536
  # Get priority queue
1436
- prefix_computed = self.policy.calc_priority(self.waiting_queue)
1537
+ self.policy.calc_priority(self.waiting_queue)
1437
1538
 
1438
1539
  # Prefill policy
1439
1540
  adder = PrefillAdder(
1541
+ self.page_size,
1440
1542
  self.tree_cache,
1441
1543
  self.token_to_kv_pool_allocator,
1442
1544
  self.running_batch,
@@ -1478,14 +1580,8 @@ class Scheduler(
1478
1580
  self.running_batch.batch_is_full = True
1479
1581
  break
1480
1582
 
1481
- req.init_next_round_input(
1482
- None if prefix_computed else self.tree_cache,
1483
- self.enable_hierarchical_cache,
1484
- )
1485
-
1486
- res = adder.add_one_req(
1487
- req, self.chunked_req, self.enable_hierarchical_cache
1488
- )
1583
+ req.init_next_round_input(self.tree_cache)
1584
+ res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1489
1585
 
1490
1586
  if res != AddReqResult.CONTINUE:
1491
1587
  if res == AddReqResult.NO_TOKEN:
@@ -1512,9 +1608,6 @@ class Scheduler(
1512
1608
  x for x in self.waiting_queue if x not in set(can_run_list)
1513
1609
  ]
1514
1610
 
1515
- if self.enable_hierarchical_cache:
1516
- self.tree_cache.ready_to_load_cache()
1517
-
1518
1611
  if adder.new_chunked_req is not None:
1519
1612
  assert self.chunked_req is None
1520
1613
  self.chunked_req = adder.new_chunked_req
@@ -1538,6 +1631,12 @@ class Scheduler(
1538
1631
  self.server_args.enable_custom_logit_processor,
1539
1632
  chunked_req=self.chunked_req,
1540
1633
  )
1634
+ if self.enable_hierarchical_cache:
1635
+ # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1636
+ new_batch.hicache_consumer_index = (
1637
+ self.tree_cache.ready_to_load_host_cache()
1638
+ )
1639
+
1541
1640
  new_batch.prepare_for_extend()
1542
1641
 
1543
1642
  # Mixed-style chunked prefill
@@ -1613,6 +1712,11 @@ class Scheduler(
1613
1712
  if self.is_generation:
1614
1713
  if self.spec_algorithm.is_none():
1615
1714
  model_worker_batch = batch.get_model_worker_batch()
1715
+
1716
+ # update the consumer index of hicache to the running batch
1717
+ self.tp_worker.set_hicache_consumer(
1718
+ model_worker_batch.hicache_consumer_index
1719
+ )
1616
1720
  if self.pp_group.is_last_rank:
1617
1721
  logits_output, next_token_ids, can_run_cuda_graph = (
1618
1722
  self.tp_worker.forward_batch_generation(model_worker_batch)
@@ -1641,13 +1745,15 @@ class Scheduler(
1641
1745
  # These 2 values are needed for processing the output, but the values can be
1642
1746
  # modified by overlap schedule. So we have to copy them here so that
1643
1747
  # we can use the correct values in output processing.
1644
- if batch.return_logprob:
1748
+ if batch.return_logprob or self.spec_algorithm.is_eagle():
1645
1749
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1750
+ else:
1751
+ extend_input_len_per_req = None
1752
+ if batch.return_logprob:
1646
1753
  extend_logprob_start_len_per_req = [
1647
1754
  req.extend_logprob_start_len for req in batch.reqs
1648
1755
  ]
1649
1756
  else:
1650
- extend_input_len_per_req = None
1651
1757
  extend_logprob_start_len_per_req = None
1652
1758
 
1653
1759
  ret = GenerationBatchResult(
@@ -1695,12 +1801,11 @@ class Scheduler(
1695
1801
  self.return_health_check_ct -= 1
1696
1802
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1697
1803
 
1698
- def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1699
- return self.prepare_dp_attn_batch_raw(
1804
+ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
1805
+ return self.prepare_mlp_sync_batch_raw(
1700
1806
  local_batch,
1701
1807
  dp_size=self.server_args.dp_size,
1702
1808
  attn_tp_size=self.attn_tp_size,
1703
- moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1704
1809
  tp_cpu_group=self.tp_cpu_group,
1705
1810
  get_idle_batch=self.get_idle_batch,
1706
1811
  disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1709,14 +1814,14 @@ class Scheduler(
1709
1814
  enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1710
1815
  enable_deepep_moe=self.server_args.enable_deepep_moe,
1711
1816
  deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1817
+ require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1712
1818
  )
1713
1819
 
1714
1820
  @staticmethod
1715
- def prepare_dp_attn_batch_raw(
1821
+ def prepare_mlp_sync_batch_raw(
1716
1822
  local_batch: ScheduleBatch,
1717
1823
  dp_size,
1718
1824
  attn_tp_size: int,
1719
- moe_dense_tp_size: Optional[int],
1720
1825
  tp_cpu_group,
1721
1826
  get_idle_batch,
1722
1827
  disable_cuda_graph: bool,
@@ -1725,6 +1830,7 @@ class Scheduler(
1725
1830
  enable_two_batch_overlap: bool,
1726
1831
  enable_deepep_moe: bool,
1727
1832
  deepep_mode: DeepEPMode,
1833
+ require_mlp_tp_gather: bool,
1728
1834
  ):
1729
1835
  # Check if other DP workers have running batches
1730
1836
  if local_batch is None:
@@ -1732,8 +1838,6 @@ class Scheduler(
1732
1838
  num_tokens_for_logprob = 0
1733
1839
  elif local_batch.forward_mode.is_decode():
1734
1840
  num_tokens = local_batch.batch_size()
1735
- if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1736
- num_tokens = num_tokens * speculative_num_draft_tokens
1737
1841
  num_tokens_for_logprob = num_tokens
1738
1842
  else:
1739
1843
  num_tokens = local_batch.extend_num_tokens
@@ -1752,11 +1856,6 @@ class Scheduler(
1752
1856
  else:
1753
1857
  can_cuda_graph = 0
1754
1858
 
1755
- if not spec_algorithm.is_none():
1756
- # TODO(sang): Support cuda graph when idle batch is there.
1757
- if local_batch is None or local_batch.forward_mode.is_idle():
1758
- can_cuda_graph = 0
1759
-
1760
1859
  is_extend_in_batch = (
1761
1860
  local_batch.forward_mode.is_extend() if local_batch else False
1762
1861
  )
@@ -1801,7 +1900,7 @@ class Scheduler(
1801
1900
 
1802
1901
  if local_batch is not None:
1803
1902
  # TODO: handle the case when moe_dense_tp_size != 1
1804
- if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1903
+ if not require_mlp_tp_gather:
1805
1904
  local_batch.global_num_tokens = [num_tokens]
1806
1905
  local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1807
1906
  else:
@@ -1809,6 +1908,7 @@ class Scheduler(
1809
1908
  local_batch.global_num_tokens_for_logprob = (
1810
1909
  global_num_tokens_for_logprob
1811
1910
  )
1911
+ local_batch.is_extend_in_batch = any(is_extend_in_batch)
1812
1912
  local_batch.tbo_split_seq_index = tbo_split_seq_index
1813
1913
  local_batch.global_forward_mode = global_forward_mode
1814
1914
 
@@ -1816,6 +1916,7 @@ class Scheduler(
1816
1916
  if not disable_cuda_graph:
1817
1917
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1818
1918
 
1919
+ # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
1819
1920
  return local_batch, any(is_extend_in_batch)
1820
1921
 
1821
1922
  def get_idle_batch(self):
@@ -2135,8 +2236,8 @@ class Scheduler(
2135
2236
  """In-place update of the weights from disk."""
2136
2237
  success, message = self.tp_worker.update_weights_from_disk(recv_req)
2137
2238
  if success:
2138
- flash_cache_success = self.flush_cache()
2139
- assert flash_cache_success, "Cache flush failed after updating weights"
2239
+ flush_cache_success = self.flush_cache()
2240
+ assert flush_cache_success, "Cache flush failed after updating weights"
2140
2241
  else:
2141
2242
  logger.error(message)
2142
2243
  return UpdateWeightFromDiskReqOutput(success, message, 0)
@@ -2153,8 +2254,8 @@ class Scheduler(
2153
2254
  """Update the online model parameter."""
2154
2255
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
2155
2256
  if success:
2156
- flash_cache_success = self.flush_cache()
2157
- assert flash_cache_success, "Cache flush failed after updating weights"
2257
+ flush_cache_success = self.flush_cache()
2258
+ assert flush_cache_success, "Cache flush failed after updating weights"
2158
2259
  else:
2159
2260
  logger.error(message)
2160
2261
  return UpdateWeightsFromDistributedReqOutput(success, message)
@@ -2165,10 +2266,11 @@ class Scheduler(
2165
2266
  # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
2166
2267
  if success:
2167
2268
  if recv_req.flush_cache:
2168
- flash_cache_success = self.flush_cache()
2169
- assert flash_cache_success, "Cache flush failed after updating weights"
2269
+ flush_cache_success = self.flush_cache()
2270
+ assert flush_cache_success, "Cache flush failed after updating weights"
2170
2271
  else:
2171
2272
  logger.error(message)
2273
+ barrier(group=self.tp_cpu_group)
2172
2274
  return UpdateWeightsFromTensorReqOutput(success, message)
2173
2275
 
2174
2276
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
@@ -2176,23 +2278,40 @@ class Scheduler(
2176
2278
  return GetWeightsByNameReqOutput(parameter)
2177
2279
 
2178
2280
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2179
- self.memory_saver_adapter.check_validity(
2180
- caller_name="release_memory_occupation"
2181
- )
2182
- self.stashed_model_static_state = _export_static_state(
2183
- self.tp_worker.worker.model_runner.model
2184
- )
2185
- self.memory_saver_adapter.pause()
2186
- self.flush_cache()
2281
+ tags = recv_req.tags
2282
+ import subprocess
2283
+
2284
+ if tags is None:
2285
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2286
+
2287
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
2288
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2289
+ self.flush_cache()
2290
+
2291
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
2292
+ self.stashed_model_static_state = _export_static_state(
2293
+ self.tp_worker.worker.model_runner.model
2294
+ )
2295
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2296
+
2187
2297
  return ReleaseMemoryOccupationReqOutput()
2188
2298
 
2189
2299
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2190
- self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2191
- self.memory_saver_adapter.resume()
2192
- _import_static_state(
2193
- self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
2194
- )
2195
- del self.stashed_model_static_state
2300
+ tags = recv_req.tags
2301
+ if tags is None or len(tags) == 0:
2302
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2303
+
2304
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
2305
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2306
+ _import_static_state(
2307
+ self.tp_worker.worker.model_runner.model,
2308
+ self.stashed_model_static_state,
2309
+ )
2310
+ del self.stashed_model_static_state
2311
+
2312
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
2313
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2314
+
2196
2315
  return ResumeMemoryOccupationReqOutput()
2197
2316
 
2198
2317
  def slow_down(self, recv_req: SlowDownReqInput):
@@ -2421,8 +2540,10 @@ class Scheduler(
2421
2540
  if self.profiler_decode_ct > self.profiler_target_decode_ct:
2422
2541
  if self.profile_in_progress:
2423
2542
  self.stop_profile(stage=ForwardMode.DECODE)
2543
+ elif batch.forward_mode.is_idle():
2544
+ pass
2424
2545
  else:
2425
- raise RuntimeError("unsupported profile stage")
2546
+ raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2426
2547
  else:
2427
2548
  # Check profiler
2428
2549
  if (