sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ if TYPE_CHECKING:
25
25
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
- from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
30
- MooncakeStore,
31
- get_hash_str_mooncake,
32
- )
33
- from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
34
28
 
35
29
  logger = logging.getLogger(__name__)
36
30
 
@@ -237,40 +231,36 @@ class HiCacheController:
237
231
  self.mem_pool_host = mem_pool_host
238
232
  self.write_policy = write_policy
239
233
  self.page_size = page_size
240
- # using kernel for small page KV cache transfer and DMA for large pages
241
- if not io_backend:
242
- IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
243
- self.io_backend = (
244
- "direct"
245
- if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
246
- else "kernel"
247
- )
248
- else:
249
- self.io_backend = io_backend
234
+ self.io_backend = io_backend
250
235
 
251
236
  self.enable_storage = False
252
237
  # todo: move backend initialization to storage backend module
253
238
  if storage_backend is not None:
254
- # create a new communication group for synchronizing storage operations across TP workers
255
- self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
256
- if self.tp_world_size > 1:
257
- group_ranks = torch.distributed.get_process_group_ranks(tp_group)
258
- self.prefetch_tp_group = torch.distributed.new_group(
259
- group_ranks, backend="gloo"
260
- )
261
- self.backup_tp_group = torch.distributed.new_group(
262
- group_ranks, backend="gloo"
263
- )
239
+ self.storage_backend_type = storage_backend
240
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
264
241
 
265
242
  if storage_backend == "file":
266
243
  self.storage_backend = HiCacheFile()
267
244
  self.get_hash_str = get_hash_str
245
+ elif storage_backend == "nixl":
246
+ from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
247
+
248
+ self.storage_backend = HiCacheNixl()
249
+ self.get_hash_str = get_hash_str
268
250
  elif storage_backend == "mooncake":
251
+ from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
252
+ MooncakeStore,
253
+ get_hash_str_mooncake,
254
+ )
255
+
269
256
  self.storage_backend = MooncakeStore()
270
257
  self.get_hash_str = get_hash_str_mooncake
271
258
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
272
259
  elif storage_backend == "hf3fs":
273
260
  from sglang.srt.distributed import get_tensor_model_parallel_rank
261
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
262
+ HiCacheHF3FS,
263
+ )
274
264
 
275
265
  rank = get_tensor_model_parallel_rank()
276
266
  bytes_per_page = (
@@ -288,6 +278,16 @@ class HiCacheController:
288
278
  self.enable_storage = True
289
279
  # todo: threshold policy for prefetching
290
280
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
281
+ # create a new communication group for synchronizing storage operations across TP workers
282
+ self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
283
+ if self.tp_world_size > 1:
284
+ group_ranks = torch.distributed.get_process_group_ranks(tp_group)
285
+ self.prefetch_tp_group = torch.distributed.new_group(
286
+ group_ranks, backend="gloo"
287
+ )
288
+ self.backup_tp_group = torch.distributed.new_group(
289
+ group_ranks, backend="gloo"
290
+ )
291
291
 
292
292
  self.load_cache_event = load_cache_event
293
293
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -439,11 +439,8 @@ class HiCacheController:
439
439
  host_indices, device_indices = self.move_indices(
440
440
  operation.host_indices, operation.device_indices
441
441
  )
442
- self.mem_pool_device.backup_to_host_all_layer(
443
- self.mem_pool_host,
444
- host_indices,
445
- device_indices,
446
- self.io_backend,
442
+ self.mem_pool_host.backup_from_device_all_layer(
443
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
447
444
  )
448
445
  self.write_stream.synchronize()
449
446
  self.mem_pool_host.complete_io(operation.host_indices)
@@ -483,8 +480,8 @@ class HiCacheController:
483
480
  batch_operation.host_indices, batch_operation.device_indices
484
481
  )
485
482
  for i in range(self.mem_pool_host.layer_num):
486
- self.mem_pool_device.load_from_host_per_layer(
487
- self.mem_pool_host,
483
+ self.mem_pool_host.load_to_device_per_layer(
484
+ self.mem_pool_device,
488
485
  host_indices,
489
486
  device_indices,
490
487
  i,
@@ -545,7 +542,11 @@ class HiCacheController:
545
542
  def generic_page_transfer(self, operation, batch_size=8):
546
543
  for i in range(0, len(operation.hash_value), batch_size):
547
544
  page_hashes = operation.hash_value[i : i + batch_size]
548
- page_data = self.storage_backend.batch_get(page_hashes)
545
+ # todo: zero copy
546
+ dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
547
+ page_hashes
548
+ )
549
+ page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
549
550
  if page_data is None:
550
551
  logger.warning(
551
552
  f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
@@ -573,6 +574,9 @@ class HiCacheController:
573
574
  self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
574
575
  operation.increment(len(operation.hash_value) * self.page_size)
575
576
 
577
+ def is_mooncake_backend(self):
578
+ return self.storage_backend_type == "mooncake"
579
+
576
580
  def prefetch_io_aux_func(self):
577
581
  """
578
582
  Auxiliary function conducting IO operations for prefetching.
@@ -580,7 +584,7 @@ class HiCacheController:
580
584
  while not self.stop_event.is_set():
581
585
  try:
582
586
  operation = self.prefetch_buffer.get(block=True, timeout=1)
583
- if isinstance(self.storage_backend, MooncakeStore):
587
+ if self.is_mooncake_backend():
584
588
  self.mooncake_page_transfer(operation)
585
589
  else:
586
590
  self.generic_page_transfer(operation)
@@ -615,14 +619,14 @@ class HiCacheController:
615
619
  )
616
620
 
617
621
  # todo, more unified interface
618
- if not isinstance(self.storage_backend, MooncakeStore):
622
+ if not self.is_mooncake_backend():
619
623
  if not self.storage_backend.exists(last_hash):
620
624
  break
621
625
  hash_value.append(last_hash)
622
626
  storage_hit_count += self.page_size
623
627
  remaining_tokens -= self.page_size
624
628
 
625
- if isinstance(self.storage_backend, MooncakeStore):
629
+ if self.is_mooncake_backend():
626
630
  # deferring to batch exists for mooncake store
627
631
  exist_result = self.storage_backend.exists(hash_value)
628
632
  storage_hit_count = (
@@ -679,7 +683,7 @@ class HiCacheController:
679
683
  for i in range(0, len(operation.hash_value), batch_size):
680
684
  page_hashes = operation.hash_value[i : i + batch_size]
681
685
  page_data = [
682
- self.mem_pool_host.get_flat_data_pages(
686
+ self.mem_pool_host.get_flat_data_page(
683
687
  operation.host_indices[j * self.page_size]
684
688
  )
685
689
  for j in range(i, i + len(page_hashes))
@@ -744,7 +748,7 @@ class HiCacheController:
744
748
  remaining_tokens -= self.page_size
745
749
  operation.hash_value = hash_value
746
750
 
747
- if isinstance(self.storage_backend, MooncakeStore):
751
+ if self.is_mooncake_backend():
748
752
  self.mooncake_page_backup(operation)
749
753
  else:
750
754
  self.generic_page_backup(operation)
@@ -16,9 +16,13 @@
16
16
  import logging
17
17
  import multiprocessing as mp
18
18
  import signal
19
+ import struct
20
+ import sys
19
21
  import threading
20
22
  import time
21
23
  from enum import Enum, auto
24
+ from multiprocessing import shared_memory
25
+ from typing import Dict, List
22
26
 
23
27
  import psutil
24
28
  import setproctitle
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
32
36
  )
33
37
  from sglang.srt.managers.schedule_batch import Req
34
38
  from sglang.srt.managers.scheduler import run_scheduler_process
39
+ from sglang.srt.managers.utils import DPBalanceMeta
35
40
  from sglang.srt.server_args import PortArgs, ServerArgs
36
41
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
37
42
  from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
45
50
 
46
51
  ROUND_ROBIN = auto()
47
52
  SHORTEST_QUEUE = auto()
53
+ MINIMUM_TOKENS = auto()
48
54
 
49
55
  @classmethod
50
56
  def from_str(cls, method: str):
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
58
64
  class DataParallelController:
59
65
  """A controller that dispatches requests to multiple data parallel workers."""
60
66
 
61
- def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
67
+ def __init__(
68
+ self,
69
+ server_args: ServerArgs,
70
+ port_args: PortArgs,
71
+ dp_balance_meta: DPBalanceMeta,
72
+ ) -> None:
73
+ # for dp balance
74
+ self.global_balance_id = 0
75
+ self.balance_meta = dp_balance_meta
76
+
62
77
  # Parse args
63
78
  self.max_total_num_tokens = None
64
79
  self.server_args = server_args
@@ -79,6 +94,7 @@ class DataParallelController:
79
94
  dispatch_lookup = {
80
95
  LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
81
96
  LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
97
+ LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
82
98
  }
83
99
  self.dispatching = dispatch_lookup[self.load_balance_method]
84
100
 
@@ -234,6 +250,7 @@ class DataParallelController:
234
250
  pp_rank,
235
251
  dp_rank,
236
252
  writer,
253
+ self.balance_meta,
237
254
  ),
238
255
  )
239
256
  with memory_saver_adapter.configure_subprocess():
@@ -269,6 +286,33 @@ class DataParallelController:
269
286
  def shortest_queue_scheduler(self, input_requests):
270
287
  raise NotImplementedError()
271
288
 
289
+ def minimum_tokens_scheduler(self, req):
290
+ # This variable corresponds to the balance_id in TokenizedGenerateReqInput.
291
+ # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
292
+ def get_next_global_balance_id() -> int:
293
+ INT32_MAX = 2147483647
294
+ current_id = self.global_balance_id
295
+ self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
296
+ return current_id
297
+
298
+ req.dp_balance_id = get_next_global_balance_id()
299
+ with self.balance_meta.mutex:
300
+ # 1. local_tokens represents the tokens currently inferring on the worker,
301
+ # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
302
+ onfly_info = self.balance_meta.get_shared_onfly()
303
+ local_tokens = self.balance_meta.get_shared_local_tokens()
304
+ total_tokens = [
305
+ local_token + sum(onfly_dict.values())
306
+ for local_token, onfly_dict in zip(local_tokens, onfly_info)
307
+ ]
308
+ target_worker = total_tokens.index(min(total_tokens))
309
+ onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
310
+ # 2. write the new onfly info to the shm
311
+ self.balance_meta.set_shared_onfly_info(onfly_info)
312
+
313
+ # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
314
+ self.workers[target_worker].send_pyobj(req)
315
+
272
316
  def event_loop(self):
273
317
  while True:
274
318
  while True:
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
302
346
  setproctitle.setproctitle("sglang::data_parallel_controller")
303
347
  configure_logger(server_args)
304
348
  parent_process = psutil.Process().parent()
349
+ balance_meta = DPBalanceMeta(server_args.dp_size)
305
350
 
306
351
  try:
307
- controller = DataParallelController(server_args, port_args)
352
+ controller = DataParallelController(
353
+ server_args, port_args, dp_balance_meta=balance_meta
354
+ )
308
355
  pipe_writer.send(
309
356
  {
310
357
  "status": "ready",
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
323
370
  traceback = get_exception_traceback()
324
371
  logger.error(f"DataParallelController hit an exception: {traceback}")
325
372
  parent_process.send_signal(signal.SIGQUIT)
373
+ finally:
374
+ # we need to destruct mp.Manager() in balance_meta
375
+ balance_meta.destructor()
@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
523
523
  # For data parallel rank routing
524
524
  data_parallel_rank: Optional[int] = None
525
525
 
526
+ # For dp balance
527
+ dp_balance_id: int = -1
528
+
526
529
 
527
530
  @dataclass
528
531
  class EmbeddingReqInput:
@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
648
651
  token_type_ids: List[int]
649
652
  # Dummy sampling params for compatibility
650
653
  sampling_params: SamplingParams
654
+ # For dp balance
655
+ dp_balance_id: int = -1
651
656
 
652
657
 
653
658
  @dataclass
@@ -1097,7 +1102,7 @@ class UnloadLoRAAdapterReqInput:
1097
1102
  class LoRAUpdateResult:
1098
1103
  success: bool
1099
1104
  error_message: Optional[str] = None
1100
- loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1105
+ loaded_adapters: Optional[Dict[str, LoRARef]] = None
1101
1106
 
1102
1107
 
1103
1108
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
51
51
  ScheduleBatchDisaggregationDecodeMixin,
52
52
  )
53
53
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
54
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
54
55
  from sglang.srt.mem_cache.allocator import (
55
56
  BaseTokenToKVPoolAllocator,
56
57
  SWATokenToKVPoolAllocator,
@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
85
86
  "enable_dp_attention",
86
87
  "enable_two_batch_overlap",
87
88
  "enable_dp_lm_head",
88
- "enable_deepep_moe",
89
+ "moe_a2a_backend",
89
90
  "deepep_mode",
90
- "enable_ep_moe",
91
91
  "enable_flashinfer_cutlass_moe",
92
92
  "enable_flashinfer_trtllm_moe",
93
93
  "enable_flashinfer_allreduce_fusion",
@@ -108,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
108
108
  "weight_loader_disable_mmap",
109
109
  "enable_triton_kernel_moe",
110
110
  "enable_multimodal",
111
+ "enable_symm_mem",
111
112
  ]
112
113
 
113
114
  # Put some global args for easy access
@@ -455,7 +455,9 @@ class PrefillAdder:
455
455
  if not self.is_hybrid:
456
456
  # Skip this logic for swa. The SWA has different memory management, and
457
457
  # this mechanism is underestimating the memory usage.
458
- cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
458
+ cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
459
+ req.extend_input_len
460
+ )
459
461
  tokens_freed = 0
460
462
  for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
461
463
  # tokens_left gives a reservative calculation as the last token is not stored
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
64
64
  )
65
65
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
66
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
67
68
  from sglang.srt.managers.io_struct import (
68
69
  AbortReq,
69
70
  CloseSessionReqInput,
@@ -125,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
125
126
  from sglang.srt.managers.session_controller import Session
126
127
  from sglang.srt.managers.tp_worker import TpModelWorker
127
128
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
- from sglang.srt.managers.utils import validate_input_length
129
+ from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
129
130
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
131
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
132
  from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
137
138
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
138
139
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
139
140
  from sglang.srt.utils import (
140
- DeepEPMode,
141
141
  DynamicGradMode,
142
142
  broadcast_pyobj,
143
143
  configure_gc_logger,
@@ -203,6 +203,7 @@ class Scheduler(
203
203
  moe_ep_rank: int,
204
204
  pp_rank: int,
205
205
  dp_rank: Optional[int],
206
+ dp_balance_meta: Optional[DPBalanceMeta] = None,
206
207
  ):
207
208
  # Parse args
208
209
  self.server_args = server_args
@@ -522,6 +523,15 @@ class Scheduler(
522
523
  ]
523
524
  )
524
525
 
526
+ self.balance_meta = dp_balance_meta
527
+ if (
528
+ server_args.enable_dp_attention
529
+ and server_args.load_balance_method == "minimum_tokens"
530
+ ):
531
+ assert dp_balance_meta is not None
532
+
533
+ self.recv_dp_balance_id_this_term = []
534
+
525
535
  def init_tokenizer(self):
526
536
  server_args = self.server_args
527
537
 
@@ -569,7 +579,23 @@ class Scheduler(
569
579
  page_size=self.page_size,
570
580
  )
571
581
  else:
572
- if self.enable_hierarchical_cache:
582
+ if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
583
+ # lazy import to avoid JIT overhead
584
+ from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
585
+
586
+ self.tree_cache = RadixCacheCpp(
587
+ disable=False,
588
+ use_hicache=self.enable_hierarchical_cache,
589
+ req_to_token_pool=self.req_to_token_pool,
590
+ token_to_kv_pool=self.token_to_kv_pool_allocator,
591
+ tp_cache_group=self.tp_cpu_group,
592
+ page_size=self.page_size,
593
+ hicache_ratio=server_args.hicache_ratio,
594
+ hicache_size=server_args.hicache_size,
595
+ hicache_write_policy=server_args.hicache_write_policy,
596
+ enable_kv_cache_events=self.enable_kv_cache_events,
597
+ )
598
+ elif self.enable_hierarchical_cache:
573
599
  self.tree_cache = HiRadixCache(
574
600
  req_to_token_pool=self.req_to_token_pool,
575
601
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
@@ -588,6 +614,7 @@ class Scheduler(
588
614
  == "fa3" # hot fix for incompatibility
589
615
  else server_args.hicache_io_backend
590
616
  ),
617
+ hicache_mem_layout=server_args.hicache_mem_layout,
591
618
  hicache_storage_backend=server_args.hicache_storage_backend,
592
619
  )
593
620
  self.tp_worker.register_hicache_layer_transfer_counter(
@@ -1032,6 +1059,12 @@ class Scheduler(
1032
1059
  self,
1033
1060
  recv_req: TokenizedGenerateReqInput,
1034
1061
  ):
1062
+ if (
1063
+ self.server_args.enable_dp_attention
1064
+ and self.server_args.load_balance_method == "minimum_tokens"
1065
+ ):
1066
+ self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1067
+
1035
1068
  # Create a new request
1036
1069
  if (
1037
1070
  recv_req.session_params is None
@@ -1442,6 +1475,11 @@ class Scheduler(
1442
1475
 
1443
1476
  # Handle DP attention
1444
1477
  if need_dp_attn_preparation:
1478
+ if (
1479
+ self.server_args.load_balance_method == "minimum_tokens"
1480
+ and self.forward_ct % 40 == 0
1481
+ ):
1482
+ self.handle_dp_balance_data(ret)
1445
1483
  ret = self.prepare_mlp_sync_batch(ret)
1446
1484
 
1447
1485
  return ret
@@ -1743,6 +1781,9 @@ class Scheduler(
1743
1781
  elif batch.forward_mode.is_dummy_first():
1744
1782
  self.set_next_batch_sampling_info_done(batch)
1745
1783
 
1784
+ self.maybe_send_health_check_signal()
1785
+
1786
+ def maybe_send_health_check_signal(self):
1746
1787
  if self.return_health_check_ct:
1747
1788
  # Return some signal for the health check.
1748
1789
  # This is used to prevent the health check signal being blocked by long context prefill.
@@ -1761,12 +1802,94 @@ class Scheduler(
1761
1802
  spec_algorithm=self.spec_algorithm,
1762
1803
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1763
1804
  enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1764
- enable_deepep_moe=self.server_args.enable_deepep_moe,
1765
- deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1805
+ enable_deepep_moe=MoeA2ABackend(
1806
+ self.server_args.moe_a2a_backend
1807
+ ).is_deepep(),
1808
+ deepep_mode=DeepEPMode(self.server_args.deepep_mode),
1766
1809
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1767
1810
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1768
1811
  )
1769
1812
 
1813
+ def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1814
+ def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1815
+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1816
+ recv_list = self.recv_dp_balance_id_this_term
1817
+ assert len(recv_list) <= 511, (
1818
+ "The number of requests received this round is too large. "
1819
+ "Please increase gather_tensor_size and onfly_info_size."
1820
+ )
1821
+ # The maximum size of the tensor used for gathering data from all workers.
1822
+ gather_tensor_size = 512
1823
+
1824
+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1825
+ recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1826
+ recv_tensor[0] = holding_tokens_list
1827
+ recv_tensor[1] = len(
1828
+ recv_list
1829
+ ) # The first element is the length of the list.
1830
+ recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1831
+ recv_list, dtype=torch.int32
1832
+ )
1833
+
1834
+ if self.tp_rank == 0:
1835
+ gathered_list = [
1836
+ torch.zeros(gather_tensor_size, dtype=torch.int32)
1837
+ for _ in range(self.balance_meta.num_workers)
1838
+ ]
1839
+ else:
1840
+ gathered_list = None
1841
+
1842
+ torch.distributed.gather(
1843
+ recv_tensor, gathered_list, group=self.tp_cpu_group
1844
+ )
1845
+
1846
+ gathered_id_list_per_worker = None
1847
+ if self.tp_rank == 0:
1848
+ gathered_id_list_per_worker = []
1849
+ holding_tokens_list = []
1850
+ for tensor in gathered_list:
1851
+ holding_tokens_list.append(tensor[0].item())
1852
+ list_length = tensor[1].item()
1853
+ gathered_id_list_per_worker.append(
1854
+ tensor[2 : list_length + 2].tolist()
1855
+ )
1856
+
1857
+ return gathered_id_list_per_worker, holding_tokens_list
1858
+
1859
+ def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1860
+ meta = self.balance_meta
1861
+
1862
+ with meta.mutex:
1863
+ onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1864
+ assert len(new_recv_rid_lists) == len(
1865
+ onfly_list
1866
+ ), "num_worker not equal"
1867
+ # 1.Check if the rid received by each worker this round is present in onfly.
1868
+ # If it is, remove the corresponding onfly item.
1869
+ worker_id = 0
1870
+ for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1871
+ for new_recv_rid in new_recv_rids:
1872
+ assert (
1873
+ new_recv_rid in on_fly_reqs
1874
+ ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1875
+ del on_fly_reqs[new_recv_rid]
1876
+ worker_id += 1
1877
+ # 2. Atomically write local_tokens and onfly into shm under the mutex
1878
+ meta.set_shared_onfly_info(onfly_list)
1879
+ meta.set_shared_local_tokens(local_tokens)
1880
+
1881
+ holding_tokens = self.get_load()
1882
+
1883
+ new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1884
+ holding_tokens
1885
+ )
1886
+
1887
+ self.recv_dp_balance_id_this_term.clear()
1888
+ if self.tp_rank == 0: # only first worker write info
1889
+ write_shared_dp_balance_info(
1890
+ new_recv_dp_balance_id_list, holding_token_list
1891
+ )
1892
+
1770
1893
  @staticmethod
1771
1894
  def prepare_mlp_sync_batch_raw(
1772
1895
  local_batch: ScheduleBatch,
@@ -2343,11 +2466,19 @@ class IdleSleeper:
2343
2466
 
2344
2467
  def __init__(self, sockets):
2345
2468
  self.poller = zmq.Poller()
2469
+ self.last_empty_time = time.time()
2346
2470
  for s in sockets:
2347
2471
  self.poller.register(s, zmq.POLLIN)
2348
2472
 
2349
2473
  def maybe_sleep(self):
2350
2474
  self.poller.poll(1000)
2475
+ if (
2476
+ global_config.torch_empty_cache_interval > 0
2477
+ and time.time() - self.last_empty_time
2478
+ > global_config.torch_empty_cache_interval
2479
+ ):
2480
+ self.last_empty_time = time.time()
2481
+ torch.cuda.empty_cache()
2351
2482
 
2352
2483
 
2353
2484
  def is_health_check_generate_req(recv_req):
@@ -2367,6 +2498,7 @@ def run_scheduler_process(
2367
2498
  pp_rank: int,
2368
2499
  dp_rank: Optional[int],
2369
2500
  pipe_writer,
2501
+ balance_meta: Optional[DPBalanceMeta] = None,
2370
2502
  ):
2371
2503
  # Generate the prefix
2372
2504
  prefix = ""
@@ -2400,7 +2532,14 @@ def run_scheduler_process(
2400
2532
  # Create a scheduler and run the event loop
2401
2533
  try:
2402
2534
  scheduler = Scheduler(
2403
- server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2535
+ server_args,
2536
+ port_args,
2537
+ gpu_id,
2538
+ tp_rank,
2539
+ moe_ep_rank,
2540
+ pp_rank,
2541
+ dp_rank,
2542
+ dp_balance_meta=balance_meta,
2404
2543
  )
2405
2544
  pipe_writer.send(
2406
2545
  {
@@ -84,26 +84,27 @@ class TemplateManager:
84
84
  if chat_template_arg:
85
85
  self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
86
86
  else:
87
- # Try HuggingFace template first
88
- hf_template = self._resolve_hf_chat_template(tokenizer_manager)
89
- if hf_template:
90
- self._jinja_template_content_format = (
91
- detect_jinja_template_content_format(hf_template)
92
- )
93
- logger.info(
94
- f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
95
- )
96
- return
97
-
98
- # Fallback to SGLang template guessing
87
+ # Guess chat template from model path
99
88
  self.guess_chat_template_from_model_path(model_path)
100
89
 
101
- # Set default format if no template was found
90
+ # If no pre-defined template was found, fallback to HuggingFace template
102
91
  if self._chat_template_name is None:
103
- self._jinja_template_content_format = "string"
104
- logger.info(
105
- "No chat template found, defaulting to 'string' content format"
106
- )
92
+ # Try HuggingFace template first
93
+ hf_template = self._resolve_hf_chat_template(tokenizer_manager)
94
+ if hf_template:
95
+ # override the chat template
96
+ tokenizer_manager.tokenizer.chat_template = hf_template
97
+ self._jinja_template_content_format = (
98
+ detect_jinja_template_content_format(hf_template)
99
+ )
100
+ logger.info(
101
+ f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
102
+ )
103
+ return
104
+
105
+ # Default to string content format if no template was found
106
+ self._jinja_template_content_format = "string"
107
+ logger.info("No chat template found, defaulting to 'string' content format")
107
108
 
108
109
  def _load_explicit_chat_template(
109
110
  self, tokenizer_manager, chat_template_arg: str
@@ -257,13 +258,15 @@ class TemplateManager:
257
258
 
258
259
  Returns the chat template string if found, None otherwise.
259
260
  """
260
- tokenizer = tokenizer_manager.tokenizer
261
-
262
- # Try to get AutoTokenizer chat template
263
261
  try:
264
- return tokenizer.get_chat_template()
262
+ if processor := tokenizer_manager.processor:
263
+ if hasattr(processor, "chat_template") and processor.chat_template:
264
+ return processor.chat_template
265
+ if tokenizer := tokenizer_manager.tokenizer:
266
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
267
+ return tokenizer.chat_template
265
268
  except Exception as e:
266
- logger.debug(f"Error getting chat template via get_chat_template(): {e}")
269
+ logger.debug(f"Error getting chat template: {e}")
267
270
 
268
271
  logger.debug("No HuggingFace chat template found")
269
272
  return None