sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +3 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +10 -2
  11. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  12. sglang/srt/eplb/expert_distribution.py +5 -0
  13. sglang/srt/eplb/expert_location.py +17 -6
  14. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  15. sglang/srt/eplb/expert_location_updater.py +2 -0
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/step3_detector.py +436 -0
  18. sglang/srt/hf_transformers_utils.py +2 -0
  19. sglang/srt/jinja_template_utils.py +4 -1
  20. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +20 -640
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  24. sglang/srt/layers/quantization/fp8.py +0 -18
  25. sglang/srt/layers/quantization/unquant.py +0 -8
  26. sglang/srt/layers/quantization/w4afp8.py +1 -0
  27. sglang/srt/managers/cache_controller.py +143 -45
  28. sglang/srt/managers/data_parallel_controller.py +2 -0
  29. sglang/srt/managers/io_struct.py +0 -2
  30. sglang/srt/managers/scheduler.py +89 -671
  31. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  32. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  33. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  34. sglang/srt/managers/template_manager.py +62 -19
  35. sglang/srt/managers/tokenizer_manager.py +123 -74
  36. sglang/srt/managers/tp_worker.py +4 -0
  37. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  38. sglang/srt/mem_cache/hicache_storage.py +45 -11
  39. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  40. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  41. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  42. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  43. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  44. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  45. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  46. sglang/srt/model_executor/model_runner.py +5 -0
  47. sglang/srt/models/arcee.py +532 -0
  48. sglang/srt/models/deepseek_v2.py +2 -0
  49. sglang/srt/models/glm4_moe.py +3 -1
  50. sglang/srt/models/granitemoe.py +3 -0
  51. sglang/srt/models/grok.py +3 -0
  52. sglang/srt/models/hunyuan.py +1 -0
  53. sglang/srt/models/llama4.py +3 -0
  54. sglang/srt/models/mixtral.py +3 -0
  55. sglang/srt/models/olmoe.py +3 -0
  56. sglang/srt/models/phimoe.py +1 -0
  57. sglang/srt/models/step3_vl.py +994 -0
  58. sglang/srt/multimodal/processors/base_processor.py +15 -16
  59. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  60. sglang/srt/reasoning_parser.py +2 -1
  61. sglang/srt/server_args.py +10 -13
  62. sglang/srt/speculative/eagle_worker.py +2 -0
  63. sglang/utils.py +0 -11
  64. sglang/version.py +1 -1
  65. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
  66. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
  67. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  68. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -413,18 +413,37 @@ def fused_moe_kernel(
413
413
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
414
414
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
415
415
  return
416
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
416
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
417
417
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
418
418
  offs_token = offs_token.to(tl.int64)
419
419
  token_mask = offs_token < num_valid_tokens
420
420
 
421
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
421
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
422
+
423
+ if off_experts == -1:
424
+ # -----------------------------------------------------------
425
+ # Write back zeros to the output when the expert is not
426
+ # in the current expert parallel rank.
427
+ write_zeros_to_output(
428
+ c_ptr,
429
+ stride_cm,
430
+ stride_cn,
431
+ pid_n,
432
+ N,
433
+ offs_token,
434
+ token_mask,
435
+ BLOCK_SIZE_M,
436
+ BLOCK_SIZE_N,
437
+ compute_type,
438
+ )
439
+ return
440
+
441
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
422
442
  offs_k = tl.arange(0, BLOCK_SIZE_K)
423
443
  a_ptrs = a_ptr + (
424
444
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
425
445
  )
426
446
 
427
- off_experts = tl.load(expert_ids_ptr + pid_m)
428
447
  b_ptrs = (
429
448
  b_ptr
430
449
  + off_experts * stride_be
@@ -497,7 +516,6 @@ def fused_moe_kernel(
497
516
 
498
517
  accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
499
518
  else:
500
- # fix out of shared memory issue
501
519
  if use_fp8_w8a8:
502
520
  accumulator = tl.dot(a, b, acc=accumulator)
503
521
  else:
@@ -568,7 +586,7 @@ def moe_align_block_size(
568
586
  - The padding ensures that the total number of tokens is now divisible
569
587
  by block_size for proper block matrix operations.
570
588
  """
571
- max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
589
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
572
590
  sorted_ids = torch.empty(
573
591
  (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
574
592
  )
@@ -578,13 +596,9 @@ def moe_align_block_size(
578
596
  )
579
597
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
580
598
 
599
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
581
600
  cumsum_buffer = torch.empty(
582
- (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
583
- )
584
- token_cnts_buffer = torch.empty(
585
- (num_experts + 1) * num_experts,
586
- dtype=torch.int32,
587
- device=topk_ids.device,
601
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
588
602
  )
589
603
 
590
604
  # Threshold based on benchmark results
@@ -594,12 +608,11 @@ def moe_align_block_size(
594
608
 
595
609
  sgl_moe_align_block_size(
596
610
  topk_ids,
597
- num_experts,
611
+ num_experts + 1,
598
612
  block_size,
599
613
  sorted_ids,
600
614
  expert_ids,
601
615
  num_tokens_post_pad,
602
- token_cnts_buffer,
603
616
  cumsum_buffer,
604
617
  fuse_sorted_ids_padding,
605
618
  )
@@ -7,11 +7,16 @@ from typing import List, Optional, Tuple
7
7
  import torch
8
8
 
9
9
  from sglang.srt.distributed import (
10
+ get_moe_expert_parallel_rank,
11
+ get_moe_expert_parallel_world_size,
12
+ get_moe_tensor_parallel_rank,
13
+ get_moe_tensor_parallel_world_size,
10
14
  get_tensor_model_parallel_rank,
11
15
  get_tensor_model_parallel_world_size,
12
16
  tensor_model_parallel_all_reduce,
13
17
  )
14
- from sglang.srt.layers.moe.topk import TopKOutput
18
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
19
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
15
20
  from sglang.srt.layers.quantization.base_config import (
16
21
  QuantizationConfig,
17
22
  QuantizeMethodBase,
@@ -62,8 +67,9 @@ class FusedMoE(torch.nn.Module):
62
67
  num_experts: int,
63
68
  hidden_size: int,
64
69
  intermediate_size: int,
70
+ layer_id: int,
65
71
  top_k: Optional[int] = None,
66
- layer_id: Optional[int] = None,
72
+ num_fused_shared_experts: int = 0,
67
73
  params_dtype: Optional[torch.dtype] = None,
68
74
  reduce_results: bool = False,
69
75
  quant_config: Optional[QuantizationConfig] = None,
@@ -77,21 +83,19 @@ class FusedMoE(torch.nn.Module):
77
83
  routed_scaling_factor: Optional[float] = None,
78
84
  enable_flashinfer_cutlass_moe: Optional[bool] = False,
79
85
  enable_ep_moe: Optional[bool] = False,
80
- skip_quant: Optional[bool] = False,
81
86
  ):
82
87
  super().__init__()
83
88
 
84
89
  if params_dtype is None:
85
90
  params_dtype = torch.get_default_dtype()
86
91
 
92
+ self.layer_id = layer_id
87
93
  self.top_k = top_k
88
94
  self.hidden_size = hidden_size
89
- self.tp_size = (
90
- tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
91
- )
92
- self.tp_rank = get_tensor_model_parallel_rank()
93
95
  self.num_experts = num_experts
94
- self.expert_map = None
96
+ self.num_fused_shared_experts = num_fused_shared_experts
97
+ self.expert_map_cpu = None
98
+ self.expert_map_gpu = None
95
99
 
96
100
  if enable_flashinfer_cutlass_moe and quant_config is None:
97
101
  logger.warning("Disable flashinfer MoE when quantization config is None.")
@@ -99,28 +103,27 @@ class FusedMoE(torch.nn.Module):
99
103
  enable_ep_moe = False
100
104
 
101
105
  self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
106
+ self.moe_ep_size = get_moe_expert_parallel_world_size()
107
+ self.moe_ep_rank = get_moe_expert_parallel_rank()
108
+ self.moe_tp_size = get_moe_tensor_parallel_world_size()
109
+ self.moe_tp_rank = get_moe_tensor_parallel_rank()
110
+ assert num_experts % self.moe_ep_size == 0
111
+ self.num_local_experts = num_experts // self.moe_ep_size
102
112
  if enable_ep_moe:
103
- self.ep_size = self.tp_size
104
- self.ep_rank = self.tp_rank
105
- self.tp_size = 1
106
- self.tp_rank = 0
113
+ # TODO(ch-wan): support shared experts fusion
107
114
  # Create a tensor of size num_experts filled with -1
108
- self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
115
+ self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
109
116
  # Create a expert map for the local experts
110
- assert num_experts % self.ep_size == 0
111
- self.num_local_experts = num_experts // self.ep_size
112
- self.expert_map[
113
- self.ep_rank
114
- * self.num_local_experts : (self.ep_rank + 1)
117
+ self.expert_map_cpu[
118
+ self.moe_ep_rank
119
+ * self.num_local_experts : (self.moe_ep_rank + 1)
115
120
  * self.num_local_experts
116
121
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
117
- else:
118
- self.ep_size = 1
119
- self.ep_rank = 0
120
- self.num_local_experts = num_experts
122
+ self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
123
+
121
124
  self.routed_scaling_factor = routed_scaling_factor
122
- assert intermediate_size % self.tp_size == 0
123
- self.intermediate_size_per_partition = intermediate_size // self.tp_size
125
+ assert intermediate_size % self.moe_tp_size == 0
126
+ self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
124
127
  self.reduce_results = reduce_results
125
128
  self.activation = activation
126
129
  self.apply_router_weight_on_input = apply_router_weight_on_input
@@ -132,9 +135,6 @@ class FusedMoE(torch.nn.Module):
132
135
  not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
133
136
  )
134
137
 
135
- if skip_quant:
136
- return
137
-
138
138
  if quant_config is None:
139
139
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
140
140
  self.use_triton_kernels
@@ -363,9 +363,9 @@ class FusedMoE(torch.nn.Module):
363
363
  expert_data.copy_(loaded_weight)
364
364
 
365
365
  def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
366
- if self.expert_map is None:
366
+ if self.expert_map_cpu is None:
367
367
  return expert_id
368
- return self.expert_map[expert_id].item()
368
+ return self.expert_map_cpu[expert_id].item()
369
369
 
370
370
  def weight_loader(
371
371
  self,
@@ -375,10 +375,48 @@ class FusedMoE(torch.nn.Module):
375
375
  shard_id: str,
376
376
  expert_id: int,
377
377
  ) -> None:
378
+
379
+ global_expert_location_metadata = get_global_expert_location_metadata()
380
+ if global_expert_location_metadata is None:
381
+ self._weight_loader_impl(
382
+ param=param,
383
+ loaded_weight=loaded_weight,
384
+ weight_name=weight_name,
385
+ shard_id=shard_id,
386
+ expert_id=expert_id,
387
+ )
388
+ return
389
+
390
+ if expert_id >= self.num_experts - self.num_fused_shared_experts:
391
+ # This is a shared expert.
392
+ physical_expert_ids = [expert_id]
393
+ else:
394
+ physical_expert_ids = (
395
+ global_expert_location_metadata.logical_to_all_physical(
396
+ self.layer_id, expert_id
397
+ )
398
+ )
399
+
400
+ for physical_expert_id in physical_expert_ids:
401
+ self._weight_loader_physical(
402
+ param=param,
403
+ loaded_weight=loaded_weight,
404
+ weight_name=weight_name,
405
+ shard_id=shard_id,
406
+ expert_id=physical_expert_id,
407
+ )
408
+
409
+ def _weight_loader_physical(
410
+ self,
411
+ param: torch.nn.Parameter,
412
+ loaded_weight: torch.Tensor,
413
+ weight_name: str,
414
+ shard_id: str,
415
+ expert_id: int,
416
+ ) -> None:
378
417
  expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
379
418
  if expert_id == -1:
380
419
  return
381
-
382
420
  self._weight_loader_impl(
383
421
  param=param,
384
422
  loaded_weight=loaded_weight,
@@ -396,8 +434,7 @@ class FusedMoE(torch.nn.Module):
396
434
  expert_id: int,
397
435
  ) -> None:
398
436
 
399
- # TP rank is set to 0 if EP is enabled
400
- tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
437
+ tp_rank = self.moe_tp_rank
401
438
 
402
439
  # compressed-tensors checkpoints with packed weights are stored flipped
403
440
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
@@ -571,9 +608,14 @@ class FusedMoE(torch.nn.Module):
571
608
  )
572
609
  return
573
610
 
574
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
611
+ def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
575
612
  assert self.quant_method is not None
576
613
 
614
+ if self.expert_map_gpu is not None:
615
+ topk_output = topk_output._replace(
616
+ topk_ids=self.expert_map_gpu[topk_output.topk_ids]
617
+ )
618
+
577
619
  # Matrix multiply.
578
620
  final_hidden_states = self.quant_method.apply(
579
621
  layer=self,
@@ -584,17 +626,17 @@ class FusedMoE(torch.nn.Module):
584
626
  routed_scaling_factor=self.routed_scaling_factor,
585
627
  **(
586
628
  dict(
587
- tp_rank=self.tp_rank,
588
- tp_size=self.tp_size,
589
- ep_rank=self.ep_rank,
590
- ep_size=self.ep_size,
629
+ tp_rank=self.moe_tp_rank,
630
+ tp_size=self.moe_tp_size,
631
+ ep_rank=self.moe_ep_rank,
632
+ ep_size=self.moe_ep_size,
591
633
  )
592
634
  if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
593
635
  else {}
594
636
  ),
595
637
  )
596
638
 
597
- if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
639
+ if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
598
640
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
599
641
 
600
642
  return final_hidden_states
@@ -627,3 +669,20 @@ class FusedMoE(torch.nn.Module):
627
669
  ("w3", ckpt_up_proj_name),
628
670
  ]
629
671
  ]
672
+
673
+ @classmethod
674
+ def make_expert_input_scale_params_mapping(
675
+ cls,
676
+ num_experts: int,
677
+ ) -> List[Tuple[str, str, int, str]]:
678
+ # (param_name, weight_name, expert_id, shard_id)
679
+ return [
680
+ (
681
+ "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
682
+ f"experts.{expert_id}.{shard_id}.",
683
+ expert_id,
684
+ shard_id,
685
+ )
686
+ for expert_id in range(num_experts)
687
+ for shard_id in ["w1", "w2", "w3"]
688
+ ]
@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
172
172
  self, layer: torch.nn.Module, prefix: str
173
173
  ) -> Optional[QuantizeMethodBase]:
174
174
  from sglang.srt.layers.linear import LinearBase
175
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
176
175
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
177
176
 
178
177
  if isinstance(layer, LinearBase):
@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
181
180
  return Fp8LinearMethod(self)
182
181
  elif isinstance(layer, FusedMoE):
183
182
  return Fp8MoEMethod(self)
184
- elif isinstance(layer, EPMoE):
185
- return Fp8EPMoEMethod(self)
186
183
  return None
187
184
 
188
185
  def get_scaled_act_names(self) -> List[str]:
@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
984
981
  no_combine: bool = False,
985
982
  routed_scaling_factor: Optional[float] = None,
986
983
  ) -> torch.Tensor:
987
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
988
984
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
989
985
 
990
- if isinstance(layer, EPMoE):
991
- layer.w13_weight_scale = (
992
- layer.w13_weight_scale_inv
993
- if self.block_quant
994
- else layer.w13_weight_scale
995
- )
996
- layer.w2_weight_scale = (
997
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
998
- )
999
- return layer.run_moe(
1000
- hidden_states=x,
1001
- topk_output=topk_output,
1002
- )
1003
-
1004
986
  if use_intel_amx_backend(layer):
1005
987
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
1006
988
 
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
204
204
  routed_scaling_factor: Optional[float] = None,
205
205
  ) -> torch.Tensor:
206
206
 
207
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
208
-
209
- if isinstance(layer, EPMoE):
210
- return layer.run_moe(
211
- hidden_states=x,
212
- topk_output=topk_output,
213
- )
214
-
215
207
  return self.forward(
216
208
  x=x,
217
209
  layer=layer,
@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
276
276
  layer: EPMoE,
277
277
  hidden_states: torch.Tensor,
278
278
  topk_output: TopKOutput,
279
+ **kwargs,
279
280
  ) -> torch.Tensor:
280
281
 
281
282
  # TODO(ch-wan): move it out of this class
@@ -26,6 +26,11 @@ if TYPE_CHECKING:
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
28
  from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
+ from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
30
+ MooncakeStore,
31
+ get_hash_str_mooncake,
32
+ )
33
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
29
34
 
30
35
  logger = logging.getLogger(__name__)
31
36
 
@@ -124,7 +129,7 @@ class TransferBuffer:
124
129
  """
125
130
 
126
131
  def __init__(
127
- self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
132
+ self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
128
133
  ) -> None:
129
134
  self.stop_event = stop_event
130
135
  self.buffers = Queue(maxsize=buffer_count)
@@ -250,17 +255,39 @@ class HiCacheController:
250
255
  self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
251
256
  if self.tp_world_size > 1:
252
257
  group_ranks = torch.distributed.get_process_group_ranks(tp_group)
253
- self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
258
+ self.prefetch_tp_group = torch.distributed.new_group(
259
+ group_ranks, backend="gloo"
260
+ )
261
+ self.backup_tp_group = torch.distributed.new_group(
262
+ group_ranks, backend="gloo"
263
+ )
254
264
 
255
265
  if storage_backend == "file":
256
266
  self.storage_backend = HiCacheFile()
257
- self.enable_storage = True
258
- # todo: threshold policy for prefetching
259
- self.prefetch_threshold = max(prefetch_threshold, self.page_size)
267
+ self.get_hash_str = get_hash_str
268
+ elif storage_backend == "mooncake":
269
+ self.storage_backend = MooncakeStore()
270
+ self.get_hash_str = get_hash_str_mooncake
271
+ self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
272
+ elif storage_backend == "hf3fs":
273
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
274
+
275
+ rank = get_tensor_model_parallel_rank()
276
+ bytes_per_page = (
277
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
278
+ )
279
+ dtype = mem_pool_host.dtype
280
+ self.storage_backend = HiCacheHF3FS.from_env_config(
281
+ rank, bytes_per_page, dtype
282
+ )
283
+ self.get_hash_str = get_hash_str
260
284
  else:
261
285
  raise NotImplementedError(
262
286
  f"Unsupported storage backend: {storage_backend}"
263
287
  )
288
+ self.enable_storage = True
289
+ # todo: threshold policy for prefetching
290
+ self.prefetch_threshold = max(prefetch_threshold, self.page_size)
264
291
 
265
292
  self.load_cache_event = load_cache_event
266
293
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -515,6 +542,37 @@ class HiCacheController:
515
542
  operation.mark_done()
516
543
  return operation.completed_tokens, operation.hash_value
517
544
 
545
+ def generic_page_transfer(self, operation, batch_size=8):
546
+ for i in range(0, len(operation.hash_value), batch_size):
547
+ page_hashes = operation.hash_value[i : i + batch_size]
548
+ page_data = self.storage_backend.batch_get(page_hashes)
549
+ if page_data is None:
550
+ logger.warning(
551
+ f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
552
+ )
553
+ break
554
+ completed_tokens = operation.completed_tokens
555
+ if operation.increment(self.page_size * len(page_hashes)):
556
+ for i in range(len(page_hashes)):
557
+ self.mem_pool_host.set_from_flat_data_page(
558
+ operation.host_indices[completed_tokens],
559
+ page_data[i],
560
+ )
561
+ completed_tokens += self.page_size
562
+ else:
563
+ # operation terminated by controller, release pre-allocated memory
564
+ self.mem_pool_host.free(
565
+ operation.host_indices[operation.completed_tokens :]
566
+ )
567
+ break
568
+
569
+ def mooncake_page_transfer(self, operation):
570
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
571
+ operation.hash_value, operation.host_indices
572
+ )
573
+ self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
574
+ operation.increment(len(operation.hash_value) * self.page_size)
575
+
518
576
  def prefetch_io_aux_func(self):
519
577
  """
520
578
  Auxiliary function conducting IO operations for prefetching.
@@ -522,24 +580,10 @@ class HiCacheController:
522
580
  while not self.stop_event.is_set():
523
581
  try:
524
582
  operation = self.prefetch_buffer.get(block=True, timeout=1)
525
- for h in operation.hash_value:
526
- page_data = self.storage_backend.get(h)
527
- if page_data is None:
528
- logger.warning(
529
- f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
530
- )
531
- break
532
- if operation.increment(self.page_size):
533
- self.mem_pool_host.set_from_flat_data_page(
534
- operation.host_indices[operation.completed_tokens],
535
- page_data,
536
- )
537
- else:
538
- # operation terminated by controller, release pre-allocated memory
539
- self.mem_pool_host.free(
540
- operation.host_indices[operation.completed_tokens :]
541
- )
542
- break
583
+ if isinstance(self.storage_backend, MooncakeStore):
584
+ self.mooncake_page_transfer(operation)
585
+ else:
586
+ self.generic_page_transfer(operation)
543
587
  except Empty:
544
588
  continue
545
589
 
@@ -563,18 +607,27 @@ class HiCacheController:
563
607
  remaining_tokens = len(tokens_to_fetch)
564
608
  hash_value = []
565
609
  while remaining_tokens >= self.page_size:
566
- last_hash = get_hash_str(
610
+ last_hash = self.get_hash_str(
567
611
  tokens_to_fetch[
568
612
  storage_hit_count : storage_hit_count + self.page_size
569
613
  ],
570
614
  last_hash,
571
615
  )
572
- if self.storage_backend.exists(last_hash):
573
- storage_hit_count += self.page_size
574
- hash_value.append(last_hash)
575
- remaining_tokens -= self.page_size
576
- else:
577
- break
616
+
617
+ # todo, more unified interface
618
+ if not isinstance(self.storage_backend, MooncakeStore):
619
+ if not self.storage_backend.exists(last_hash):
620
+ break
621
+ hash_value.append(last_hash)
622
+ storage_hit_count += self.page_size
623
+ remaining_tokens -= self.page_size
624
+
625
+ if isinstance(self.storage_backend, MooncakeStore):
626
+ # deferring to batch exists for mooncake store
627
+ exist_result = self.storage_backend.exists(hash_value)
628
+ storage_hit_count = (
629
+ sum(1 for v in exist_result.values() if v != 0) * self.page_size
630
+ )
578
631
 
579
632
  if self.tp_world_size > 1:
580
633
  storage_hit_count_tensor = torch.tensor(
@@ -583,7 +636,7 @@ class HiCacheController:
583
636
  torch.distributed.all_reduce(
584
637
  storage_hit_count_tensor,
585
638
  op=torch.distributed.ReduceOp.MIN,
586
- group=self.tp_group,
639
+ group=self.prefetch_tp_group,
587
640
  )
588
641
  storage_hit_count = storage_hit_count_tensor.item()
589
642
 
@@ -622,6 +675,47 @@ class HiCacheController:
622
675
  self.backup_queue.put(operation)
623
676
  return operation.id
624
677
 
678
+ def generic_page_backup(self, operation, batch_size=8):
679
+ for i in range(0, len(operation.hash_value), batch_size):
680
+ page_hashes = operation.hash_value[i : i + batch_size]
681
+ page_data = [
682
+ self.mem_pool_host.get_flat_data_pages(
683
+ operation.host_indices[j * self.page_size]
684
+ )
685
+ for j in range(i, i + len(page_hashes))
686
+ ]
687
+ success = self.storage_backend.batch_set(page_hashes, page_data)
688
+ if not success:
689
+ logger.warning(f"Failed to write page {page_hashes} to storage.")
690
+ break
691
+ operation.completed_tokens += self.page_size * len(page_hashes)
692
+
693
+ def mooncake_page_backup(self, operation):
694
+ if len(operation.hash_value):
695
+ exist_hashvalues = self.storage_backend.exists(operation.hash_value)
696
+ indices = operation.host_indices.tolist()
697
+ non_exist_keys = []
698
+ non_exist_indices = []
699
+ for i in range(len(operation.hash_value)):
700
+ if not exist_hashvalues[operation.hash_value[i]]:
701
+ non_exist_keys.append(operation.hash_value[i])
702
+ non_exist_indices.extend(
703
+ indices[i * self.page_size : (i + 1) * self.page_size]
704
+ )
705
+ if len(non_exist_keys) > 0:
706
+ key_strs, buffer_ptrs, buffer_sizes = (
707
+ self.mem_pool_host.get_buffer_meta(
708
+ non_exist_keys, non_exist_indices
709
+ )
710
+ )
711
+ # TODO: check the return value of batch set to see how many tokens are set successfully
712
+ self.storage_backend.batch_set(
713
+ key_strs,
714
+ target_location=buffer_ptrs,
715
+ target_sizes=buffer_sizes,
716
+ )
717
+ operation.completed_tokens += len(operation.hash_value) * self.page_size
718
+
625
719
  def backup_thread_func(self):
626
720
  """
627
721
  Manage backup operations from host memory to storage backend.
@@ -635,21 +729,25 @@ class HiCacheController:
635
729
  last_hash = operation.last_hash
636
730
  tokens_to_backup = operation.token_ids
637
731
 
638
- for i in range(0, len(tokens_to_backup), self.page_size):
639
- last_hash = get_hash_str(
640
- tokens_to_backup[i : i + self.page_size], last_hash
641
- )
642
- success = self.storage_backend.set(
732
+ backup_hit_count = 0
733
+ remaining_tokens = len(tokens_to_backup)
734
+ hash_value = []
735
+ while remaining_tokens >= self.page_size:
736
+ last_hash = self.get_hash_str(
737
+ tokens_to_backup[
738
+ backup_hit_count : backup_hit_count + self.page_size
739
+ ],
643
740
  last_hash,
644
- self.mem_pool_host.get_flat_data_page(
645
- operation.host_indices[i]
646
- ),
647
741
  )
648
- if not success:
649
- logger.warning(f"Failed to write page {last_hash} to storage.")
650
- break
651
- operation.completed_tokens += self.page_size
652
- operation.hash_value.append(last_hash)
742
+ backup_hit_count += self.page_size
743
+ hash_value.append(last_hash)
744
+ remaining_tokens -= self.page_size
745
+ operation.hash_value = hash_value
746
+
747
+ if isinstance(self.storage_backend, MooncakeStore):
748
+ self.mooncake_page_backup(operation)
749
+ else:
750
+ self.generic_page_backup(operation)
653
751
 
654
752
  min_completed_tokens = operation.completed_tokens
655
753
  if self.tp_world_size > 1:
@@ -659,7 +757,7 @@ class HiCacheController:
659
757
  torch.distributed.all_reduce(
660
758
  completed_tokens_tensor,
661
759
  op=torch.distributed.ReduceOp.MIN,
662
- group=self.tp_group,
760
+ group=self.backup_tp_group,
663
761
  )
664
762
  min_completed_tokens = completed_tokens_tensor.item()
665
763
 
@@ -222,6 +222,7 @@ class DataParallelController:
222
222
  + ((pp_rank % pp_size_per_node) * tp_size_per_node)
223
223
  + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
224
224
  )
225
+ moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
225
226
  proc = mp.Process(
226
227
  target=run_scheduler_process,
227
228
  args=(
@@ -229,6 +230,7 @@ class DataParallelController:
229
230
  rank_port_args,
230
231
  gpu_id,
231
232
  tp_rank,
233
+ moe_ep_rank,
232
234
  pp_rank,
233
235
  dp_rank,
234
236
  writer,