sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ import time
23
23
  from collections import defaultdict, deque
24
24
  from concurrent import futures
25
25
  from dataclasses import dataclass
26
- from http import HTTPStatus
27
26
  from pathlib import Path
28
27
  from types import SimpleNamespace
29
28
  from typing import Dict, List, Optional, Tuple, Union
@@ -36,6 +35,7 @@ from torch.distributed import barrier
36
35
 
37
36
  from sglang.global_config import global_config
38
37
  from sglang.srt.configs.model_config import ModelConfig
38
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
39
39
  from sglang.srt.constrained.base_grammar_backend import (
40
40
  INVALID_GRAMMAR_OBJ,
41
41
  create_grammar_backend,
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
140
140
  DeepEPMode,
141
141
  DynamicGradMode,
142
142
  broadcast_pyobj,
143
+ configure_gc_logger,
143
144
  configure_logger,
144
145
  disable_request_logging,
145
146
  get_available_gpu_memory,
@@ -148,6 +149,8 @@ from sglang.srt.utils import (
148
149
  kill_itself_when_parent_died,
149
150
  point_to_point_pyobj,
150
151
  pyspy_dump_schedulers,
152
+ require_mlp_sync,
153
+ require_mlp_tp_gather,
151
154
  set_gpu_proc_affinity,
152
155
  set_random_seed,
153
156
  suppress_other_loggers,
@@ -450,8 +453,6 @@ class Scheduler(
450
453
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
451
454
  t.start()
452
455
  self.parent_process = psutil.Process().parent()
453
-
454
- # Init memory saver
455
456
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
456
457
  enable=server_args.enable_memory_saver
457
458
  )
@@ -508,6 +509,9 @@ class Scheduler(
508
509
  )
509
510
  self.init_disaggregation()
510
511
 
512
+ if get_bool_env_var("SGLANG_GC_LOG"):
513
+ configure_gc_logger()
514
+
511
515
  def maybe_sleep_on_idle(self):
512
516
  if self.idle_sleeper is not None:
513
517
  self.idle_sleeper.maybe_sleep()
@@ -559,12 +563,20 @@ class Scheduler(
559
563
  self.tree_cache = HiRadixCache(
560
564
  req_to_token_pool=self.req_to_token_pool,
561
565
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
562
- tp_cache_group=self.tp_cpu_group,
566
+ tp_cache_group=(
567
+ self.attn_tp_cpu_group
568
+ if self.server_args.enable_dp_attention
569
+ else self.tp_cpu_group
570
+ ),
563
571
  page_size=self.page_size,
564
572
  hicache_ratio=server_args.hicache_ratio,
565
573
  hicache_size=server_args.hicache_size,
566
574
  hicache_write_policy=server_args.hicache_write_policy,
567
575
  )
576
+ self.tp_worker.register_hicache_layer_transfer_counter(
577
+ self.tree_cache.cache_controller.layer_done_counter
578
+ )
579
+
568
580
  else:
569
581
  self.tree_cache = RadixCache(
570
582
  req_to_token_pool=self.req_to_token_pool,
@@ -622,7 +634,12 @@ class Scheduler(
622
634
  self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
623
635
  buffer_size
624
636
  )
625
- self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
637
+ self.disagg_metadata_buffers = MetadataBuffers(
638
+ buffer_size,
639
+ hidden_size=self.model_config.hf_text_config.hidden_size,
640
+ dtype=self.model_config.dtype,
641
+ custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
642
+ )
626
643
 
627
644
  # The decode requests polling kv cache
628
645
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
@@ -669,7 +686,12 @@ class Scheduler(
669
686
  self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
670
687
  buffer_size
671
688
  )
672
- self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
689
+ self.disagg_metadata_buffers = MetadataBuffers(
690
+ buffer_size,
691
+ hidden_size=self.model_config.hf_text_config.hidden_size,
692
+ dtype=self.model_config.dtype,
693
+ custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
694
+ )
673
695
 
674
696
  self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
675
697
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
@@ -795,11 +817,28 @@ class Scheduler(
795
817
  result.next_token_ids,
796
818
  result.bid,
797
819
  )
798
- pp_outputs = PPProxyTensors(
799
- {
800
- "next_token_ids": next_token_ids,
801
- }
802
- )
820
+ if self.cur_batch.return_logprob:
821
+ pp_outputs = PPProxyTensors(
822
+ {
823
+ "next_token_ids": next_token_ids,
824
+ "extend_input_len_per_req": result.extend_input_len_per_req,
825
+ "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
826
+ }
827
+ | (
828
+ {
829
+ f"logits_output.{k}": v
830
+ for k, v in result.logits_output.__dict__.items()
831
+ }
832
+ if result.logits_output is not None
833
+ else {}
834
+ )
835
+ )
836
+ else:
837
+ pp_outputs = PPProxyTensors(
838
+ {
839
+ "next_token_ids": next_token_ids,
840
+ }
841
+ )
803
842
  # send the output from the last round to let the next stage worker run post processing
804
843
  self.pp_group.send_tensor_dict(
805
844
  pp_outputs.tensors,
@@ -816,12 +855,25 @@ class Scheduler(
816
855
  )
817
856
  )
818
857
  mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
858
+ logits_output_args = {
859
+ k[len("logits_output.") :]: v
860
+ for k, v in next_pp_outputs.tensors.items()
861
+ if k.startswith("logits_output.")
862
+ }
863
+ if len(logits_output_args) > 0:
864
+ logits_output = LogitsProcessorOutput(**logits_output_args)
865
+ else:
866
+ logits_output = None
819
867
  output_result = GenerationBatchResult(
820
- logits_output=None,
868
+ logits_output=logits_output,
821
869
  pp_hidden_states_proxy_tensors=None,
822
870
  next_token_ids=next_pp_outputs["next_token_ids"],
823
- extend_input_len_per_req=None,
824
- extend_logprob_start_len_per_req=None,
871
+ extend_input_len_per_req=next_pp_outputs.tensors.get(
872
+ "extend_input_len_per_req", None
873
+ ),
874
+ extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
875
+ "extend_logprob_start_len_per_req", None
876
+ ),
825
877
  bid=bids[next_mb_id],
826
878
  can_run_cuda_graph=result.can_run_cuda_graph,
827
879
  )
@@ -1322,7 +1374,14 @@ class Scheduler(
1322
1374
  )
1323
1375
  raise ValueError(msg)
1324
1376
 
1325
- if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1377
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1378
+ req_total_size = (
1379
+ self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
1380
+ )
1381
+ else:
1382
+ req_total_size = self.req_to_token_pool.size
1383
+
1384
+ if len(self.req_to_token_pool.free_slots) != req_total_size:
1326
1385
  msg = (
1327
1386
  "req_to_token_pool memory leak detected!"
1328
1387
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
@@ -1383,6 +1442,15 @@ class Scheduler(
1383
1442
  self.running_batch.merge_batch(self.last_batch)
1384
1443
 
1385
1444
  new_batch = self.get_new_batch_prefill()
1445
+
1446
+ need_dp_attn_preparation = require_mlp_sync(self.server_args)
1447
+
1448
+ if need_dp_attn_preparation and not self.spec_algorithm.is_none():
1449
+ # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
1450
+ # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1451
+ new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
1452
+ need_dp_attn_preparation = new_batch is None
1453
+
1386
1454
  if new_batch is not None:
1387
1455
  # Run prefill first if possible
1388
1456
  ret = new_batch
@@ -1395,8 +1463,8 @@ class Scheduler(
1395
1463
  ret = None
1396
1464
 
1397
1465
  # Handle DP attention
1398
- if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
1399
- ret, _ = self.prepare_dp_attn_batch(ret)
1466
+ if need_dp_attn_preparation:
1467
+ ret, _ = self.prepare_mlp_sync_batch(ret)
1400
1468
 
1401
1469
  return ret
1402
1470
 
@@ -1428,15 +1496,14 @@ class Scheduler(
1428
1496
  return None
1429
1497
 
1430
1498
  if self.enable_hierarchical_cache:
1431
- # check for completion of hierarchical cache activities to release memory
1432
- self.tree_cache.writing_check()
1433
- self.tree_cache.loading_check()
1499
+ self.tree_cache.check_hicache_events()
1434
1500
 
1435
1501
  # Get priority queue
1436
- prefix_computed = self.policy.calc_priority(self.waiting_queue)
1502
+ self.policy.calc_priority(self.waiting_queue)
1437
1503
 
1438
1504
  # Prefill policy
1439
1505
  adder = PrefillAdder(
1506
+ self.page_size,
1440
1507
  self.tree_cache,
1441
1508
  self.token_to_kv_pool_allocator,
1442
1509
  self.running_batch,
@@ -1478,14 +1545,8 @@ class Scheduler(
1478
1545
  self.running_batch.batch_is_full = True
1479
1546
  break
1480
1547
 
1481
- req.init_next_round_input(
1482
- None if prefix_computed else self.tree_cache,
1483
- self.enable_hierarchical_cache,
1484
- )
1485
-
1486
- res = adder.add_one_req(
1487
- req, self.chunked_req, self.enable_hierarchical_cache
1488
- )
1548
+ req.init_next_round_input(self.tree_cache)
1549
+ res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1489
1550
 
1490
1551
  if res != AddReqResult.CONTINUE:
1491
1552
  if res == AddReqResult.NO_TOKEN:
@@ -1512,9 +1573,6 @@ class Scheduler(
1512
1573
  x for x in self.waiting_queue if x not in set(can_run_list)
1513
1574
  ]
1514
1575
 
1515
- if self.enable_hierarchical_cache:
1516
- self.tree_cache.ready_to_load_cache()
1517
-
1518
1576
  if adder.new_chunked_req is not None:
1519
1577
  assert self.chunked_req is None
1520
1578
  self.chunked_req = adder.new_chunked_req
@@ -1538,6 +1596,12 @@ class Scheduler(
1538
1596
  self.server_args.enable_custom_logit_processor,
1539
1597
  chunked_req=self.chunked_req,
1540
1598
  )
1599
+ if self.enable_hierarchical_cache:
1600
+ # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1601
+ new_batch.hicache_consumer_index = (
1602
+ self.tree_cache.ready_to_load_host_cache()
1603
+ )
1604
+
1541
1605
  new_batch.prepare_for_extend()
1542
1606
 
1543
1607
  # Mixed-style chunked prefill
@@ -1613,6 +1677,11 @@ class Scheduler(
1613
1677
  if self.is_generation:
1614
1678
  if self.spec_algorithm.is_none():
1615
1679
  model_worker_batch = batch.get_model_worker_batch()
1680
+
1681
+ # update the consumer index of hicache to the running batch
1682
+ self.tp_worker.set_hicache_consumer(
1683
+ model_worker_batch.hicache_consumer_index
1684
+ )
1616
1685
  if self.pp_group.is_last_rank:
1617
1686
  logits_output, next_token_ids, can_run_cuda_graph = (
1618
1687
  self.tp_worker.forward_batch_generation(model_worker_batch)
@@ -1641,13 +1710,15 @@ class Scheduler(
1641
1710
  # These 2 values are needed for processing the output, but the values can be
1642
1711
  # modified by overlap schedule. So we have to copy them here so that
1643
1712
  # we can use the correct values in output processing.
1644
- if batch.return_logprob:
1713
+ if batch.return_logprob or self.spec_algorithm.is_eagle():
1645
1714
  extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1715
+ else:
1716
+ extend_input_len_per_req = None
1717
+ if batch.return_logprob:
1646
1718
  extend_logprob_start_len_per_req = [
1647
1719
  req.extend_logprob_start_len for req in batch.reqs
1648
1720
  ]
1649
1721
  else:
1650
- extend_input_len_per_req = None
1651
1722
  extend_logprob_start_len_per_req = None
1652
1723
 
1653
1724
  ret = GenerationBatchResult(
@@ -1695,12 +1766,11 @@ class Scheduler(
1695
1766
  self.return_health_check_ct -= 1
1696
1767
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1697
1768
 
1698
- def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1699
- return self.prepare_dp_attn_batch_raw(
1769
+ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
1770
+ return self.prepare_mlp_sync_batch_raw(
1700
1771
  local_batch,
1701
1772
  dp_size=self.server_args.dp_size,
1702
1773
  attn_tp_size=self.attn_tp_size,
1703
- moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1704
1774
  tp_cpu_group=self.tp_cpu_group,
1705
1775
  get_idle_batch=self.get_idle_batch,
1706
1776
  disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1709,14 +1779,14 @@ class Scheduler(
1709
1779
  enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1710
1780
  enable_deepep_moe=self.server_args.enable_deepep_moe,
1711
1781
  deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1782
+ require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1712
1783
  )
1713
1784
 
1714
1785
  @staticmethod
1715
- def prepare_dp_attn_batch_raw(
1786
+ def prepare_mlp_sync_batch_raw(
1716
1787
  local_batch: ScheduleBatch,
1717
1788
  dp_size,
1718
1789
  attn_tp_size: int,
1719
- moe_dense_tp_size: Optional[int],
1720
1790
  tp_cpu_group,
1721
1791
  get_idle_batch,
1722
1792
  disable_cuda_graph: bool,
@@ -1725,6 +1795,7 @@ class Scheduler(
1725
1795
  enable_two_batch_overlap: bool,
1726
1796
  enable_deepep_moe: bool,
1727
1797
  deepep_mode: DeepEPMode,
1798
+ require_mlp_tp_gather: bool,
1728
1799
  ):
1729
1800
  # Check if other DP workers have running batches
1730
1801
  if local_batch is None:
@@ -1732,8 +1803,6 @@ class Scheduler(
1732
1803
  num_tokens_for_logprob = 0
1733
1804
  elif local_batch.forward_mode.is_decode():
1734
1805
  num_tokens = local_batch.batch_size()
1735
- if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1736
- num_tokens = num_tokens * speculative_num_draft_tokens
1737
1806
  num_tokens_for_logprob = num_tokens
1738
1807
  else:
1739
1808
  num_tokens = local_batch.extend_num_tokens
@@ -1752,11 +1821,6 @@ class Scheduler(
1752
1821
  else:
1753
1822
  can_cuda_graph = 0
1754
1823
 
1755
- if not spec_algorithm.is_none():
1756
- # TODO(sang): Support cuda graph when idle batch is there.
1757
- if local_batch is None or local_batch.forward_mode.is_idle():
1758
- can_cuda_graph = 0
1759
-
1760
1824
  is_extend_in_batch = (
1761
1825
  local_batch.forward_mode.is_extend() if local_batch else False
1762
1826
  )
@@ -1801,7 +1865,7 @@ class Scheduler(
1801
1865
 
1802
1866
  if local_batch is not None:
1803
1867
  # TODO: handle the case when moe_dense_tp_size != 1
1804
- if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1868
+ if not require_mlp_tp_gather:
1805
1869
  local_batch.global_num_tokens = [num_tokens]
1806
1870
  local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1807
1871
  else:
@@ -1809,6 +1873,7 @@ class Scheduler(
1809
1873
  local_batch.global_num_tokens_for_logprob = (
1810
1874
  global_num_tokens_for_logprob
1811
1875
  )
1876
+ local_batch.is_extend_in_batch = any(is_extend_in_batch)
1812
1877
  local_batch.tbo_split_seq_index = tbo_split_seq_index
1813
1878
  local_batch.global_forward_mode = global_forward_mode
1814
1879
 
@@ -1816,6 +1881,7 @@ class Scheduler(
1816
1881
  if not disable_cuda_graph:
1817
1882
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1818
1883
 
1884
+ # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
1819
1885
  return local_batch, any(is_extend_in_batch)
1820
1886
 
1821
1887
  def get_idle_batch(self):
@@ -2176,23 +2242,40 @@ class Scheduler(
2176
2242
  return GetWeightsByNameReqOutput(parameter)
2177
2243
 
2178
2244
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2179
- self.memory_saver_adapter.check_validity(
2180
- caller_name="release_memory_occupation"
2181
- )
2182
- self.stashed_model_static_state = _export_static_state(
2183
- self.tp_worker.worker.model_runner.model
2184
- )
2185
- self.memory_saver_adapter.pause()
2186
- self.flush_cache()
2245
+ tags = recv_req.tags
2246
+ import subprocess
2247
+
2248
+ if tags is None:
2249
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2250
+
2251
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
2252
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2253
+ self.flush_cache()
2254
+
2255
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
2256
+ self.stashed_model_static_state = _export_static_state(
2257
+ self.tp_worker.worker.model_runner.model
2258
+ )
2259
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2260
+
2187
2261
  return ReleaseMemoryOccupationReqOutput()
2188
2262
 
2189
2263
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2190
- self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2191
- self.memory_saver_adapter.resume()
2192
- _import_static_state(
2193
- self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
2194
- )
2195
- del self.stashed_model_static_state
2264
+ tags = recv_req.tags
2265
+ if tags is None or len(tags) == 0:
2266
+ tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2267
+
2268
+ if GPU_MEMORY_TYPE_WEIGHTS in tags:
2269
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2270
+ _import_static_state(
2271
+ self.tp_worker.worker.model_runner.model,
2272
+ self.stashed_model_static_state,
2273
+ )
2274
+ del self.stashed_model_static_state
2275
+
2276
+ if GPU_MEMORY_TYPE_KV_CACHE in tags:
2277
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2278
+
2196
2279
  return ResumeMemoryOccupationReqOutput()
2197
2280
 
2198
2281
  def slow_down(self, recv_req: SlowDownReqInput):
@@ -2421,8 +2504,10 @@ class Scheduler(
2421
2504
  if self.profiler_decode_ct > self.profiler_target_decode_ct:
2422
2505
  if self.profile_in_progress:
2423
2506
  self.stop_profile(stage=ForwardMode.DECODE)
2507
+ elif batch.forward_mode.is_idle():
2508
+ pass
2424
2509
  else:
2425
- raise RuntimeError("unsupported profile stage")
2510
+ raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2426
2511
  else:
2427
2512
  # Check profiler
2428
2513
  if (
@@ -0,0 +1,226 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Centralized template management for chat templates and completion templates.
16
+
17
+ This module provides a unified interface for managing both chat conversation templates
18
+ and code completion templates, eliminating global state and improving modularity.
19
+ """
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+ from typing import Optional
25
+
26
+ from sglang.srt.code_completion_parser import (
27
+ CompletionTemplate,
28
+ FimPosition,
29
+ completion_template_exists,
30
+ register_completion_template,
31
+ )
32
+ from sglang.srt.conversation import (
33
+ Conversation,
34
+ SeparatorStyle,
35
+ chat_template_exists,
36
+ get_conv_template_by_model_path,
37
+ register_conv_template,
38
+ )
39
+ from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class TemplateManager:
45
+ """
46
+ Centralized manager for chat and completion templates.
47
+
48
+ This class encapsulates all template-related state and operations,
49
+ eliminating the need for global variables and providing a clean
50
+ interface for template management.
51
+ """
52
+
53
+ def __init__(self):
54
+ self._chat_template_name: Optional[str] = None
55
+ self._completion_template_name: Optional[str] = None
56
+ self._jinja_template_content_format: Optional[str] = None
57
+
58
+ @property
59
+ def chat_template_name(self) -> Optional[str]:
60
+ """Get the current chat template name."""
61
+ return self._chat_template_name
62
+
63
+ @property
64
+ def completion_template_name(self) -> Optional[str]:
65
+ """Get the current completion template name."""
66
+ return self._completion_template_name
67
+
68
+ @property
69
+ def jinja_template_content_format(self) -> Optional[str]:
70
+ """Get the detected template content format ('string' or 'openai' or None)."""
71
+ return self._jinja_template_content_format
72
+
73
+ def load_chat_template(
74
+ self, tokenizer_manager, chat_template_arg: str, model_path: str
75
+ ) -> None:
76
+ """
77
+ Load a chat template from various sources.
78
+
79
+ Args:
80
+ tokenizer_manager: The tokenizer manager instance
81
+ chat_template_arg: Template name or file path
82
+ model_path: Path to the model
83
+ """
84
+ logger.info(f"Loading chat template: {chat_template_arg}")
85
+
86
+ if not chat_template_exists(chat_template_arg):
87
+ if not os.path.exists(chat_template_arg):
88
+ raise RuntimeError(
89
+ f"Chat template {chat_template_arg} is not a built-in template name "
90
+ "or a valid chat template file path."
91
+ )
92
+
93
+ if chat_template_arg.endswith(".jinja"):
94
+ self._load_jinja_template(tokenizer_manager, chat_template_arg)
95
+ else:
96
+ self._load_json_chat_template(chat_template_arg)
97
+ else:
98
+ self._chat_template_name = chat_template_arg
99
+
100
+ def guess_chat_template_from_model_path(self, model_path: str) -> None:
101
+ """
102
+ Infer chat template name from model path.
103
+
104
+ Args:
105
+ model_path: Path to the model
106
+ """
107
+ template_name = get_conv_template_by_model_path(model_path)
108
+ if template_name is not None:
109
+ logger.info(f"Inferred chat template from model path: {template_name}")
110
+ self._chat_template_name = template_name
111
+
112
+ def load_completion_template(self, completion_template_arg: str) -> None:
113
+ """
114
+ Load completion template for code completion.
115
+
116
+ Args:
117
+ completion_template_arg: Template name or file path
118
+ """
119
+ logger.info(f"Loading completion template: {completion_template_arg}")
120
+
121
+ if not completion_template_exists(completion_template_arg):
122
+ if not os.path.exists(completion_template_arg):
123
+ raise RuntimeError(
124
+ f"Completion template {completion_template_arg} is not a built-in template name "
125
+ "or a valid completion template file path."
126
+ )
127
+
128
+ self._load_json_completion_template(completion_template_arg)
129
+ else:
130
+ self._completion_template_name = completion_template_arg
131
+
132
+ def initialize_templates(
133
+ self,
134
+ tokenizer_manager,
135
+ model_path: str,
136
+ chat_template: Optional[str] = None,
137
+ completion_template: Optional[str] = None,
138
+ ) -> None:
139
+ """
140
+ Initialize all templates based on provided configuration.
141
+
142
+ Args:
143
+ tokenizer_manager: The tokenizer manager instance
144
+ model_path: Path to the model
145
+ chat_template: Optional chat template name/path
146
+ completion_template: Optional completion template name/path
147
+ """
148
+ # Load chat template
149
+ if chat_template:
150
+ self.load_chat_template(tokenizer_manager, chat_template, model_path)
151
+ else:
152
+ self.guess_chat_template_from_model_path(model_path)
153
+
154
+ # Load completion template
155
+ if completion_template:
156
+ self.load_completion_template(completion_template)
157
+
158
+ def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
159
+ """Load a Jinja template file."""
160
+ with open(template_path, "r") as f:
161
+ chat_template = "".join(f.readlines()).strip("\n")
162
+ tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
163
+ self._chat_template_name = None
164
+ # Detect content format from the loaded template
165
+ self._jinja_template_content_format = detect_jinja_template_content_format(
166
+ chat_template
167
+ )
168
+ logger.info(
169
+ f"Detected chat template content format: {self._jinja_template_content_format}"
170
+ )
171
+
172
+ def _load_json_chat_template(self, template_path: str) -> None:
173
+ """Load a JSON chat template file."""
174
+ assert template_path.endswith(
175
+ ".json"
176
+ ), "unrecognized format of chat template file"
177
+
178
+ with open(template_path, "r") as filep:
179
+ template = json.load(filep)
180
+ try:
181
+ sep_style = SeparatorStyle[template["sep_style"]]
182
+ except KeyError:
183
+ raise ValueError(
184
+ f"Unknown separator style: {template['sep_style']}"
185
+ ) from None
186
+
187
+ register_conv_template(
188
+ Conversation(
189
+ name=template["name"],
190
+ system_template=template["system"] + "\n{system_message}",
191
+ system_message=template.get("system_message", ""),
192
+ roles=(template["user"], template["assistant"]),
193
+ sep_style=sep_style,
194
+ sep=template.get("sep", "\n"),
195
+ stop_str=template["stop_str"],
196
+ ),
197
+ override=True,
198
+ )
199
+ self._chat_template_name = template["name"]
200
+
201
+ def _load_json_completion_template(self, template_path: str) -> None:
202
+ """Load a JSON completion template file."""
203
+ assert template_path.endswith(
204
+ ".json"
205
+ ), "unrecognized format of completion template file"
206
+
207
+ with open(template_path, "r") as filep:
208
+ template = json.load(filep)
209
+ try:
210
+ fim_position = FimPosition[template["fim_position"]]
211
+ except KeyError:
212
+ raise ValueError(
213
+ f"Unknown fim position: {template['fim_position']}"
214
+ ) from None
215
+
216
+ register_completion_template(
217
+ CompletionTemplate(
218
+ name=template["name"],
219
+ fim_begin_token=template["fim_begin_token"],
220
+ fim_middle_token=template["fim_middle_token"],
221
+ fim_end_token=template["fim_end_token"],
222
+ fim_position=fim_position,
223
+ ),
224
+ override=True,
225
+ )
226
+ self._completion_template_name = template["name"]