sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/lora/mem_pool.py
CHANGED
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
|
|
13
13
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
14
14
|
LoRAType,
|
15
15
|
get_hidden_dim,
|
16
|
-
|
16
|
+
get_normalized_target_modules,
|
17
17
|
get_stacked_multiply,
|
18
|
-
|
18
|
+
get_target_module_name,
|
19
19
|
)
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
|
52
52
|
tp_size: int,
|
53
53
|
tp_rank: int,
|
54
54
|
max_lora_rank: int,
|
55
|
-
|
55
|
+
target_modules: Set[str],
|
56
56
|
base_model: torch.nn.Module,
|
57
57
|
):
|
58
58
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
|
|
62
62
|
self.tp_size: int = tp_size
|
63
63
|
self.tp_rank: int = tp_rank
|
64
64
|
self.max_lora_rank: int = max_lora_rank
|
65
|
-
self.
|
65
|
+
self.target_modules: Set[str] = target_modules
|
66
66
|
|
67
67
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
68
68
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
|
|
95
95
|
"""
|
96
96
|
if config.r > self.max_lora_rank:
|
97
97
|
return False
|
98
|
-
|
99
|
-
return
|
98
|
+
target_module_names = get_normalized_target_modules(config.target_modules)
|
99
|
+
return target_module_names.issubset(self.target_modules)
|
100
100
|
|
101
101
|
if isinstance(config, LoRAConfig):
|
102
102
|
return _can_support(config)
|
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
|
|
139
139
|
|
140
140
|
def init_buffer(
|
141
141
|
buffer: Dict[str, List[torch.Tensor]],
|
142
|
-
|
142
|
+
target_modules: Set[str],
|
143
143
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
144
144
|
):
|
145
|
-
for module_name in
|
145
|
+
for module_name in target_modules:
|
146
146
|
lora_shape = get_lora_shape_fn(
|
147
147
|
module_name, base_model, self.max_lora_rank
|
148
148
|
)
|
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
|
|
157
157
|
|
158
158
|
init_buffer(
|
159
159
|
self.A_buffer,
|
160
|
-
self.
|
160
|
+
self.target_modules,
|
161
161
|
self.get_lora_A_shape,
|
162
162
|
)
|
163
163
|
|
164
164
|
init_buffer(
|
165
165
|
self.B_buffer,
|
166
|
-
self.
|
166
|
+
self.target_modules,
|
167
167
|
self.get_lora_B_shape,
|
168
168
|
)
|
169
169
|
|
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
|
|
242
242
|
for layer_id in range(self.num_layer):
|
243
243
|
layer_weights = lora_adapter.layers[layer_id].weights
|
244
244
|
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
245
|
-
|
245
|
+
target_module: None for target_module in self.A_buffer
|
246
246
|
}
|
247
247
|
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
248
|
-
|
248
|
+
target_module: None for target_module in self.B_buffer
|
249
249
|
}
|
250
250
|
for name, weights in layer_weights.items():
|
251
|
-
|
251
|
+
target_module = get_target_module_name(name, self.target_modules)
|
252
252
|
if "lora_A" in name:
|
253
|
-
temp_A_buffer[
|
253
|
+
temp_A_buffer[target_module] = weights
|
254
254
|
else:
|
255
|
-
temp_B_buffer[
|
255
|
+
temp_B_buffer[target_module] = weights
|
256
256
|
|
257
257
|
if self.tp_size > 1:
|
258
258
|
cur_layer_modules = lora_modules[layer_id]
|
259
259
|
for module_name, module in cur_layer_modules.items():
|
260
|
-
|
260
|
+
target_module = get_target_module_name(
|
261
|
+
module_name, self.target_modules
|
262
|
+
)
|
261
263
|
|
262
|
-
if temp_A_buffer[
|
264
|
+
if temp_A_buffer[target_module] is None:
|
263
265
|
# Skip weight slicing if the weight is not present in the adapter
|
264
266
|
continue
|
265
267
|
|
266
|
-
temp_A_buffer[
|
267
|
-
temp_A_buffer[
|
268
|
+
temp_A_buffer[target_module] = module.slice_lora_a_weights(
|
269
|
+
temp_A_buffer[target_module], self.tp_rank
|
268
270
|
)
|
269
|
-
temp_B_buffer[
|
270
|
-
temp_B_buffer[
|
271
|
+
temp_B_buffer[target_module] = module.slice_lora_b_weights(
|
272
|
+
temp_B_buffer[target_module], self.tp_rank
|
271
273
|
)
|
272
274
|
|
273
275
|
for name, weights in temp_A_buffer.items():
|
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
|
|
282
284
|
load_lora_weight_tensor(buffer_view, weights)
|
283
285
|
|
284
286
|
def get_tensor(
|
285
|
-
self,
|
287
|
+
self, target_module: str, layer_id: int, lora_type: LoRAType
|
286
288
|
) -> torch.Tensor:
|
287
289
|
if lora_type == LoRAType.LORA_A:
|
288
|
-
return self.A_buffer[
|
290
|
+
return self.A_buffer[target_module][layer_id]
|
289
291
|
|
290
|
-
return self.B_buffer[
|
292
|
+
return self.B_buffer[target_module][layer_id]
|
291
293
|
|
292
294
|
def get_buffer_id(self, lora_uid: str):
|
293
295
|
return self.uid_to_buffer_id[lora_uid]
|
sglang/srt/lora/utils.py
CHANGED
@@ -84,7 +84,7 @@ def get_hidden_dim(
|
|
84
84
|
raise NotImplementedError()
|
85
85
|
|
86
86
|
|
87
|
-
def
|
87
|
+
def get_normalized_target_modules(
|
88
88
|
target_modules: Iterable[str],
|
89
89
|
) -> set[str]:
|
90
90
|
"""
|
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
|
|
100
100
|
|
101
101
|
result = set()
|
102
102
|
for name in target_modules:
|
103
|
-
|
104
|
-
result.add(
|
103
|
+
normalized_name = params_mapping.get(name, name)
|
104
|
+
result.add(normalized_name)
|
105
105
|
return result
|
106
106
|
|
107
107
|
|
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
|
|
116
116
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
117
117
|
|
118
118
|
|
119
|
-
def
|
120
|
-
target_name: str, lora_weight_names: Tuple[Set[str]]
|
121
|
-
) -> Optional[str]:
|
119
|
+
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
|
122
120
|
"""
|
123
|
-
Get the
|
121
|
+
Get the target module name in target_modules that can match full_module_name.
|
124
122
|
|
125
|
-
If there is a
|
123
|
+
If there is a target module name in target_modules that can match full_module_name, return this name
|
126
124
|
Else raise ValueError.
|
127
125
|
"""
|
128
|
-
for
|
129
|
-
if
|
130
|
-
return
|
126
|
+
for target_module in target_modules:
|
127
|
+
if target_module in full_module_name:
|
128
|
+
return target_module
|
131
129
|
raise ValueError(
|
132
|
-
f"Cannot find
|
130
|
+
f"Cannot find target module name for {full_module_name} in {target_modules}"
|
133
131
|
)
|
134
132
|
|
135
133
|
|
@@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
|
26
26
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
27
27
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
28
28
|
|
29
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
30
|
+
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
|
29
31
|
|
30
32
|
logger = logging.getLogger(__name__)
|
31
33
|
|
@@ -238,13 +240,14 @@ class HiCacheController:
|
|
238
240
|
self.io_backend = io_backend
|
239
241
|
|
240
242
|
self.enable_storage = False
|
243
|
+
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
|
241
244
|
# todo: move backend initialization to storage backend module
|
242
245
|
if storage_backend is not None:
|
243
246
|
self.storage_backend_type = storage_backend
|
244
247
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
245
248
|
|
246
249
|
if storage_backend == "file":
|
247
|
-
self.storage_backend = HiCacheFile()
|
250
|
+
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
|
248
251
|
self.get_hash_str = get_hash_str
|
249
252
|
elif storage_backend == "nixl":
|
250
253
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
@@ -257,23 +260,26 @@ class HiCacheController:
|
|
257
260
|
get_hash_str_mooncake,
|
258
261
|
)
|
259
262
|
|
260
|
-
self.storage_backend = MooncakeStore()
|
263
|
+
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
|
261
264
|
self.get_hash_str = get_hash_str_mooncake
|
262
265
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
263
266
|
assert self.mem_pool_host.layout == "page_first"
|
264
267
|
elif storage_backend == "hf3fs":
|
265
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
266
268
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
267
269
|
HiCacheHF3FS,
|
268
270
|
)
|
269
271
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
272
|
+
if self.mem_pool_host.layout == "page_first":
|
273
|
+
bytes_per_page = (
|
274
|
+
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
275
|
+
)
|
276
|
+
elif self.mem_pool_host.layout == "layer_first":
|
277
|
+
bytes_per_page = (
|
278
|
+
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
279
|
+
)
|
274
280
|
dtype = mem_pool_host.dtype
|
275
281
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
276
|
-
|
282
|
+
bytes_per_page, dtype
|
277
283
|
)
|
278
284
|
self.get_hash_str = get_hash_str
|
279
285
|
else:
|
@@ -394,6 +400,15 @@ class HiCacheController:
|
|
394
400
|
self.prefetch_thread.start()
|
395
401
|
self.backup_thread.start()
|
396
402
|
|
403
|
+
@property
|
404
|
+
def backup_skip(self):
|
405
|
+
return (
|
406
|
+
self.is_mla
|
407
|
+
and get_tensor_model_parallel_rank() != 0
|
408
|
+
# todo: only support file and mooncake
|
409
|
+
and self.storage_backend_type in ["file", "mooncake"]
|
410
|
+
)
|
411
|
+
|
397
412
|
def write(
|
398
413
|
self,
|
399
414
|
device_indices: torch.Tensor,
|
@@ -555,13 +570,34 @@ class HiCacheController:
|
|
555
570
|
operation.mark_done()
|
556
571
|
return operation.completed_tokens, operation.hash_value
|
557
572
|
|
573
|
+
def zerocopy_page_transfer(self, operation, batch_size=8):
|
574
|
+
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
575
|
+
operation.hash_value, operation.host_indices
|
576
|
+
)
|
577
|
+
for i in range(0, len(hashes), batch_size):
|
578
|
+
page_hashes = hashes[i : i + batch_size]
|
579
|
+
page_dsts = dsts[i : i + batch_size]
|
580
|
+
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
581
|
+
if page_data is None:
|
582
|
+
logger.warning(
|
583
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
584
|
+
)
|
585
|
+
break
|
586
|
+
completed_tokens = operation.completed_tokens
|
587
|
+
if operation.increment(self.page_size * len(page_hashes)):
|
588
|
+
for i in range(len(page_hashes)):
|
589
|
+
completed_tokens += self.page_size
|
590
|
+
else:
|
591
|
+
break
|
592
|
+
|
558
593
|
def generic_page_transfer(self, operation, batch_size=8):
|
559
594
|
for i in range(0, len(operation.hash_value), batch_size):
|
560
595
|
page_hashes = operation.hash_value[i : i + batch_size]
|
561
596
|
# todo: zero copy
|
562
|
-
dummy_page_dst = [
|
563
|
-
|
564
|
-
|
597
|
+
dummy_page_dst = [
|
598
|
+
self.mem_pool_host.get_dummy_flat_data_page()
|
599
|
+
for _ in range(len(page_hashes))
|
600
|
+
]
|
565
601
|
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
566
602
|
if page_data is None:
|
567
603
|
logger.warning(
|
@@ -599,7 +635,10 @@ class HiCacheController:
|
|
599
635
|
if self.is_mooncake_backend():
|
600
636
|
self.mooncake_page_transfer(operation)
|
601
637
|
elif self.storage_backend_type == "hf3fs":
|
602
|
-
self.
|
638
|
+
if self.mem_pool_host.layout == "page_first":
|
639
|
+
self.zerocopy_page_transfer(operation, batch_size=128)
|
640
|
+
elif self.mem_pool_host.layout == "layer_first":
|
641
|
+
self.generic_page_transfer(operation, batch_size=128)
|
603
642
|
else:
|
604
643
|
self.generic_page_transfer(operation)
|
605
644
|
|
@@ -716,6 +755,19 @@ class HiCacheController:
|
|
716
755
|
self.backup_queue.put(operation)
|
717
756
|
return operation.id
|
718
757
|
|
758
|
+
def zerocopy_page_backup(self, operation, batch_size=8):
|
759
|
+
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
760
|
+
operation.hash_value, operation.host_indices
|
761
|
+
)
|
762
|
+
for i in range(0, len(hashes), batch_size):
|
763
|
+
page_hashes = hashes[i : i + batch_size]
|
764
|
+
page_data = dsts[i : i + batch_size]
|
765
|
+
success = self.storage_backend.batch_set(page_hashes, page_data)
|
766
|
+
if not success:
|
767
|
+
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
768
|
+
break
|
769
|
+
operation.completed_tokens += self.page_size * len(page_hashes)
|
770
|
+
|
719
771
|
def generic_page_backup(self, operation, batch_size=8):
|
720
772
|
for i in range(0, len(operation.hash_value), batch_size):
|
721
773
|
page_hashes = operation.hash_value[i : i + batch_size]
|
@@ -767,14 +819,20 @@ class HiCacheController:
|
|
767
819
|
if operation is None:
|
768
820
|
continue
|
769
821
|
|
770
|
-
if self.
|
771
|
-
self.
|
772
|
-
|
773
|
-
self.
|
822
|
+
if not self.backup_skip:
|
823
|
+
if self.is_mooncake_backend():
|
824
|
+
self.mooncake_page_backup(operation)
|
825
|
+
elif self.storage_backend_type == "hf3fs":
|
826
|
+
if self.mem_pool_host.layout == "page_first":
|
827
|
+
self.zerocopy_page_backup(operation, batch_size=128)
|
828
|
+
elif self.mem_pool_host.layout == "layer_first":
|
829
|
+
self.generic_page_backup(operation, batch_size=128)
|
830
|
+
else:
|
831
|
+
self.generic_page_backup(operation)
|
832
|
+
min_completed_tokens = operation.completed_tokens
|
774
833
|
else:
|
775
|
-
|
834
|
+
min_completed_tokens = len(operation.token_ids)
|
776
835
|
|
777
|
-
min_completed_tokens = operation.completed_tokens
|
778
836
|
if self.tp_world_size > 1:
|
779
837
|
completed_tokens_tensor = torch.tensor(
|
780
838
|
min_completed_tokens, dtype=torch.int
|
@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
|
|
31
31
|
BatchMultimodalOut,
|
32
32
|
BatchStrOut,
|
33
33
|
BatchTokenIDOut,
|
34
|
+
FreezeGCReq,
|
34
35
|
)
|
35
36
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
36
37
|
from sglang.srt.utils import (
|
37
38
|
configure_logger,
|
39
|
+
freeze_gc,
|
38
40
|
get_zmq_socket,
|
39
41
|
kill_itself_when_parent_died,
|
40
42
|
)
|
@@ -100,6 +102,7 @@ class DetokenizerManager:
|
|
100
102
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
101
103
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
102
104
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
105
|
+
(FreezeGCReq, self.handle_freeze_gc_req),
|
103
106
|
]
|
104
107
|
)
|
105
108
|
|
@@ -108,7 +111,8 @@ class DetokenizerManager:
|
|
108
111
|
while True:
|
109
112
|
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
110
113
|
output = self._request_dispatcher(recv_obj)
|
111
|
-
|
114
|
+
if output is not None:
|
115
|
+
self.send_to_tokenizer.send_pyobj(output)
|
112
116
|
|
113
117
|
def trim_matched_stop(
|
114
118
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
@@ -216,7 +220,7 @@ class DetokenizerManager:
|
|
216
220
|
rids=recv_obj.rids,
|
217
221
|
finished_reasons=recv_obj.finished_reasons,
|
218
222
|
output_strs=output_strs,
|
219
|
-
output_ids=recv_obj.
|
223
|
+
output_ids=recv_obj.decode_ids,
|
220
224
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
225
|
completion_tokens=recv_obj.completion_tokens,
|
222
226
|
cached_tokens=recv_obj.cached_tokens,
|
@@ -247,6 +251,10 @@ class DetokenizerManager:
|
|
247
251
|
cached_tokens=recv_obj.cached_tokens,
|
248
252
|
)
|
249
253
|
|
254
|
+
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
|
255
|
+
freeze_gc("Detokenizer Manager")
|
256
|
+
return None
|
257
|
+
|
250
258
|
|
251
259
|
class LimitedCapacityDict(OrderedDict):
|
252
260
|
def __init__(self, capacity: int, *args, **kwargs):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -612,6 +612,8 @@ class EmbeddingReqInput:
|
|
612
612
|
|
613
613
|
if self.sampling_params is None:
|
614
614
|
self.sampling_params = [{}] * self.batch_size
|
615
|
+
elif isinstance(self.sampling_params, dict):
|
616
|
+
self.sampling_params = [self.sampling_params] * self.batch_size
|
615
617
|
for i in range(self.batch_size):
|
616
618
|
self.sampling_params[i]["max_new_tokens"] = 0
|
617
619
|
|
@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
|
|
660
662
|
token_type_ids: List[int]
|
661
663
|
# Dummy sampling params for compatibility
|
662
664
|
sampling_params: SamplingParams
|
665
|
+
# For data parallel rank routing
|
666
|
+
data_parallel_rank: Optional[int] = None
|
663
667
|
# For dp balance
|
664
668
|
dp_balance_id: int = -1
|
665
669
|
|
@@ -1001,6 +1005,11 @@ class ProfileReqOutput:
|
|
1001
1005
|
message: str
|
1002
1006
|
|
1003
1007
|
|
1008
|
+
@dataclass
|
1009
|
+
class FreezeGCReq:
|
1010
|
+
pass
|
1011
|
+
|
1012
|
+
|
1004
1013
|
@dataclass
|
1005
1014
|
class ConfigureLoggingReq:
|
1006
1015
|
log_requests: Optional[bool] = None
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -560,7 +560,7 @@ def embed_mm_inputs(
|
|
560
560
|
]
|
561
561
|
items_size[i + 1] = len(mm_items)
|
562
562
|
items_offsets.append(
|
563
|
-
flatten_nested_list([item.offsets for item in
|
563
|
+
flatten_nested_list([item.offsets for item in mm_items])
|
564
564
|
)
|
565
565
|
items_size = torch.cumsum(items_size, dim=0).tolist()
|
566
566
|
|
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
52
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
53
|
)
|
54
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
+
from sglang.srt.layers.moe import is_tbo_enabled
|
55
56
|
from sglang.srt.mem_cache.allocator import (
|
56
57
|
BaseTokenToKVPoolAllocator,
|
57
58
|
SWATokenToKVPoolAllocator,
|
@@ -83,18 +84,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
83
84
|
"chunked_prefill_size",
|
84
85
|
"device",
|
85
86
|
"disable_chunked_prefix_cache",
|
87
|
+
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
86
88
|
"disable_radix_cache",
|
87
|
-
"enable_two_batch_overlap",
|
88
|
-
"tbo_token_distribution_threshold",
|
89
89
|
"enable_dp_lm_head",
|
90
|
-
"
|
91
|
-
"deepep_mode",
|
92
|
-
"enable_flashinfer_cutlass_moe",
|
93
|
-
"enable_flashinfer_trtllm_moe",
|
90
|
+
"flashinfer_mxfp4_moe_precision",
|
94
91
|
"enable_flashinfer_allreduce_fusion",
|
95
92
|
"moe_dense_tp_size",
|
96
93
|
"ep_dispatch_algorithm",
|
97
|
-
"deepep_config",
|
98
94
|
"ep_num_redundant_experts",
|
99
95
|
"enable_nan_detection",
|
100
96
|
"flashinfer_mla_disable_ragged",
|
@@ -107,12 +103,11 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
107
103
|
"triton_attention_reduce_in_fp32",
|
108
104
|
"num_reserved_decode_tokens",
|
109
105
|
"weight_loader_disable_mmap",
|
110
|
-
"enable_triton_kernel_moe",
|
111
|
-
"enable_flashinfer_mxfp4_moe",
|
112
106
|
"enable_multimodal",
|
113
107
|
"enable_symm_mem",
|
114
108
|
"quantization",
|
115
109
|
"enable_custom_logit_processor",
|
110
|
+
"disaggregation_mode",
|
116
111
|
]
|
117
112
|
|
118
113
|
# Put some global args for easy access
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
66
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
|
-
from sglang.srt.layers.moe
|
67
|
+
from sglang.srt.layers.moe import initialize_moe_config
|
68
68
|
from sglang.srt.managers.io_struct import (
|
69
69
|
AbortReq,
|
70
70
|
CloseSessionReqInput,
|
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
72
|
ExpertDistributionReqOutput,
|
73
73
|
FlushCacheReqInput,
|
74
74
|
FlushCacheReqOutput,
|
75
|
+
FreezeGCReq,
|
75
76
|
GetInternalStateReq,
|
76
77
|
GetInternalStateReqOutput,
|
77
78
|
GetWeightsByNameReqInput,
|
@@ -145,6 +146,7 @@ from sglang.srt.utils import (
|
|
145
146
|
configure_gc_logger,
|
146
147
|
configure_logger,
|
147
148
|
disable_request_logging,
|
149
|
+
freeze_gc,
|
148
150
|
get_available_gpu_memory,
|
149
151
|
get_bool_env_var,
|
150
152
|
get_zmq_socket,
|
@@ -245,6 +247,9 @@ class Scheduler(
|
|
245
247
|
)
|
246
248
|
)
|
247
249
|
|
250
|
+
# Init model config
|
251
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
252
|
+
|
248
253
|
# Init inter-process communication
|
249
254
|
context = zmq.Context(2)
|
250
255
|
self.idle_sleeper = None
|
@@ -292,6 +297,9 @@ class Scheduler(
|
|
292
297
|
# Init tokenizer
|
293
298
|
self.init_tokenizer()
|
294
299
|
|
300
|
+
# Init moe config
|
301
|
+
self.init_moe_config()
|
302
|
+
|
295
303
|
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
296
304
|
if self.server_args.reasoning_parser and self.tokenizer:
|
297
305
|
reasoning_parser = ReasoningParser(
|
@@ -518,6 +526,7 @@ class Scheduler(
|
|
518
526
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
519
527
|
(SlowDownReqInput, self.slow_down),
|
520
528
|
(ProfileReq, self.profile),
|
529
|
+
(FreezeGCReq, self.handle_freeze_gc),
|
521
530
|
(GetInternalStateReq, self.get_internal_state),
|
522
531
|
(SetInternalStateReq, self.set_internal_state),
|
523
532
|
(RpcReqInput, self.handle_rpc_request),
|
@@ -538,8 +547,6 @@ class Scheduler(
|
|
538
547
|
|
539
548
|
def init_tokenizer(self):
|
540
549
|
server_args = self.server_args
|
541
|
-
|
542
|
-
self.model_config = ModelConfig.from_server_args(server_args)
|
543
550
|
self.is_generation = self.model_config.is_generation
|
544
551
|
|
545
552
|
if server_args.skip_tokenizer_init:
|
@@ -761,6 +768,10 @@ class Scheduler(
|
|
761
768
|
# The prefill requests that are in the middle of kv sending
|
762
769
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
763
770
|
|
771
|
+
def init_moe_config(self):
|
772
|
+
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
773
|
+
initialize_moe_config(self.server_args)
|
774
|
+
|
764
775
|
@DynamicGradMode()
|
765
776
|
def event_loop_normal(self):
|
766
777
|
"""A normal scheduler loop."""
|
@@ -1133,7 +1144,7 @@ class Scheduler(
|
|
1133
1144
|
f"boostrap room id. {req.rid=}"
|
1134
1145
|
)
|
1135
1146
|
logger.error(error_msg)
|
1136
|
-
prepare_abort(req, error_msg)
|
1147
|
+
prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
|
1137
1148
|
self.stream_output([req], req.return_logprob)
|
1138
1149
|
return
|
1139
1150
|
|
@@ -1823,11 +1834,6 @@ class Scheduler(
|
|
1823
1834
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
1824
1835
|
spec_algorithm=self.spec_algorithm,
|
1825
1836
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1826
|
-
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1827
|
-
enable_deepep_moe=MoeA2ABackend(
|
1828
|
-
self.server_args.moe_a2a_backend
|
1829
|
-
).is_deepep(),
|
1830
|
-
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
1831
1837
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1832
1838
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1833
1839
|
)
|
@@ -1922,9 +1928,6 @@ class Scheduler(
|
|
1922
1928
|
disable_cuda_graph: bool,
|
1923
1929
|
spec_algorithm,
|
1924
1930
|
speculative_num_draft_tokens,
|
1925
|
-
enable_two_batch_overlap: bool,
|
1926
|
-
enable_deepep_moe: bool,
|
1927
|
-
deepep_mode: DeepEPMode,
|
1928
1931
|
require_mlp_tp_gather: bool,
|
1929
1932
|
disable_overlap_schedule: bool,
|
1930
1933
|
):
|
@@ -1972,9 +1975,6 @@ class Scheduler(
|
|
1972
1975
|
is_extend_in_batch,
|
1973
1976
|
*tbo_preparer.prepare_all_gather(
|
1974
1977
|
local_batch,
|
1975
|
-
deepep_mode,
|
1976
|
-
enable_deepep_moe,
|
1977
|
-
enable_two_batch_overlap,
|
1978
1978
|
),
|
1979
1979
|
],
|
1980
1980
|
dtype=torch.int64,
|
@@ -2472,6 +2472,12 @@ class Scheduler(
|
|
2472
2472
|
if self.idle_sleeper is not None:
|
2473
2473
|
self.idle_sleeper.maybe_sleep()
|
2474
2474
|
|
2475
|
+
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
2476
|
+
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
2477
|
+
freeze_gc("Scheduler")
|
2478
|
+
self.send_to_detokenizer.send_pyobj(recv_req)
|
2479
|
+
return None
|
2480
|
+
|
2475
2481
|
|
2476
2482
|
class IdleSleeper:
|
2477
2483
|
"""
|
@@ -2582,7 +2588,10 @@ def run_scheduler_process(
|
|
2582
2588
|
if scheduler.enable_overlap:
|
2583
2589
|
scheduler.event_loop_overlap_disagg_prefill()
|
2584
2590
|
else:
|
2585
|
-
|
2591
|
+
if server_args.pp_size > 1:
|
2592
|
+
scheduler.event_loop_pp_disagg_prefill()
|
2593
|
+
else:
|
2594
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2586
2595
|
|
2587
2596
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2588
2597
|
if scheduler.enable_overlap:
|
@@ -54,7 +54,7 @@ class SessionReqNode:
|
|
54
54
|
prefix += " -- " + self.childs[0].req.rid
|
55
55
|
ret = self.childs[0]._str_helper(prefix)
|
56
56
|
for child in self.childs[1:]:
|
57
|
-
prefix = " " * len(origin_prefix) + "
|
57
|
+
prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
|
58
58
|
ret += child._str_helper(prefix)
|
59
59
|
return ret
|
60
60
|
|
@@ -89,6 +89,7 @@ class TemplateManager:
|
|
89
89
|
if template is None:
|
90
90
|
return False
|
91
91
|
|
92
|
+
# TODO: remove this hard code the reasoning pattern
|
92
93
|
force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
|
93
94
|
has_reasoning = re.search(force_reasoning_pattern, template) is not None
|
94
95
|
|
@@ -128,11 +129,12 @@ class TemplateManager:
|
|
128
129
|
logger.info(
|
129
130
|
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
130
131
|
)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
132
|
+
else:
|
133
|
+
# Default to string content format if no template was found
|
134
|
+
self._jinja_template_content_format = "string"
|
135
|
+
logger.info(
|
136
|
+
"No chat template found, defaulting to 'string' content format"
|
137
|
+
)
|
136
138
|
|
137
139
|
# Detect reasoning pattern from chat template
|
138
140
|
if tokenizer_manager.tokenizer:
|