sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ if TYPE_CHECKING:
25
25
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
28
 
30
29
  logger = logging.getLogger(__name__)
31
30
 
@@ -124,7 +123,7 @@ class TransferBuffer:
124
123
  """
125
124
 
126
125
  def __init__(
127
- self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
126
+ self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
128
127
  ) -> None:
129
128
  self.stop_event = stop_event
130
129
  self.buffers = Queue(maxsize=buffer_count)
@@ -232,35 +231,62 @@ class HiCacheController:
232
231
  self.mem_pool_host = mem_pool_host
233
232
  self.write_policy = write_policy
234
233
  self.page_size = page_size
235
- # using kernel for small page KV cache transfer and DMA for large pages
236
- if not io_backend:
237
- IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
238
- self.io_backend = (
239
- "direct"
240
- if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
241
- else "kernel"
242
- )
243
- else:
244
- self.io_backend = io_backend
234
+ self.io_backend = io_backend
245
235
 
246
236
  self.enable_storage = False
247
237
  # todo: move backend initialization to storage backend module
248
238
  if storage_backend is not None:
249
- # create a new communication group for synchronizing storage operations across TP workers
250
- self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
251
- if self.tp_world_size > 1:
252
- group_ranks = torch.distributed.get_process_group_ranks(tp_group)
253
- self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
239
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
254
240
 
255
241
  if storage_backend == "file":
256
242
  self.storage_backend = HiCacheFile()
257
- self.enable_storage = True
258
- # todo: threshold policy for prefetching
259
- self.prefetch_threshold = max(prefetch_threshold, self.page_size)
243
+ self.get_hash_str = get_hash_str
244
+ elif storage_backend == "nixl":
245
+ from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
246
+
247
+ self.storage_backend = HiCacheNixl()
248
+ self.get_hash_str = get_hash_str
249
+ elif storage_backend == "mooncake":
250
+ from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
251
+ MooncakeStore,
252
+ get_hash_str_mooncake,
253
+ )
254
+
255
+ self.storage_backend = MooncakeStore()
256
+ self.get_hash_str = get_hash_str_mooncake
257
+ self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
258
+ elif storage_backend == "hf3fs":
259
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
260
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
261
+ HiCacheHF3FS,
262
+ )
263
+
264
+ rank = get_tensor_model_parallel_rank()
265
+ bytes_per_page = (
266
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
267
+ )
268
+ dtype = mem_pool_host.dtype
269
+ self.storage_backend = HiCacheHF3FS.from_env_config(
270
+ rank, bytes_per_page, dtype
271
+ )
272
+ self.get_hash_str = get_hash_str
260
273
  else:
261
274
  raise NotImplementedError(
262
275
  f"Unsupported storage backend: {storage_backend}"
263
276
  )
277
+ self.enable_storage = True
278
+ # todo: threshold policy for prefetching
279
+ self.prefetch_threshold = max(prefetch_threshold, self.page_size)
280
+ # create a new communication group for synchronizing storage operations across TP workers
281
+ self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
282
+ if self.tp_world_size > 1:
283
+ group_ranks = torch.distributed.get_process_group_ranks(tp_group)
284
+ self.prefetch_tp_group = torch.distributed.new_group(
285
+ group_ranks, backend="gloo"
286
+ )
287
+ self.backup_tp_group = torch.distributed.new_group(
288
+ group_ranks, backend="gloo"
289
+ )
264
290
 
265
291
  self.load_cache_event = load_cache_event
266
292
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -412,11 +438,8 @@ class HiCacheController:
412
438
  host_indices, device_indices = self.move_indices(
413
439
  operation.host_indices, operation.device_indices
414
440
  )
415
- self.mem_pool_device.backup_to_host_all_layer(
416
- self.mem_pool_host,
417
- host_indices,
418
- device_indices,
419
- self.io_backend,
441
+ self.mem_pool_host.backup_from_device_all_layer(
442
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
420
443
  )
421
444
  self.write_stream.synchronize()
422
445
  self.mem_pool_host.complete_io(operation.host_indices)
@@ -456,8 +479,8 @@ class HiCacheController:
456
479
  batch_operation.host_indices, batch_operation.device_indices
457
480
  )
458
481
  for i in range(self.mem_pool_host.layer_num):
459
- self.mem_pool_device.load_from_host_per_layer(
460
- self.mem_pool_host,
482
+ self.mem_pool_host.load_to_device_per_layer(
483
+ self.mem_pool_device,
461
484
  host_indices,
462
485
  device_indices,
463
486
  i,
@@ -515,6 +538,41 @@ class HiCacheController:
515
538
  operation.mark_done()
516
539
  return operation.completed_tokens, operation.hash_value
517
540
 
541
+ def generic_page_transfer(self, operation, batch_size=8):
542
+ for i in range(0, len(operation.hash_value), batch_size):
543
+ page_hashes = operation.hash_value[i : i + batch_size]
544
+ # todo: zero copy
545
+ dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
546
+ page_hashes
547
+ )
548
+ page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
549
+ if page_data is None:
550
+ logger.warning(
551
+ f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
552
+ )
553
+ break
554
+ completed_tokens = operation.completed_tokens
555
+ if operation.increment(self.page_size * len(page_hashes)):
556
+ for i in range(len(page_hashes)):
557
+ self.mem_pool_host.set_from_flat_data_page(
558
+ operation.host_indices[completed_tokens],
559
+ page_data[i],
560
+ )
561
+ completed_tokens += self.page_size
562
+ else:
563
+ # operation terminated by controller, release pre-allocated memory
564
+ self.mem_pool_host.free(
565
+ operation.host_indices[operation.completed_tokens :]
566
+ )
567
+ break
568
+
569
+ def mooncake_page_transfer(self, operation):
570
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
571
+ operation.hash_value, operation.host_indices
572
+ )
573
+ self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
574
+ operation.increment(len(operation.hash_value) * self.page_size)
575
+
518
576
  def prefetch_io_aux_func(self):
519
577
  """
520
578
  Auxiliary function conducting IO operations for prefetching.
@@ -522,24 +580,10 @@ class HiCacheController:
522
580
  while not self.stop_event.is_set():
523
581
  try:
524
582
  operation = self.prefetch_buffer.get(block=True, timeout=1)
525
- for h in operation.hash_value:
526
- page_data = self.storage_backend.get(h)
527
- if page_data is None:
528
- logger.warning(
529
- f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
530
- )
531
- break
532
- if operation.increment(self.page_size):
533
- self.mem_pool_host.set_from_flat_data_page(
534
- operation.host_indices[operation.completed_tokens],
535
- page_data,
536
- )
537
- else:
538
- # operation terminated by controller, release pre-allocated memory
539
- self.mem_pool_host.free(
540
- operation.host_indices[operation.completed_tokens :]
541
- )
542
- break
583
+ if isinstance(self.storage_backend, MooncakeStore):
584
+ self.mooncake_page_transfer(operation)
585
+ else:
586
+ self.generic_page_transfer(operation)
543
587
  except Empty:
544
588
  continue
545
589
 
@@ -563,18 +607,27 @@ class HiCacheController:
563
607
  remaining_tokens = len(tokens_to_fetch)
564
608
  hash_value = []
565
609
  while remaining_tokens >= self.page_size:
566
- last_hash = get_hash_str(
610
+ last_hash = self.get_hash_str(
567
611
  tokens_to_fetch[
568
612
  storage_hit_count : storage_hit_count + self.page_size
569
613
  ],
570
614
  last_hash,
571
615
  )
572
- if self.storage_backend.exists(last_hash):
573
- storage_hit_count += self.page_size
574
- hash_value.append(last_hash)
575
- remaining_tokens -= self.page_size
576
- else:
577
- break
616
+
617
+ # todo, more unified interface
618
+ if not isinstance(self.storage_backend, MooncakeStore):
619
+ if not self.storage_backend.exists(last_hash):
620
+ break
621
+ hash_value.append(last_hash)
622
+ storage_hit_count += self.page_size
623
+ remaining_tokens -= self.page_size
624
+
625
+ if isinstance(self.storage_backend, MooncakeStore):
626
+ # deferring to batch exists for mooncake store
627
+ exist_result = self.storage_backend.exists(hash_value)
628
+ storage_hit_count = (
629
+ sum(1 for v in exist_result.values() if v != 0) * self.page_size
630
+ )
578
631
 
579
632
  if self.tp_world_size > 1:
580
633
  storage_hit_count_tensor = torch.tensor(
@@ -583,7 +636,7 @@ class HiCacheController:
583
636
  torch.distributed.all_reduce(
584
637
  storage_hit_count_tensor,
585
638
  op=torch.distributed.ReduceOp.MIN,
586
- group=self.tp_group,
639
+ group=self.prefetch_tp_group,
587
640
  )
588
641
  storage_hit_count = storage_hit_count_tensor.item()
589
642
 
@@ -622,6 +675,47 @@ class HiCacheController:
622
675
  self.backup_queue.put(operation)
623
676
  return operation.id
624
677
 
678
+ def generic_page_backup(self, operation, batch_size=8):
679
+ for i in range(0, len(operation.hash_value), batch_size):
680
+ page_hashes = operation.hash_value[i : i + batch_size]
681
+ page_data = [
682
+ self.mem_pool_host.get_flat_data_page(
683
+ operation.host_indices[j * self.page_size]
684
+ )
685
+ for j in range(i, i + len(page_hashes))
686
+ ]
687
+ success = self.storage_backend.batch_set(page_hashes, page_data)
688
+ if not success:
689
+ logger.warning(f"Failed to write page {page_hashes} to storage.")
690
+ break
691
+ operation.completed_tokens += self.page_size * len(page_hashes)
692
+
693
+ def mooncake_page_backup(self, operation):
694
+ if len(operation.hash_value):
695
+ exist_hashvalues = self.storage_backend.exists(operation.hash_value)
696
+ indices = operation.host_indices.tolist()
697
+ non_exist_keys = []
698
+ non_exist_indices = []
699
+ for i in range(len(operation.hash_value)):
700
+ if not exist_hashvalues[operation.hash_value[i]]:
701
+ non_exist_keys.append(operation.hash_value[i])
702
+ non_exist_indices.extend(
703
+ indices[i * self.page_size : (i + 1) * self.page_size]
704
+ )
705
+ if len(non_exist_keys) > 0:
706
+ key_strs, buffer_ptrs, buffer_sizes = (
707
+ self.mem_pool_host.get_buffer_meta(
708
+ non_exist_keys, non_exist_indices
709
+ )
710
+ )
711
+ # TODO: check the return value of batch set to see how many tokens are set successfully
712
+ self.storage_backend.batch_set(
713
+ key_strs,
714
+ target_location=buffer_ptrs,
715
+ target_sizes=buffer_sizes,
716
+ )
717
+ operation.completed_tokens += len(operation.hash_value) * self.page_size
718
+
625
719
  def backup_thread_func(self):
626
720
  """
627
721
  Manage backup operations from host memory to storage backend.
@@ -635,21 +729,25 @@ class HiCacheController:
635
729
  last_hash = operation.last_hash
636
730
  tokens_to_backup = operation.token_ids
637
731
 
638
- for i in range(0, len(tokens_to_backup), self.page_size):
639
- last_hash = get_hash_str(
640
- tokens_to_backup[i : i + self.page_size], last_hash
641
- )
642
- success = self.storage_backend.set(
732
+ backup_hit_count = 0
733
+ remaining_tokens = len(tokens_to_backup)
734
+ hash_value = []
735
+ while remaining_tokens >= self.page_size:
736
+ last_hash = self.get_hash_str(
737
+ tokens_to_backup[
738
+ backup_hit_count : backup_hit_count + self.page_size
739
+ ],
643
740
  last_hash,
644
- self.mem_pool_host.get_flat_data_page(
645
- operation.host_indices[i]
646
- ),
647
741
  )
648
- if not success:
649
- logger.warning(f"Failed to write page {last_hash} to storage.")
650
- break
651
- operation.completed_tokens += self.page_size
652
- operation.hash_value.append(last_hash)
742
+ backup_hit_count += self.page_size
743
+ hash_value.append(last_hash)
744
+ remaining_tokens -= self.page_size
745
+ operation.hash_value = hash_value
746
+
747
+ if isinstance(self.storage_backend, MooncakeStore):
748
+ self.mooncake_page_backup(operation)
749
+ else:
750
+ self.generic_page_backup(operation)
653
751
 
654
752
  min_completed_tokens = operation.completed_tokens
655
753
  if self.tp_world_size > 1:
@@ -659,7 +757,7 @@ class HiCacheController:
659
757
  torch.distributed.all_reduce(
660
758
  completed_tokens_tensor,
661
759
  op=torch.distributed.ReduceOp.MIN,
662
- group=self.tp_group,
760
+ group=self.backup_tp_group,
663
761
  )
664
762
  min_completed_tokens = completed_tokens_tensor.item()
665
763
 
@@ -222,6 +222,7 @@ class DataParallelController:
222
222
  + ((pp_rank % pp_size_per_node) * tp_size_per_node)
223
223
  + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
224
224
  )
225
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
225
226
  proc = mp.Process(
226
227
  target=run_scheduler_process,
227
228
  args=(
@@ -229,6 +230,7 @@ class DataParallelController:
229
230
  rank_port_args,
230
231
  gpu_id,
231
232
  tp_rank,
233
+ moe_ep_rank,
232
234
  pp_rank,
233
235
  dp_rank,
234
236
  writer,
@@ -152,8 +152,6 @@ class GenerateReqInput:
152
152
  else:
153
153
  self._normalize_batch_inputs()
154
154
 
155
- self._validate_session_params()
156
-
157
155
  def _validate_inputs(self):
158
156
  """Validate that the input configuration is valid."""
159
157
  if (