sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
|
|
78
78
|
dst_kv_ptrs: list[int]
|
79
79
|
dst_aux_ptrs: list[int]
|
80
80
|
gpu_id: int
|
81
|
+
decode_tp_size: int
|
82
|
+
decode_tp_rank: int
|
83
|
+
dst_kv_item_len: int
|
81
84
|
|
82
85
|
@classmethod
|
83
86
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
|
|
90
93
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
91
94
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
92
95
|
gpu_id=int(msg[7].decode("ascii")),
|
96
|
+
decode_tp_size=int(msg[8].decode("ascii")),
|
97
|
+
decode_tp_rank=int(msg[9].decode("ascii")),
|
98
|
+
dst_kv_item_len=int(msg[10].decode("ascii")),
|
93
99
|
)
|
94
100
|
|
95
101
|
|
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
|
|
166
172
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
167
173
|
):
|
168
174
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
169
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM"
|
175
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
170
176
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
171
177
|
if not self.kv_descs:
|
172
178
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
|
|
175
181
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
176
182
|
):
|
177
183
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
178
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM"
|
184
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
179
185
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
180
186
|
if not self.aux_descs:
|
181
187
|
raise Exception("NIXL memory registration failed for aux tensors")
|
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
|
|
222
228
|
logger.debug(
|
223
229
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
224
230
|
)
|
225
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM"
|
226
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM"
|
231
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
232
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
227
233
|
# Transfer data
|
228
234
|
xfer_handle = self.agent.initialize_xfer(
|
229
235
|
"WRITE",
|
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
|
|
239
245
|
raise Exception("KVSender failed to post transfer")
|
240
246
|
return xfer_handle
|
241
247
|
|
248
|
+
def send_kvcache_slice(
|
249
|
+
self,
|
250
|
+
peer_name: str,
|
251
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
252
|
+
dst_kv_ptrs: list[int],
|
253
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
254
|
+
dst_gpu_id: int,
|
255
|
+
notif: str,
|
256
|
+
prefill_tp_size: int,
|
257
|
+
decode_tp_size: int,
|
258
|
+
decode_tp_rank: int,
|
259
|
+
dst_kv_item_len: int,
|
260
|
+
):
|
261
|
+
# Get configuration from kv_args
|
262
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
263
|
+
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
264
|
+
num_kv_heads = self.kv_args.kv_head_num
|
265
|
+
|
266
|
+
# Calculate head distribution
|
267
|
+
src_heads_per_rank = num_kv_heads
|
268
|
+
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
269
|
+
|
270
|
+
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
271
|
+
page_size = self.kv_args.page_size
|
272
|
+
|
273
|
+
bytes_per_head_slice_to_send = (
|
274
|
+
dst_kv_item_len // page_size // dst_heads_per_rank
|
275
|
+
)
|
276
|
+
|
277
|
+
# Determine which heads to send
|
278
|
+
if prefill_tp_size > decode_tp_size:
|
279
|
+
# Multiple prefill ranks to one decode rank
|
280
|
+
src_head_start_offset = 0
|
281
|
+
num_heads_to_send = src_heads_per_rank
|
282
|
+
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
283
|
+
else:
|
284
|
+
# Send KVCache from 1 prefill instance to multiple decode instances
|
285
|
+
src_head_start_offset = (
|
286
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
287
|
+
) % src_heads_per_rank
|
288
|
+
num_heads_to_send = dst_heads_per_rank
|
289
|
+
dst_head_start_offset = 0
|
290
|
+
|
291
|
+
# Create transfer descriptors
|
292
|
+
src_addrs = []
|
293
|
+
dst_addrs = []
|
294
|
+
|
295
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
296
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
297
|
+
|
298
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
299
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
300
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
301
|
+
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
302
|
+
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
303
|
+
|
304
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
305
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
306
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
307
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
308
|
+
|
309
|
+
src_dst_ptr_pairs = [
|
310
|
+
(
|
311
|
+
src_k_ptrs[layer_id],
|
312
|
+
dst_k_ptrs[layer_id],
|
313
|
+
)
|
314
|
+
for layer_id in range(len(src_k_ptrs))
|
315
|
+
] + [
|
316
|
+
(
|
317
|
+
src_v_ptrs[layer_id],
|
318
|
+
dst_v_ptrs[layer_id],
|
319
|
+
)
|
320
|
+
for layer_id in range(len(src_v_ptrs))
|
321
|
+
]
|
322
|
+
|
323
|
+
src_addrs = []
|
324
|
+
dst_addrs = []
|
325
|
+
|
326
|
+
# Calculate strides for a single token slot
|
327
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
328
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
329
|
+
|
330
|
+
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
331
|
+
for i in range(len(prefill_kv_indices)):
|
332
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
333
|
+
decode_page_idx = int(dst_kv_indices[i])
|
334
|
+
|
335
|
+
# Get the starting addresses for the current src and dst pages
|
336
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
337
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
338
|
+
|
339
|
+
# Iterate through each valid token slot within the current page
|
340
|
+
for token_slot_in_page in range(page_size):
|
341
|
+
# Calculate the start address of the current token slot
|
342
|
+
src_token_slot_start_addr = (
|
343
|
+
src_page_start_addr
|
344
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
345
|
+
)
|
346
|
+
dst_token_slot_start_addr = (
|
347
|
+
dst_page_start_addr
|
348
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
349
|
+
)
|
350
|
+
|
351
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
352
|
+
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
353
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
354
|
+
|
355
|
+
src_addrs.append(
|
356
|
+
(
|
357
|
+
src_slice_addr,
|
358
|
+
heads_bytes_per_token_to_send,
|
359
|
+
self.kv_args.gpu_id,
|
360
|
+
)
|
361
|
+
)
|
362
|
+
dst_addrs.append(
|
363
|
+
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
364
|
+
)
|
365
|
+
|
366
|
+
# Use NIXL agent for transfer
|
367
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
368
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
369
|
+
|
370
|
+
xfer_handle = self.agent.initialize_xfer(
|
371
|
+
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
372
|
+
)
|
373
|
+
if not xfer_handle:
|
374
|
+
raise Exception("Failed to create sliced KV transfer")
|
375
|
+
|
376
|
+
state = self.agent.transfer(xfer_handle)
|
377
|
+
if state == "ERR":
|
378
|
+
raise Exception("Failed to post sliced KV transfer")
|
379
|
+
|
380
|
+
return xfer_handle
|
381
|
+
|
242
382
|
def send_aux(
|
243
383
|
self,
|
244
384
|
peer_name: str,
|
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
|
|
255
395
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
256
396
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
257
397
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
258
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM"
|
259
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM"
|
398
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
399
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
260
400
|
# Transfer data
|
261
401
|
xfer_handle = self.agent.initialize_xfer(
|
262
402
|
"WRITE",
|
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
|
|
296
436
|
assert req.agent_name in self.decode_kv_args_table
|
297
437
|
|
298
438
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
439
|
+
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
440
|
+
|
441
|
+
if decode_tp_size == self.tp_size:
|
442
|
+
kv_xfer_handle = self.send_kvcache(
|
443
|
+
req.agent_name,
|
444
|
+
kv_indices,
|
445
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
446
|
+
chunked_dst_kv_indice,
|
447
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
448
|
+
notif,
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
kv_xfer_handle = self.send_kvcache_slice(
|
452
|
+
req.agent_name,
|
453
|
+
kv_indices,
|
454
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
455
|
+
chunked_dst_kv_indice,
|
456
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
457
|
+
notif,
|
458
|
+
prefill_tp_size=self.tp_size,
|
459
|
+
decode_tp_size=decode_tp_size,
|
460
|
+
decode_tp_rank=self.decode_kv_args_table[
|
461
|
+
req.agent_name
|
462
|
+
].decode_tp_rank,
|
463
|
+
dst_kv_item_len=self.decode_kv_args_table[
|
464
|
+
req.agent_name
|
465
|
+
].dst_kv_item_len,
|
466
|
+
)
|
467
|
+
|
307
468
|
handles.append(kv_xfer_handle)
|
308
469
|
# Only the last chunk we need to send the aux data.
|
309
470
|
if is_last:
|
@@ -454,11 +615,11 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
454
615
|
mgr: NixlKVManager,
|
455
616
|
bootstrap_addr: str,
|
456
617
|
bootstrap_room: Optional[int] = None,
|
457
|
-
|
618
|
+
prefill_dp_rank: Optional[int] = None,
|
458
619
|
):
|
459
620
|
self.started_transfer = False
|
460
621
|
self.conclude_state = None
|
461
|
-
super().__init__(mgr, bootstrap_addr, bootstrap_room,
|
622
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
462
623
|
|
463
624
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
464
625
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
521
682
|
packed_kv_data_ptrs,
|
522
683
|
packed_aux_data_ptrs,
|
523
684
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
685
|
+
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
686
|
+
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
687
|
+
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
524
688
|
]
|
525
689
|
)
|
526
690
|
|
@@ -23,7 +23,7 @@ import logging
|
|
23
23
|
import threading
|
24
24
|
from collections import deque
|
25
25
|
from http import HTTPStatus
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, List, Optional, Type
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
|
|
140
140
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
141
141
|
kv_args.gpu_id = self.scheduler.gpu_id
|
142
142
|
|
143
|
-
kv_manager_class = get_kv_class(
|
144
|
-
|
143
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
144
|
+
self.transfer_backend, KVClassType.MANAGER
|
145
|
+
)
|
146
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
145
147
|
kv_args,
|
146
148
|
DisaggregationMode.PREFILL,
|
147
149
|
self.scheduler.server_args,
|
@@ -1,21 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import dataclasses
|
4
3
|
import os
|
5
4
|
import random
|
6
|
-
import threading
|
7
|
-
import warnings
|
8
5
|
from collections import deque
|
9
6
|
from contextlib import nullcontext
|
10
7
|
from enum import Enum
|
11
|
-
from typing import TYPE_CHECKING, List, Optional
|
8
|
+
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
12
9
|
|
13
10
|
import numpy as np
|
14
|
-
import requests
|
15
11
|
import torch
|
16
12
|
import torch.distributed as dist
|
17
13
|
|
18
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_npu
|
19
15
|
|
20
16
|
if TYPE_CHECKING:
|
21
17
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -217,7 +213,9 @@ class KVClassType(Enum):
|
|
217
213
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
218
214
|
|
219
215
|
|
220
|
-
def get_kv_class(
|
216
|
+
def get_kv_class(
|
217
|
+
transfer_backend: TransferBackend, class_type: KVClassType
|
218
|
+
) -> Optional[Type]:
|
221
219
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
222
220
|
|
223
221
|
if transfer_backend == TransferBackend.MOONCAKE:
|
@@ -305,49 +303,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
305
303
|
return (num_kv_indices + page_size - 1) // page_size
|
306
304
|
|
307
305
|
|
308
|
-
#########################
|
309
|
-
# PDLB Registry
|
310
|
-
#########################
|
311
|
-
|
312
|
-
|
313
|
-
@dataclasses.dataclass
|
314
|
-
class PDRegistryRequest:
|
315
|
-
"""A request to register a machine itself to the LB."""
|
316
|
-
|
317
|
-
mode: str
|
318
|
-
registry_url: str
|
319
|
-
bootstrap_port: Optional[int] = None
|
320
|
-
|
321
|
-
def __post_init__(self):
|
322
|
-
if self.mode == "prefill" and self.bootstrap_port is None:
|
323
|
-
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
324
|
-
elif self.mode == "decode" and self.bootstrap_port is not None:
|
325
|
-
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
326
|
-
elif self.mode not in ["prefill", "decode"]:
|
327
|
-
raise ValueError(
|
328
|
-
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
329
|
-
)
|
330
|
-
|
331
|
-
|
332
|
-
def register_disaggregation_server(
|
333
|
-
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
334
|
-
):
|
335
|
-
boostrap_port = bootstrap_port if mode == "prefill" else None
|
336
|
-
registry_request = PDRegistryRequest(
|
337
|
-
mode=mode,
|
338
|
-
registry_url=f"http://{get_ip()}:{server_port}",
|
339
|
-
bootstrap_port=boostrap_port,
|
340
|
-
)
|
341
|
-
res = requests.post(
|
342
|
-
f"{pdlb_url}/register",
|
343
|
-
json=dataclasses.asdict(registry_request),
|
344
|
-
)
|
345
|
-
if res.status_code != 200:
|
346
|
-
warnings.warn(
|
347
|
-
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
348
|
-
)
|
349
|
-
|
350
|
-
|
351
306
|
#########################
|
352
307
|
# Misc
|
353
308
|
#########################
|
@@ -64,6 +64,9 @@ class GraphCaptureContext:
|
|
64
64
|
|
65
65
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
66
66
|
|
67
|
+
# use int value instead of ReduceOp.SUM to support torch compile
|
68
|
+
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
69
|
+
|
67
70
|
|
68
71
|
def _split_tensor_dict(
|
69
72
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
@@ -489,9 +492,7 @@ class GroupCoordinator:
|
|
489
492
|
|
490
493
|
if input_.is_cpu:
|
491
494
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
492
|
-
torch.ops.sgl_kernel.shm_allreduce(
|
493
|
-
input_, torch.distributed.ReduceOp.SUM
|
494
|
-
)
|
495
|
+
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
|
495
496
|
else:
|
496
497
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
497
498
|
return input_
|
@@ -1586,6 +1587,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
1586
1587
|
_TP = old_tp_group
|
1587
1588
|
|
1588
1589
|
|
1590
|
+
def get_world_size():
|
1591
|
+
"""Return world size for the world group."""
|
1592
|
+
return get_world_group().world_size
|
1593
|
+
|
1594
|
+
|
1595
|
+
def get_world_rank():
|
1596
|
+
"""Return my rank for the world group."""
|
1597
|
+
return get_world_group().rank_in_group
|
1598
|
+
|
1599
|
+
|
1589
1600
|
def get_tensor_model_parallel_world_size():
|
1590
1601
|
"""Return world size for the tensor model parallel group."""
|
1591
1602
|
return get_tp_group().world_size
|
@@ -1596,6 +1607,16 @@ def get_tensor_model_parallel_rank():
|
|
1596
1607
|
return get_tp_group().rank_in_group
|
1597
1608
|
|
1598
1609
|
|
1610
|
+
def get_pipeline_model_parallel_world_size():
|
1611
|
+
"""Return world size for the pipeline model parallel group."""
|
1612
|
+
return get_pp_group().world_size
|
1613
|
+
|
1614
|
+
|
1615
|
+
def get_pipeline_model_parallel_rank():
|
1616
|
+
"""Return my rank for the pipeline model parallel group."""
|
1617
|
+
return get_pp_group().rank_in_group
|
1618
|
+
|
1619
|
+
|
1599
1620
|
def get_moe_expert_parallel_world_size():
|
1600
1621
|
"""Return world size for the moe expert parallel group."""
|
1601
1622
|
return get_moe_ep_group().world_size
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -33,6 +33,8 @@ import zmq
|
|
33
33
|
import zmq.asyncio
|
34
34
|
from PIL.Image import Image
|
35
35
|
|
36
|
+
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
|
37
|
+
|
36
38
|
# Fix a bug of Python threading
|
37
39
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
38
40
|
|
@@ -138,6 +140,12 @@ class Engine(EngineBase):
|
|
138
140
|
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
139
141
|
)
|
140
142
|
|
143
|
+
if server_args.enable_trace:
|
144
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
145
|
+
if server_args.disaggregation_mode == "null":
|
146
|
+
thread_label = "Tokenizer"
|
147
|
+
trace_set_thread_info(thread_label)
|
148
|
+
|
141
149
|
def generate(
|
142
150
|
self,
|
143
151
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -364,9 +372,9 @@ class Engine(EngineBase):
|
|
364
372
|
loop = asyncio.get_event_loop()
|
365
373
|
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
|
366
374
|
|
367
|
-
def start_profile(self):
|
375
|
+
def start_profile(self, **kwargs):
|
368
376
|
loop = asyncio.get_event_loop()
|
369
|
-
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
377
|
+
loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
|
370
378
|
|
371
379
|
def stop_profile(self):
|
372
380
|
loop = asyncio.get_event_loop()
|
@@ -655,7 +663,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
655
663
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
656
664
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
657
665
|
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
658
|
-
os.environ
|
666
|
+
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
667
|
+
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
659
668
|
|
660
669
|
# Can also be passed as argument
|
661
670
|
os.environ["SGLANG_RUN_ID"] = (
|
@@ -673,7 +682,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
673
682
|
if server_args.attention_backend == "flashinfer":
|
674
683
|
assert_pkg_version(
|
675
684
|
"flashinfer_python",
|
676
|
-
"0.3.
|
685
|
+
"0.3.1",
|
677
686
|
"Please uninstall the old version and "
|
678
687
|
"reinstall the latest version by following the instructions "
|
679
688
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -681,7 +690,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
681
690
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
682
691
|
assert_pkg_version(
|
683
692
|
"sgl-kernel",
|
684
|
-
"0.3.
|
693
|
+
"0.3.9.post2",
|
685
694
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
686
695
|
)
|
687
696
|
|
@@ -703,6 +712,24 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
703
712
|
mp.set_start_method("spawn", force=True)
|
704
713
|
|
705
714
|
|
715
|
+
def _init_tokenizer_manager(
|
716
|
+
server_args: ServerArgs, port_args: PortArgs
|
717
|
+
) -> TokenizerManager:
|
718
|
+
# Launch tokenizer process
|
719
|
+
tokenizer_manager = TokenizerManager(server_args, port_args)
|
720
|
+
|
721
|
+
# Initialize templates
|
722
|
+
template_manager = TemplateManager()
|
723
|
+
template_manager.initialize_templates(
|
724
|
+
tokenizer_manager=tokenizer_manager,
|
725
|
+
model_path=server_args.model_path,
|
726
|
+
chat_template=server_args.chat_template,
|
727
|
+
completion_template=server_args.completion_template,
|
728
|
+
)
|
729
|
+
|
730
|
+
return tokenizer_manager, template_manager
|
731
|
+
|
732
|
+
|
706
733
|
def _launch_subprocesses(
|
707
734
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
708
735
|
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
@@ -815,23 +842,15 @@ def _launch_subprocesses(
|
|
815
842
|
),
|
816
843
|
)
|
817
844
|
detoken_proc.start()
|
845
|
+
|
846
|
+
# Init tokenizer manager first, as the bootstrap server is initialized here
|
818
847
|
if server_args.tokenizer_worker_num > 1:
|
819
848
|
# Launch multi-tokenizer router
|
820
849
|
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
821
|
-
|
822
|
-
# Initialize templates
|
823
850
|
template_manager = None
|
824
851
|
else:
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
# Initialize templates
|
829
|
-
template_manager = TemplateManager()
|
830
|
-
template_manager.initialize_templates(
|
831
|
-
tokenizer_manager=tokenizer_manager,
|
832
|
-
model_path=server_args.model_path,
|
833
|
-
chat_template=server_args.chat_template,
|
834
|
-
completion_template=server_args.completion_template,
|
852
|
+
tokenizer_manager, template_manager = _init_tokenizer_manager(
|
853
|
+
server_args, port_args
|
835
854
|
)
|
836
855
|
|
837
856
|
# Wait for the model to finish loading
|
@@ -855,5 +874,7 @@ def _launch_subprocesses(
|
|
855
874
|
|
856
875
|
# Assume all schedulers have the same scheduler_info
|
857
876
|
scheduler_info = scheduler_infos[0]
|
877
|
+
|
858
878
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
879
|
+
|
859
880
|
return tokenizer_manager, template_manager, scheduler_info
|