sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,11 @@ if TYPE_CHECKING:
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
28
  from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
+ from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
30
+ MooncakeStore,
31
+ get_hash_str_mooncake,
32
+ )
33
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
29
34
 
30
35
  logger = logging.getLogger(__name__)
31
36
 
@@ -124,7 +129,7 @@ class TransferBuffer:
124
129
  """
125
130
 
126
131
  def __init__(
127
- self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
132
+ self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
128
133
  ) -> None:
129
134
  self.stop_event = stop_event
130
135
  self.buffers = Queue(maxsize=buffer_count)
@@ -250,17 +255,39 @@ class HiCacheController:
250
255
  self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
251
256
  if self.tp_world_size > 1:
252
257
  group_ranks = torch.distributed.get_process_group_ranks(tp_group)
253
- self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
258
+ self.prefetch_tp_group = torch.distributed.new_group(
259
+ group_ranks, backend="gloo"
260
+ )
261
+ self.backup_tp_group = torch.distributed.new_group(
262
+ group_ranks, backend="gloo"
263
+ )
254
264
 
255
265
  if storage_backend == "file":
256
266
  self.storage_backend = HiCacheFile()
257
- self.enable_storage = True
258
- # todo: threshold policy for prefetching
259
- self.prefetch_threshold = max(prefetch_threshold, self.page_size)
267
+ self.get_hash_str = get_hash_str
268
+ elif storage_backend == "mooncake":
269
+ self.storage_backend = MooncakeStore()
270
+ self.get_hash_str = get_hash_str_mooncake
271
+ self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
272
+ elif storage_backend == "hf3fs":
273
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
274
+
275
+ rank = get_tensor_model_parallel_rank()
276
+ bytes_per_page = (
277
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
278
+ )
279
+ dtype = mem_pool_host.dtype
280
+ self.storage_backend = HiCacheHF3FS.from_env_config(
281
+ rank, bytes_per_page, dtype
282
+ )
283
+ self.get_hash_str = get_hash_str
260
284
  else:
261
285
  raise NotImplementedError(
262
286
  f"Unsupported storage backend: {storage_backend}"
263
287
  )
288
+ self.enable_storage = True
289
+ # todo: threshold policy for prefetching
290
+ self.prefetch_threshold = max(prefetch_threshold, self.page_size)
264
291
 
265
292
  self.load_cache_event = load_cache_event
266
293
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -515,6 +542,37 @@ class HiCacheController:
515
542
  operation.mark_done()
516
543
  return operation.completed_tokens, operation.hash_value
517
544
 
545
+ def generic_page_transfer(self, operation, batch_size=8):
546
+ for i in range(0, len(operation.hash_value), batch_size):
547
+ page_hashes = operation.hash_value[i : i + batch_size]
548
+ page_data = self.storage_backend.batch_get(page_hashes)
549
+ if page_data is None:
550
+ logger.warning(
551
+ f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
552
+ )
553
+ break
554
+ completed_tokens = operation.completed_tokens
555
+ if operation.increment(self.page_size * len(page_hashes)):
556
+ for i in range(len(page_hashes)):
557
+ self.mem_pool_host.set_from_flat_data_page(
558
+ operation.host_indices[completed_tokens],
559
+ page_data[i],
560
+ )
561
+ completed_tokens += self.page_size
562
+ else:
563
+ # operation terminated by controller, release pre-allocated memory
564
+ self.mem_pool_host.free(
565
+ operation.host_indices[operation.completed_tokens :]
566
+ )
567
+ break
568
+
569
+ def mooncake_page_transfer(self, operation):
570
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
571
+ operation.hash_value, operation.host_indices
572
+ )
573
+ self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
574
+ operation.increment(len(operation.hash_value) * self.page_size)
575
+
518
576
  def prefetch_io_aux_func(self):
519
577
  """
520
578
  Auxiliary function conducting IO operations for prefetching.
@@ -522,24 +580,10 @@ class HiCacheController:
522
580
  while not self.stop_event.is_set():
523
581
  try:
524
582
  operation = self.prefetch_buffer.get(block=True, timeout=1)
525
- for h in operation.hash_value:
526
- page_data = self.storage_backend.get(h)
527
- if page_data is None:
528
- logger.warning(
529
- f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
530
- )
531
- break
532
- if operation.increment(self.page_size):
533
- self.mem_pool_host.set_from_flat_data_page(
534
- operation.host_indices[operation.completed_tokens],
535
- page_data,
536
- )
537
- else:
538
- # operation terminated by controller, release pre-allocated memory
539
- self.mem_pool_host.free(
540
- operation.host_indices[operation.completed_tokens :]
541
- )
542
- break
583
+ if isinstance(self.storage_backend, MooncakeStore):
584
+ self.mooncake_page_transfer(operation)
585
+ else:
586
+ self.generic_page_transfer(operation)
543
587
  except Empty:
544
588
  continue
545
589
 
@@ -563,18 +607,27 @@ class HiCacheController:
563
607
  remaining_tokens = len(tokens_to_fetch)
564
608
  hash_value = []
565
609
  while remaining_tokens >= self.page_size:
566
- last_hash = get_hash_str(
610
+ last_hash = self.get_hash_str(
567
611
  tokens_to_fetch[
568
612
  storage_hit_count : storage_hit_count + self.page_size
569
613
  ],
570
614
  last_hash,
571
615
  )
572
- if self.storage_backend.exists(last_hash):
573
- storage_hit_count += self.page_size
574
- hash_value.append(last_hash)
575
- remaining_tokens -= self.page_size
576
- else:
577
- break
616
+
617
+ # todo, more unified interface
618
+ if not isinstance(self.storage_backend, MooncakeStore):
619
+ if not self.storage_backend.exists(last_hash):
620
+ break
621
+ hash_value.append(last_hash)
622
+ storage_hit_count += self.page_size
623
+ remaining_tokens -= self.page_size
624
+
625
+ if isinstance(self.storage_backend, MooncakeStore):
626
+ # deferring to batch exists for mooncake store
627
+ exist_result = self.storage_backend.exists(hash_value)
628
+ storage_hit_count = (
629
+ sum(1 for v in exist_result.values() if v != 0) * self.page_size
630
+ )
578
631
 
579
632
  if self.tp_world_size > 1:
580
633
  storage_hit_count_tensor = torch.tensor(
@@ -583,7 +636,7 @@ class HiCacheController:
583
636
  torch.distributed.all_reduce(
584
637
  storage_hit_count_tensor,
585
638
  op=torch.distributed.ReduceOp.MIN,
586
- group=self.tp_group,
639
+ group=self.prefetch_tp_group,
587
640
  )
588
641
  storage_hit_count = storage_hit_count_tensor.item()
589
642
 
@@ -622,6 +675,47 @@ class HiCacheController:
622
675
  self.backup_queue.put(operation)
623
676
  return operation.id
624
677
 
678
+ def generic_page_backup(self, operation, batch_size=8):
679
+ for i in range(0, len(operation.hash_value), batch_size):
680
+ page_hashes = operation.hash_value[i : i + batch_size]
681
+ page_data = [
682
+ self.mem_pool_host.get_flat_data_pages(
683
+ operation.host_indices[j * self.page_size]
684
+ )
685
+ for j in range(i, i + len(page_hashes))
686
+ ]
687
+ success = self.storage_backend.batch_set(page_hashes, page_data)
688
+ if not success:
689
+ logger.warning(f"Failed to write page {page_hashes} to storage.")
690
+ break
691
+ operation.completed_tokens += self.page_size * len(page_hashes)
692
+
693
+ def mooncake_page_backup(self, operation):
694
+ if len(operation.hash_value):
695
+ exist_hashvalues = self.storage_backend.exists(operation.hash_value)
696
+ indices = operation.host_indices.tolist()
697
+ non_exist_keys = []
698
+ non_exist_indices = []
699
+ for i in range(len(operation.hash_value)):
700
+ if not exist_hashvalues[operation.hash_value[i]]:
701
+ non_exist_keys.append(operation.hash_value[i])
702
+ non_exist_indices.extend(
703
+ indices[i * self.page_size : (i + 1) * self.page_size]
704
+ )
705
+ if len(non_exist_keys) > 0:
706
+ key_strs, buffer_ptrs, buffer_sizes = (
707
+ self.mem_pool_host.get_buffer_meta(
708
+ non_exist_keys, non_exist_indices
709
+ )
710
+ )
711
+ # TODO: check the return value of batch set to see how many tokens are set successfully
712
+ self.storage_backend.batch_set(
713
+ key_strs,
714
+ target_location=buffer_ptrs,
715
+ target_sizes=buffer_sizes,
716
+ )
717
+ operation.completed_tokens += len(operation.hash_value) * self.page_size
718
+
625
719
  def backup_thread_func(self):
626
720
  """
627
721
  Manage backup operations from host memory to storage backend.
@@ -635,21 +729,25 @@ class HiCacheController:
635
729
  last_hash = operation.last_hash
636
730
  tokens_to_backup = operation.token_ids
637
731
 
638
- for i in range(0, len(tokens_to_backup), self.page_size):
639
- last_hash = get_hash_str(
640
- tokens_to_backup[i : i + self.page_size], last_hash
641
- )
642
- success = self.storage_backend.set(
732
+ backup_hit_count = 0
733
+ remaining_tokens = len(tokens_to_backup)
734
+ hash_value = []
735
+ while remaining_tokens >= self.page_size:
736
+ last_hash = self.get_hash_str(
737
+ tokens_to_backup[
738
+ backup_hit_count : backup_hit_count + self.page_size
739
+ ],
643
740
  last_hash,
644
- self.mem_pool_host.get_flat_data_page(
645
- operation.host_indices[i]
646
- ),
647
741
  )
648
- if not success:
649
- logger.warning(f"Failed to write page {last_hash} to storage.")
650
- break
651
- operation.completed_tokens += self.page_size
652
- operation.hash_value.append(last_hash)
742
+ backup_hit_count += self.page_size
743
+ hash_value.append(last_hash)
744
+ remaining_tokens -= self.page_size
745
+ operation.hash_value = hash_value
746
+
747
+ if isinstance(self.storage_backend, MooncakeStore):
748
+ self.mooncake_page_backup(operation)
749
+ else:
750
+ self.generic_page_backup(operation)
653
751
 
654
752
  min_completed_tokens = operation.completed_tokens
655
753
  if self.tp_world_size > 1:
@@ -659,7 +757,7 @@ class HiCacheController:
659
757
  torch.distributed.all_reduce(
660
758
  completed_tokens_tensor,
661
759
  op=torch.distributed.ReduceOp.MIN,
662
- group=self.tp_group,
760
+ group=self.backup_tp_group,
663
761
  )
664
762
  min_completed_tokens = completed_tokens_tensor.item()
665
763
 
@@ -26,6 +26,7 @@ import zmq
26
26
 
27
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
28
28
  from sglang.srt.managers.io_struct import (
29
+ BlockReqInput,
29
30
  TokenizedEmbeddingReqInput,
30
31
  TokenizedGenerateReqInput,
31
32
  )
@@ -221,6 +222,7 @@ class DataParallelController:
221
222
  + ((pp_rank % pp_size_per_node) * tp_size_per_node)
222
223
  + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
223
224
  )
225
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
224
226
  proc = mp.Process(
225
227
  target=run_scheduler_process,
226
228
  args=(
@@ -228,6 +230,7 @@ class DataParallelController:
228
230
  rank_port_args,
229
231
  gpu_id,
230
232
  tp_rank,
233
+ moe_ep_rank,
231
234
  pp_rank,
232
235
  dp_rank,
233
236
  writer,
@@ -282,6 +285,9 @@ class DataParallelController:
282
285
  ),
283
286
  ):
284
287
  self.dispatching(recv_req)
288
+ elif isinstance(recv_req, BlockReqInput):
289
+ for worker in self.workers:
290
+ worker.send_pyobj(recv_req)
285
291
  else:
286
292
  # Send other control messages to first worker of tp group
287
293
  for worker in self.workers[:: self.control_message_step]:
@@ -152,8 +152,6 @@ class GenerateReqInput:
152
152
  else:
153
153
  self._normalize_batch_inputs()
154
154
 
155
- self._validate_session_params()
156
-
157
155
  def _validate_inputs(self):
158
156
  """Validate that the input configuration is valid."""
159
157
  if (
@@ -911,6 +909,8 @@ class AbortReq:
911
909
  rid: str = ""
912
910
  # Whether to abort all requests
913
911
  abort_all: bool = False
912
+ # The finished reason data
913
+ finished_reason: Optional[Dict[str, Any]] = None
914
914
 
915
915
 
916
916
  @dataclass
@@ -1101,3 +1101,13 @@ class LoRAUpdateResult:
1101
1101
 
1102
1102
 
1103
1103
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1104
+
1105
+
1106
+ class BlockReqType(Enum):
1107
+ BLOCK = 1
1108
+ UNBLOCK = 2
1109
+
1110
+
1111
+ @dataclass
1112
+ class BlockReqInput:
1113
+ type: BlockReqType