sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -640
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +89 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +2 -0
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +10 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -413,18 +413,37 @@ def fused_moe_kernel(
|
|
413
413
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
414
414
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
415
415
|
return
|
416
|
-
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
416
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
417
417
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
418
418
|
offs_token = offs_token.to(tl.int64)
|
419
419
|
token_mask = offs_token < num_valid_tokens
|
420
420
|
|
421
|
-
|
421
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
422
|
+
|
423
|
+
if off_experts == -1:
|
424
|
+
# -----------------------------------------------------------
|
425
|
+
# Write back zeros to the output when the expert is not
|
426
|
+
# in the current expert parallel rank.
|
427
|
+
write_zeros_to_output(
|
428
|
+
c_ptr,
|
429
|
+
stride_cm,
|
430
|
+
stride_cn,
|
431
|
+
pid_n,
|
432
|
+
N,
|
433
|
+
offs_token,
|
434
|
+
token_mask,
|
435
|
+
BLOCK_SIZE_M,
|
436
|
+
BLOCK_SIZE_N,
|
437
|
+
compute_type,
|
438
|
+
)
|
439
|
+
return
|
440
|
+
|
441
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
422
442
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
423
443
|
a_ptrs = a_ptr + (
|
424
444
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
425
445
|
)
|
426
446
|
|
427
|
-
off_experts = tl.load(expert_ids_ptr + pid_m)
|
428
447
|
b_ptrs = (
|
429
448
|
b_ptr
|
430
449
|
+ off_experts * stride_be
|
@@ -497,7 +516,6 @@ def fused_moe_kernel(
|
|
497
516
|
|
498
517
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
499
518
|
else:
|
500
|
-
# fix out of shared memory issue
|
501
519
|
if use_fp8_w8a8:
|
502
520
|
accumulator = tl.dot(a, b, acc=accumulator)
|
503
521
|
else:
|
@@ -568,7 +586,7 @@ def moe_align_block_size(
|
|
568
586
|
- The padding ensures that the total number of tokens is now divisible
|
569
587
|
by block_size for proper block matrix operations.
|
570
588
|
"""
|
571
|
-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
589
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
572
590
|
sorted_ids = torch.empty(
|
573
591
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
574
592
|
)
|
@@ -578,13 +596,9 @@ def moe_align_block_size(
|
|
578
596
|
)
|
579
597
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
580
598
|
|
599
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
581
600
|
cumsum_buffer = torch.empty(
|
582
|
-
(num_experts +
|
583
|
-
)
|
584
|
-
token_cnts_buffer = torch.empty(
|
585
|
-
(num_experts + 1) * num_experts,
|
586
|
-
dtype=torch.int32,
|
587
|
-
device=topk_ids.device,
|
601
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
588
602
|
)
|
589
603
|
|
590
604
|
# Threshold based on benchmark results
|
@@ -594,12 +608,11 @@ def moe_align_block_size(
|
|
594
608
|
|
595
609
|
sgl_moe_align_block_size(
|
596
610
|
topk_ids,
|
597
|
-
num_experts,
|
611
|
+
num_experts + 1,
|
598
612
|
block_size,
|
599
613
|
sorted_ids,
|
600
614
|
expert_ids,
|
601
615
|
num_tokens_post_pad,
|
602
|
-
token_cnts_buffer,
|
603
616
|
cumsum_buffer,
|
604
617
|
fuse_sorted_ids_padding,
|
605
618
|
)
|
@@ -7,11 +7,16 @@ from typing import List, Optional, Tuple
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
from sglang.srt.distributed import (
|
10
|
+
get_moe_expert_parallel_rank,
|
11
|
+
get_moe_expert_parallel_world_size,
|
12
|
+
get_moe_tensor_parallel_rank,
|
13
|
+
get_moe_tensor_parallel_world_size,
|
10
14
|
get_tensor_model_parallel_rank,
|
11
15
|
get_tensor_model_parallel_world_size,
|
12
16
|
tensor_model_parallel_all_reduce,
|
13
17
|
)
|
14
|
-
from sglang.srt.
|
18
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
19
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
15
20
|
from sglang.srt.layers.quantization.base_config import (
|
16
21
|
QuantizationConfig,
|
17
22
|
QuantizeMethodBase,
|
@@ -62,8 +67,9 @@ class FusedMoE(torch.nn.Module):
|
|
62
67
|
num_experts: int,
|
63
68
|
hidden_size: int,
|
64
69
|
intermediate_size: int,
|
70
|
+
layer_id: int,
|
65
71
|
top_k: Optional[int] = None,
|
66
|
-
|
72
|
+
num_fused_shared_experts: int = 0,
|
67
73
|
params_dtype: Optional[torch.dtype] = None,
|
68
74
|
reduce_results: bool = False,
|
69
75
|
quant_config: Optional[QuantizationConfig] = None,
|
@@ -77,21 +83,19 @@ class FusedMoE(torch.nn.Module):
|
|
77
83
|
routed_scaling_factor: Optional[float] = None,
|
78
84
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
79
85
|
enable_ep_moe: Optional[bool] = False,
|
80
|
-
skip_quant: Optional[bool] = False,
|
81
86
|
):
|
82
87
|
super().__init__()
|
83
88
|
|
84
89
|
if params_dtype is None:
|
85
90
|
params_dtype = torch.get_default_dtype()
|
86
91
|
|
92
|
+
self.layer_id = layer_id
|
87
93
|
self.top_k = top_k
|
88
94
|
self.hidden_size = hidden_size
|
89
|
-
self.tp_size = (
|
90
|
-
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
91
|
-
)
|
92
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
93
95
|
self.num_experts = num_experts
|
94
|
-
self.
|
96
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
97
|
+
self.expert_map_cpu = None
|
98
|
+
self.expert_map_gpu = None
|
95
99
|
|
96
100
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
97
101
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
@@ -99,28 +103,27 @@ class FusedMoE(torch.nn.Module):
|
|
99
103
|
enable_ep_moe = False
|
100
104
|
|
101
105
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
106
|
+
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
107
|
+
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
108
|
+
self.moe_tp_size = get_moe_tensor_parallel_world_size()
|
109
|
+
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
110
|
+
assert num_experts % self.moe_ep_size == 0
|
111
|
+
self.num_local_experts = num_experts // self.moe_ep_size
|
102
112
|
if enable_ep_moe:
|
103
|
-
|
104
|
-
self.ep_rank = self.tp_rank
|
105
|
-
self.tp_size = 1
|
106
|
-
self.tp_rank = 0
|
113
|
+
# TODO(ch-wan): support shared experts fusion
|
107
114
|
# Create a tensor of size num_experts filled with -1
|
108
|
-
self.
|
115
|
+
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
109
116
|
# Create a expert map for the local experts
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
self.ep_rank
|
114
|
-
* self.num_local_experts : (self.ep_rank + 1)
|
117
|
+
self.expert_map_cpu[
|
118
|
+
self.moe_ep_rank
|
119
|
+
* self.num_local_experts : (self.moe_ep_rank + 1)
|
115
120
|
* self.num_local_experts
|
116
121
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
117
|
-
|
118
|
-
|
119
|
-
self.ep_rank = 0
|
120
|
-
self.num_local_experts = num_experts
|
122
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
123
|
+
|
121
124
|
self.routed_scaling_factor = routed_scaling_factor
|
122
|
-
assert intermediate_size % self.
|
123
|
-
self.intermediate_size_per_partition = intermediate_size // self.
|
125
|
+
assert intermediate_size % self.moe_tp_size == 0
|
126
|
+
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
124
127
|
self.reduce_results = reduce_results
|
125
128
|
self.activation = activation
|
126
129
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
@@ -132,9 +135,6 @@ class FusedMoE(torch.nn.Module):
|
|
132
135
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
133
136
|
)
|
134
137
|
|
135
|
-
if skip_quant:
|
136
|
-
return
|
137
|
-
|
138
138
|
if quant_config is None:
|
139
139
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
140
140
|
self.use_triton_kernels
|
@@ -363,9 +363,9 @@ class FusedMoE(torch.nn.Module):
|
|
363
363
|
expert_data.copy_(loaded_weight)
|
364
364
|
|
365
365
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
366
|
-
if self.
|
366
|
+
if self.expert_map_cpu is None:
|
367
367
|
return expert_id
|
368
|
-
return self.
|
368
|
+
return self.expert_map_cpu[expert_id].item()
|
369
369
|
|
370
370
|
def weight_loader(
|
371
371
|
self,
|
@@ -375,10 +375,48 @@ class FusedMoE(torch.nn.Module):
|
|
375
375
|
shard_id: str,
|
376
376
|
expert_id: int,
|
377
377
|
) -> None:
|
378
|
+
|
379
|
+
global_expert_location_metadata = get_global_expert_location_metadata()
|
380
|
+
if global_expert_location_metadata is None:
|
381
|
+
self._weight_loader_impl(
|
382
|
+
param=param,
|
383
|
+
loaded_weight=loaded_weight,
|
384
|
+
weight_name=weight_name,
|
385
|
+
shard_id=shard_id,
|
386
|
+
expert_id=expert_id,
|
387
|
+
)
|
388
|
+
return
|
389
|
+
|
390
|
+
if expert_id >= self.num_experts - self.num_fused_shared_experts:
|
391
|
+
# This is a shared expert.
|
392
|
+
physical_expert_ids = [expert_id]
|
393
|
+
else:
|
394
|
+
physical_expert_ids = (
|
395
|
+
global_expert_location_metadata.logical_to_all_physical(
|
396
|
+
self.layer_id, expert_id
|
397
|
+
)
|
398
|
+
)
|
399
|
+
|
400
|
+
for physical_expert_id in physical_expert_ids:
|
401
|
+
self._weight_loader_physical(
|
402
|
+
param=param,
|
403
|
+
loaded_weight=loaded_weight,
|
404
|
+
weight_name=weight_name,
|
405
|
+
shard_id=shard_id,
|
406
|
+
expert_id=physical_expert_id,
|
407
|
+
)
|
408
|
+
|
409
|
+
def _weight_loader_physical(
|
410
|
+
self,
|
411
|
+
param: torch.nn.Parameter,
|
412
|
+
loaded_weight: torch.Tensor,
|
413
|
+
weight_name: str,
|
414
|
+
shard_id: str,
|
415
|
+
expert_id: int,
|
416
|
+
) -> None:
|
378
417
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
379
418
|
if expert_id == -1:
|
380
419
|
return
|
381
|
-
|
382
420
|
self._weight_loader_impl(
|
383
421
|
param=param,
|
384
422
|
loaded_weight=loaded_weight,
|
@@ -396,8 +434,7 @@ class FusedMoE(torch.nn.Module):
|
|
396
434
|
expert_id: int,
|
397
435
|
) -> None:
|
398
436
|
|
399
|
-
|
400
|
-
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
437
|
+
tp_rank = self.moe_tp_rank
|
401
438
|
|
402
439
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
403
440
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
@@ -571,9 +608,14 @@ class FusedMoE(torch.nn.Module):
|
|
571
608
|
)
|
572
609
|
return
|
573
610
|
|
574
|
-
def forward(self, hidden_states: torch.Tensor, topk_output:
|
611
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
575
612
|
assert self.quant_method is not None
|
576
613
|
|
614
|
+
if self.expert_map_gpu is not None:
|
615
|
+
topk_output = topk_output._replace(
|
616
|
+
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
617
|
+
)
|
618
|
+
|
577
619
|
# Matrix multiply.
|
578
620
|
final_hidden_states = self.quant_method.apply(
|
579
621
|
layer=self,
|
@@ -584,17 +626,17 @@ class FusedMoE(torch.nn.Module):
|
|
584
626
|
routed_scaling_factor=self.routed_scaling_factor,
|
585
627
|
**(
|
586
628
|
dict(
|
587
|
-
tp_rank=self.
|
588
|
-
tp_size=self.
|
589
|
-
ep_rank=self.
|
590
|
-
ep_size=self.
|
629
|
+
tp_rank=self.moe_tp_rank,
|
630
|
+
tp_size=self.moe_tp_size,
|
631
|
+
ep_rank=self.moe_ep_rank,
|
632
|
+
ep_size=self.moe_ep_size,
|
591
633
|
)
|
592
634
|
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
593
635
|
else {}
|
594
636
|
),
|
595
637
|
)
|
596
638
|
|
597
|
-
if self.reduce_results and (self.
|
639
|
+
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
598
640
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
599
641
|
|
600
642
|
return final_hidden_states
|
@@ -627,3 +669,20 @@ class FusedMoE(torch.nn.Module):
|
|
627
669
|
("w3", ckpt_up_proj_name),
|
628
670
|
]
|
629
671
|
]
|
672
|
+
|
673
|
+
@classmethod
|
674
|
+
def make_expert_input_scale_params_mapping(
|
675
|
+
cls,
|
676
|
+
num_experts: int,
|
677
|
+
) -> List[Tuple[str, str, int, str]]:
|
678
|
+
# (param_name, weight_name, expert_id, shard_id)
|
679
|
+
return [
|
680
|
+
(
|
681
|
+
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
682
|
+
f"experts.{expert_id}.{shard_id}.",
|
683
|
+
expert_id,
|
684
|
+
shard_id,
|
685
|
+
)
|
686
|
+
for expert_id in range(num_experts)
|
687
|
+
for shard_id in ["w1", "w2", "w3"]
|
688
|
+
]
|
@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
|
|
172
172
|
self, layer: torch.nn.Module, prefix: str
|
173
173
|
) -> Optional[QuantizeMethodBase]:
|
174
174
|
from sglang.srt.layers.linear import LinearBase
|
175
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
176
175
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
177
176
|
|
178
177
|
if isinstance(layer, LinearBase):
|
@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
|
|
181
180
|
return Fp8LinearMethod(self)
|
182
181
|
elif isinstance(layer, FusedMoE):
|
183
182
|
return Fp8MoEMethod(self)
|
184
|
-
elif isinstance(layer, EPMoE):
|
185
|
-
return Fp8EPMoEMethod(self)
|
186
183
|
return None
|
187
184
|
|
188
185
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
984
981
|
no_combine: bool = False,
|
985
982
|
routed_scaling_factor: Optional[float] = None,
|
986
983
|
) -> torch.Tensor:
|
987
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
988
984
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
989
985
|
|
990
|
-
if isinstance(layer, EPMoE):
|
991
|
-
layer.w13_weight_scale = (
|
992
|
-
layer.w13_weight_scale_inv
|
993
|
-
if self.block_quant
|
994
|
-
else layer.w13_weight_scale
|
995
|
-
)
|
996
|
-
layer.w2_weight_scale = (
|
997
|
-
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
998
|
-
)
|
999
|
-
return layer.run_moe(
|
1000
|
-
hidden_states=x,
|
1001
|
-
topk_output=topk_output,
|
1002
|
-
)
|
1003
|
-
|
1004
986
|
if use_intel_amx_backend(layer):
|
1005
987
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
1006
988
|
|
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
204
204
|
routed_scaling_factor: Optional[float] = None,
|
205
205
|
) -> torch.Tensor:
|
206
206
|
|
207
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
208
|
-
|
209
|
-
if isinstance(layer, EPMoE):
|
210
|
-
return layer.run_moe(
|
211
|
-
hidden_states=x,
|
212
|
-
topk_output=topk_output,
|
213
|
-
)
|
214
|
-
|
215
207
|
return self.forward(
|
216
208
|
x=x,
|
217
209
|
layer=layer,
|
@@ -26,6 +26,11 @@ if TYPE_CHECKING:
|
|
26
26
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
27
27
|
|
28
28
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
29
|
+
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
30
|
+
MooncakeStore,
|
31
|
+
get_hash_str_mooncake,
|
32
|
+
)
|
33
|
+
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
29
34
|
|
30
35
|
logger = logging.getLogger(__name__)
|
31
36
|
|
@@ -124,7 +129,7 @@ class TransferBuffer:
|
|
124
129
|
"""
|
125
130
|
|
126
131
|
def __init__(
|
127
|
-
self, stop_event, buffer_count: int = 3, max_buffer_size: int =
|
132
|
+
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
|
128
133
|
) -> None:
|
129
134
|
self.stop_event = stop_event
|
130
135
|
self.buffers = Queue(maxsize=buffer_count)
|
@@ -250,17 +255,39 @@ class HiCacheController:
|
|
250
255
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
251
256
|
if self.tp_world_size > 1:
|
252
257
|
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
253
|
-
self.
|
258
|
+
self.prefetch_tp_group = torch.distributed.new_group(
|
259
|
+
group_ranks, backend="gloo"
|
260
|
+
)
|
261
|
+
self.backup_tp_group = torch.distributed.new_group(
|
262
|
+
group_ranks, backend="gloo"
|
263
|
+
)
|
254
264
|
|
255
265
|
if storage_backend == "file":
|
256
266
|
self.storage_backend = HiCacheFile()
|
257
|
-
self.
|
258
|
-
|
259
|
-
self.
|
267
|
+
self.get_hash_str = get_hash_str
|
268
|
+
elif storage_backend == "mooncake":
|
269
|
+
self.storage_backend = MooncakeStore()
|
270
|
+
self.get_hash_str = get_hash_str_mooncake
|
271
|
+
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
272
|
+
elif storage_backend == "hf3fs":
|
273
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
274
|
+
|
275
|
+
rank = get_tensor_model_parallel_rank()
|
276
|
+
bytes_per_page = (
|
277
|
+
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
278
|
+
)
|
279
|
+
dtype = mem_pool_host.dtype
|
280
|
+
self.storage_backend = HiCacheHF3FS.from_env_config(
|
281
|
+
rank, bytes_per_page, dtype
|
282
|
+
)
|
283
|
+
self.get_hash_str = get_hash_str
|
260
284
|
else:
|
261
285
|
raise NotImplementedError(
|
262
286
|
f"Unsupported storage backend: {storage_backend}"
|
263
287
|
)
|
288
|
+
self.enable_storage = True
|
289
|
+
# todo: threshold policy for prefetching
|
290
|
+
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
264
291
|
|
265
292
|
self.load_cache_event = load_cache_event
|
266
293
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -515,6 +542,37 @@ class HiCacheController:
|
|
515
542
|
operation.mark_done()
|
516
543
|
return operation.completed_tokens, operation.hash_value
|
517
544
|
|
545
|
+
def generic_page_transfer(self, operation, batch_size=8):
|
546
|
+
for i in range(0, len(operation.hash_value), batch_size):
|
547
|
+
page_hashes = operation.hash_value[i : i + batch_size]
|
548
|
+
page_data = self.storage_backend.batch_get(page_hashes)
|
549
|
+
if page_data is None:
|
550
|
+
logger.warning(
|
551
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
552
|
+
)
|
553
|
+
break
|
554
|
+
completed_tokens = operation.completed_tokens
|
555
|
+
if operation.increment(self.page_size * len(page_hashes)):
|
556
|
+
for i in range(len(page_hashes)):
|
557
|
+
self.mem_pool_host.set_from_flat_data_page(
|
558
|
+
operation.host_indices[completed_tokens],
|
559
|
+
page_data[i],
|
560
|
+
)
|
561
|
+
completed_tokens += self.page_size
|
562
|
+
else:
|
563
|
+
# operation terminated by controller, release pre-allocated memory
|
564
|
+
self.mem_pool_host.free(
|
565
|
+
operation.host_indices[operation.completed_tokens :]
|
566
|
+
)
|
567
|
+
break
|
568
|
+
|
569
|
+
def mooncake_page_transfer(self, operation):
|
570
|
+
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
571
|
+
operation.hash_value, operation.host_indices
|
572
|
+
)
|
573
|
+
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
574
|
+
operation.increment(len(operation.hash_value) * self.page_size)
|
575
|
+
|
518
576
|
def prefetch_io_aux_func(self):
|
519
577
|
"""
|
520
578
|
Auxiliary function conducting IO operations for prefetching.
|
@@ -522,24 +580,10 @@ class HiCacheController:
|
|
522
580
|
while not self.stop_event.is_set():
|
523
581
|
try:
|
524
582
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
530
|
-
)
|
531
|
-
break
|
532
|
-
if operation.increment(self.page_size):
|
533
|
-
self.mem_pool_host.set_from_flat_data_page(
|
534
|
-
operation.host_indices[operation.completed_tokens],
|
535
|
-
page_data,
|
536
|
-
)
|
537
|
-
else:
|
538
|
-
# operation terminated by controller, release pre-allocated memory
|
539
|
-
self.mem_pool_host.free(
|
540
|
-
operation.host_indices[operation.completed_tokens :]
|
541
|
-
)
|
542
|
-
break
|
583
|
+
if isinstance(self.storage_backend, MooncakeStore):
|
584
|
+
self.mooncake_page_transfer(operation)
|
585
|
+
else:
|
586
|
+
self.generic_page_transfer(operation)
|
543
587
|
except Empty:
|
544
588
|
continue
|
545
589
|
|
@@ -563,18 +607,27 @@ class HiCacheController:
|
|
563
607
|
remaining_tokens = len(tokens_to_fetch)
|
564
608
|
hash_value = []
|
565
609
|
while remaining_tokens >= self.page_size:
|
566
|
-
last_hash = get_hash_str(
|
610
|
+
last_hash = self.get_hash_str(
|
567
611
|
tokens_to_fetch[
|
568
612
|
storage_hit_count : storage_hit_count + self.page_size
|
569
613
|
],
|
570
614
|
last_hash,
|
571
615
|
)
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
616
|
+
|
617
|
+
# todo, more unified interface
|
618
|
+
if not isinstance(self.storage_backend, MooncakeStore):
|
619
|
+
if not self.storage_backend.exists(last_hash):
|
620
|
+
break
|
621
|
+
hash_value.append(last_hash)
|
622
|
+
storage_hit_count += self.page_size
|
623
|
+
remaining_tokens -= self.page_size
|
624
|
+
|
625
|
+
if isinstance(self.storage_backend, MooncakeStore):
|
626
|
+
# deferring to batch exists for mooncake store
|
627
|
+
exist_result = self.storage_backend.exists(hash_value)
|
628
|
+
storage_hit_count = (
|
629
|
+
sum(1 for v in exist_result.values() if v != 0) * self.page_size
|
630
|
+
)
|
578
631
|
|
579
632
|
if self.tp_world_size > 1:
|
580
633
|
storage_hit_count_tensor = torch.tensor(
|
@@ -583,7 +636,7 @@ class HiCacheController:
|
|
583
636
|
torch.distributed.all_reduce(
|
584
637
|
storage_hit_count_tensor,
|
585
638
|
op=torch.distributed.ReduceOp.MIN,
|
586
|
-
group=self.
|
639
|
+
group=self.prefetch_tp_group,
|
587
640
|
)
|
588
641
|
storage_hit_count = storage_hit_count_tensor.item()
|
589
642
|
|
@@ -622,6 +675,47 @@ class HiCacheController:
|
|
622
675
|
self.backup_queue.put(operation)
|
623
676
|
return operation.id
|
624
677
|
|
678
|
+
def generic_page_backup(self, operation, batch_size=8):
|
679
|
+
for i in range(0, len(operation.hash_value), batch_size):
|
680
|
+
page_hashes = operation.hash_value[i : i + batch_size]
|
681
|
+
page_data = [
|
682
|
+
self.mem_pool_host.get_flat_data_pages(
|
683
|
+
operation.host_indices[j * self.page_size]
|
684
|
+
)
|
685
|
+
for j in range(i, i + len(page_hashes))
|
686
|
+
]
|
687
|
+
success = self.storage_backend.batch_set(page_hashes, page_data)
|
688
|
+
if not success:
|
689
|
+
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
690
|
+
break
|
691
|
+
operation.completed_tokens += self.page_size * len(page_hashes)
|
692
|
+
|
693
|
+
def mooncake_page_backup(self, operation):
|
694
|
+
if len(operation.hash_value):
|
695
|
+
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
696
|
+
indices = operation.host_indices.tolist()
|
697
|
+
non_exist_keys = []
|
698
|
+
non_exist_indices = []
|
699
|
+
for i in range(len(operation.hash_value)):
|
700
|
+
if not exist_hashvalues[operation.hash_value[i]]:
|
701
|
+
non_exist_keys.append(operation.hash_value[i])
|
702
|
+
non_exist_indices.extend(
|
703
|
+
indices[i * self.page_size : (i + 1) * self.page_size]
|
704
|
+
)
|
705
|
+
if len(non_exist_keys) > 0:
|
706
|
+
key_strs, buffer_ptrs, buffer_sizes = (
|
707
|
+
self.mem_pool_host.get_buffer_meta(
|
708
|
+
non_exist_keys, non_exist_indices
|
709
|
+
)
|
710
|
+
)
|
711
|
+
# TODO: check the return value of batch set to see how many tokens are set successfully
|
712
|
+
self.storage_backend.batch_set(
|
713
|
+
key_strs,
|
714
|
+
target_location=buffer_ptrs,
|
715
|
+
target_sizes=buffer_sizes,
|
716
|
+
)
|
717
|
+
operation.completed_tokens += len(operation.hash_value) * self.page_size
|
718
|
+
|
625
719
|
def backup_thread_func(self):
|
626
720
|
"""
|
627
721
|
Manage backup operations from host memory to storage backend.
|
@@ -635,21 +729,25 @@ class HiCacheController:
|
|
635
729
|
last_hash = operation.last_hash
|
636
730
|
tokens_to_backup = operation.token_ids
|
637
731
|
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
732
|
+
backup_hit_count = 0
|
733
|
+
remaining_tokens = len(tokens_to_backup)
|
734
|
+
hash_value = []
|
735
|
+
while remaining_tokens >= self.page_size:
|
736
|
+
last_hash = self.get_hash_str(
|
737
|
+
tokens_to_backup[
|
738
|
+
backup_hit_count : backup_hit_count + self.page_size
|
739
|
+
],
|
643
740
|
last_hash,
|
644
|
-
self.mem_pool_host.get_flat_data_page(
|
645
|
-
operation.host_indices[i]
|
646
|
-
),
|
647
741
|
)
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
742
|
+
backup_hit_count += self.page_size
|
743
|
+
hash_value.append(last_hash)
|
744
|
+
remaining_tokens -= self.page_size
|
745
|
+
operation.hash_value = hash_value
|
746
|
+
|
747
|
+
if isinstance(self.storage_backend, MooncakeStore):
|
748
|
+
self.mooncake_page_backup(operation)
|
749
|
+
else:
|
750
|
+
self.generic_page_backup(operation)
|
653
751
|
|
654
752
|
min_completed_tokens = operation.completed_tokens
|
655
753
|
if self.tp_world_size > 1:
|
@@ -659,7 +757,7 @@ class HiCacheController:
|
|
659
757
|
torch.distributed.all_reduce(
|
660
758
|
completed_tokens_tensor,
|
661
759
|
op=torch.distributed.ReduceOp.MIN,
|
662
|
-
group=self.
|
760
|
+
group=self.backup_tp_group,
|
663
761
|
)
|
664
762
|
min_completed_tokens = completed_tokens_tensor.item()
|
665
763
|
|
@@ -222,6 +222,7 @@ class DataParallelController:
|
|
222
222
|
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
223
223
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
224
224
|
)
|
225
|
+
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
225
226
|
proc = mp.Process(
|
226
227
|
target=run_scheduler_process,
|
227
228
|
args=(
|
@@ -229,6 +230,7 @@ class DataParallelController:
|
|
229
230
|
rank_port_args,
|
230
231
|
gpu_id,
|
231
232
|
tp_rank,
|
233
|
+
moe_ep_rank,
|
232
234
|
pp_rank,
|
233
235
|
dp_rank,
|
234
236
|
writer,
|