sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
|
|
78
78
|
dst_kv_ptrs: list[int]
|
79
79
|
dst_aux_ptrs: list[int]
|
80
80
|
gpu_id: int
|
81
|
+
decode_tp_size: int
|
82
|
+
decode_tp_rank: int
|
83
|
+
dst_kv_item_len: int
|
81
84
|
|
82
85
|
@classmethod
|
83
86
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
|
|
90
93
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
91
94
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
92
95
|
gpu_id=int(msg[7].decode("ascii")),
|
96
|
+
decode_tp_size=int(msg[8].decode("ascii")),
|
97
|
+
decode_tp_rank=int(msg[9].decode("ascii")),
|
98
|
+
dst_kv_item_len=int(msg[10].decode("ascii")),
|
93
99
|
)
|
94
100
|
|
95
101
|
|
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
|
|
166
172
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
167
173
|
):
|
168
174
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
169
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM"
|
175
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
170
176
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
171
177
|
if not self.kv_descs:
|
172
178
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
|
|
175
181
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
176
182
|
):
|
177
183
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
178
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM"
|
184
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
179
185
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
180
186
|
if not self.aux_descs:
|
181
187
|
raise Exception("NIXL memory registration failed for aux tensors")
|
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
|
|
222
228
|
logger.debug(
|
223
229
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
224
230
|
)
|
225
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM"
|
226
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM"
|
231
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
232
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
227
233
|
# Transfer data
|
228
234
|
xfer_handle = self.agent.initialize_xfer(
|
229
235
|
"WRITE",
|
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
|
|
239
245
|
raise Exception("KVSender failed to post transfer")
|
240
246
|
return xfer_handle
|
241
247
|
|
248
|
+
def send_kvcache_slice(
|
249
|
+
self,
|
250
|
+
peer_name: str,
|
251
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
252
|
+
dst_kv_ptrs: list[int],
|
253
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
254
|
+
dst_gpu_id: int,
|
255
|
+
notif: str,
|
256
|
+
prefill_tp_size: int,
|
257
|
+
decode_tp_size: int,
|
258
|
+
decode_tp_rank: int,
|
259
|
+
dst_kv_item_len: int,
|
260
|
+
):
|
261
|
+
# Get configuration from kv_args
|
262
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
263
|
+
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
264
|
+
num_kv_heads = self.kv_args.kv_head_num
|
265
|
+
|
266
|
+
# Calculate head distribution
|
267
|
+
src_heads_per_rank = num_kv_heads
|
268
|
+
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
269
|
+
|
270
|
+
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
271
|
+
page_size = self.kv_args.page_size
|
272
|
+
|
273
|
+
bytes_per_head_slice_to_send = (
|
274
|
+
dst_kv_item_len // page_size // dst_heads_per_rank
|
275
|
+
)
|
276
|
+
|
277
|
+
# Determine which heads to send
|
278
|
+
if prefill_tp_size > decode_tp_size:
|
279
|
+
# Multiple prefill ranks to one decode rank
|
280
|
+
src_head_start_offset = 0
|
281
|
+
num_heads_to_send = src_heads_per_rank
|
282
|
+
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
283
|
+
else:
|
284
|
+
# Send KVCache from 1 prefill instance to multiple decode instances
|
285
|
+
src_head_start_offset = (
|
286
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
287
|
+
) % src_heads_per_rank
|
288
|
+
num_heads_to_send = dst_heads_per_rank
|
289
|
+
dst_head_start_offset = 0
|
290
|
+
|
291
|
+
# Create transfer descriptors
|
292
|
+
src_addrs = []
|
293
|
+
dst_addrs = []
|
294
|
+
|
295
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
296
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
297
|
+
|
298
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
299
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
300
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
301
|
+
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
302
|
+
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
303
|
+
|
304
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
305
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
306
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
307
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
308
|
+
|
309
|
+
src_dst_ptr_pairs = [
|
310
|
+
(
|
311
|
+
src_k_ptrs[layer_id],
|
312
|
+
dst_k_ptrs[layer_id],
|
313
|
+
)
|
314
|
+
for layer_id in range(len(src_k_ptrs))
|
315
|
+
] + [
|
316
|
+
(
|
317
|
+
src_v_ptrs[layer_id],
|
318
|
+
dst_v_ptrs[layer_id],
|
319
|
+
)
|
320
|
+
for layer_id in range(len(src_v_ptrs))
|
321
|
+
]
|
322
|
+
|
323
|
+
src_addrs = []
|
324
|
+
dst_addrs = []
|
325
|
+
|
326
|
+
# Calculate strides for a single token slot
|
327
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
328
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
329
|
+
|
330
|
+
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
331
|
+
for i in range(len(prefill_kv_indices)):
|
332
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
333
|
+
decode_page_idx = int(dst_kv_indices[i])
|
334
|
+
|
335
|
+
# Get the starting addresses for the current src and dst pages
|
336
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
337
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
338
|
+
|
339
|
+
# Iterate through each valid token slot within the current page
|
340
|
+
for token_slot_in_page in range(page_size):
|
341
|
+
# Calculate the start address of the current token slot
|
342
|
+
src_token_slot_start_addr = (
|
343
|
+
src_page_start_addr
|
344
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
345
|
+
)
|
346
|
+
dst_token_slot_start_addr = (
|
347
|
+
dst_page_start_addr
|
348
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
349
|
+
)
|
350
|
+
|
351
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
352
|
+
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
353
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
354
|
+
|
355
|
+
src_addrs.append(
|
356
|
+
(
|
357
|
+
src_slice_addr,
|
358
|
+
heads_bytes_per_token_to_send,
|
359
|
+
self.kv_args.gpu_id,
|
360
|
+
)
|
361
|
+
)
|
362
|
+
dst_addrs.append(
|
363
|
+
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
364
|
+
)
|
365
|
+
|
366
|
+
# Use NIXL agent for transfer
|
367
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
368
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
369
|
+
|
370
|
+
xfer_handle = self.agent.initialize_xfer(
|
371
|
+
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
372
|
+
)
|
373
|
+
if not xfer_handle:
|
374
|
+
raise Exception("Failed to create sliced KV transfer")
|
375
|
+
|
376
|
+
state = self.agent.transfer(xfer_handle)
|
377
|
+
if state == "ERR":
|
378
|
+
raise Exception("Failed to post sliced KV transfer")
|
379
|
+
|
380
|
+
return xfer_handle
|
381
|
+
|
242
382
|
def send_aux(
|
243
383
|
self,
|
244
384
|
peer_name: str,
|
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
|
|
255
395
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
256
396
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
257
397
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
258
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM"
|
259
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM"
|
398
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
399
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
260
400
|
# Transfer data
|
261
401
|
xfer_handle = self.agent.initialize_xfer(
|
262
402
|
"WRITE",
|
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
|
|
296
436
|
assert req.agent_name in self.decode_kv_args_table
|
297
437
|
|
298
438
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
439
|
+
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
440
|
+
|
441
|
+
if decode_tp_size == self.tp_size:
|
442
|
+
kv_xfer_handle = self.send_kvcache(
|
443
|
+
req.agent_name,
|
444
|
+
kv_indices,
|
445
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
446
|
+
chunked_dst_kv_indice,
|
447
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
448
|
+
notif,
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
kv_xfer_handle = self.send_kvcache_slice(
|
452
|
+
req.agent_name,
|
453
|
+
kv_indices,
|
454
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
455
|
+
chunked_dst_kv_indice,
|
456
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
457
|
+
notif,
|
458
|
+
prefill_tp_size=self.tp_size,
|
459
|
+
decode_tp_size=decode_tp_size,
|
460
|
+
decode_tp_rank=self.decode_kv_args_table[
|
461
|
+
req.agent_name
|
462
|
+
].decode_tp_rank,
|
463
|
+
dst_kv_item_len=self.decode_kv_args_table[
|
464
|
+
req.agent_name
|
465
|
+
].dst_kv_item_len,
|
466
|
+
)
|
467
|
+
|
307
468
|
handles.append(kv_xfer_handle)
|
308
469
|
# Only the last chunk we need to send the aux data.
|
309
470
|
if is_last:
|
@@ -454,11 +615,11 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
454
615
|
mgr: NixlKVManager,
|
455
616
|
bootstrap_addr: str,
|
456
617
|
bootstrap_room: Optional[int] = None,
|
457
|
-
|
618
|
+
prefill_dp_rank: Optional[int] = None,
|
458
619
|
):
|
459
620
|
self.started_transfer = False
|
460
621
|
self.conclude_state = None
|
461
|
-
super().__init__(mgr, bootstrap_addr, bootstrap_room,
|
622
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
462
623
|
|
463
624
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
464
625
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
521
682
|
packed_kv_data_ptrs,
|
522
683
|
packed_aux_data_ptrs,
|
523
684
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
685
|
+
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
686
|
+
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
687
|
+
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
524
688
|
]
|
525
689
|
)
|
526
690
|
|
@@ -23,7 +23,7 @@ import logging
|
|
23
23
|
import threading
|
24
24
|
from collections import deque
|
25
25
|
from http import HTTPStatus
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, List, Optional, Type
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
|
|
140
140
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
141
141
|
kv_args.gpu_id = self.scheduler.gpu_id
|
142
142
|
|
143
|
-
kv_manager_class = get_kv_class(
|
144
|
-
|
143
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
144
|
+
self.transfer_backend, KVClassType.MANAGER
|
145
|
+
)
|
146
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
145
147
|
kv_args,
|
146
148
|
DisaggregationMode.PREFILL,
|
147
149
|
self.scheduler.server_args,
|
@@ -1,21 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import dataclasses
|
4
3
|
import os
|
5
4
|
import random
|
6
|
-
import threading
|
7
|
-
import warnings
|
8
5
|
from collections import deque
|
9
6
|
from contextlib import nullcontext
|
10
7
|
from enum import Enum
|
11
|
-
from typing import TYPE_CHECKING, List, Optional
|
8
|
+
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
12
9
|
|
13
10
|
import numpy as np
|
14
|
-
import requests
|
15
11
|
import torch
|
16
12
|
import torch.distributed as dist
|
17
13
|
|
18
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_npu
|
19
15
|
|
20
16
|
if TYPE_CHECKING:
|
21
17
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -217,7 +213,9 @@ class KVClassType(Enum):
|
|
217
213
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
218
214
|
|
219
215
|
|
220
|
-
def get_kv_class(
|
216
|
+
def get_kv_class(
|
217
|
+
transfer_backend: TransferBackend, class_type: KVClassType
|
218
|
+
) -> Optional[Type]:
|
221
219
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
222
220
|
|
223
221
|
if transfer_backend == TransferBackend.MOONCAKE:
|
@@ -305,49 +303,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
305
303
|
return (num_kv_indices + page_size - 1) // page_size
|
306
304
|
|
307
305
|
|
308
|
-
#########################
|
309
|
-
# PDLB Registry
|
310
|
-
#########################
|
311
|
-
|
312
|
-
|
313
|
-
@dataclasses.dataclass
|
314
|
-
class PDRegistryRequest:
|
315
|
-
"""A request to register a machine itself to the LB."""
|
316
|
-
|
317
|
-
mode: str
|
318
|
-
registry_url: str
|
319
|
-
bootstrap_port: Optional[int] = None
|
320
|
-
|
321
|
-
def __post_init__(self):
|
322
|
-
if self.mode == "prefill" and self.bootstrap_port is None:
|
323
|
-
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
324
|
-
elif self.mode == "decode" and self.bootstrap_port is not None:
|
325
|
-
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
326
|
-
elif self.mode not in ["prefill", "decode"]:
|
327
|
-
raise ValueError(
|
328
|
-
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
329
|
-
)
|
330
|
-
|
331
|
-
|
332
|
-
def register_disaggregation_server(
|
333
|
-
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
334
|
-
):
|
335
|
-
boostrap_port = bootstrap_port if mode == "prefill" else None
|
336
|
-
registry_request = PDRegistryRequest(
|
337
|
-
mode=mode,
|
338
|
-
registry_url=f"http://{get_ip()}:{server_port}",
|
339
|
-
bootstrap_port=boostrap_port,
|
340
|
-
)
|
341
|
-
res = requests.post(
|
342
|
-
f"{pdlb_url}/register",
|
343
|
-
json=dataclasses.asdict(registry_request),
|
344
|
-
)
|
345
|
-
if res.status_code != 200:
|
346
|
-
warnings.warn(
|
347
|
-
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
348
|
-
)
|
349
|
-
|
350
|
-
|
351
306
|
#########################
|
352
307
|
# Misc
|
353
308
|
#########################
|
@@ -64,6 +64,9 @@ class GraphCaptureContext:
|
|
64
64
|
|
65
65
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
66
66
|
|
67
|
+
# use int value instead of ReduceOp.SUM to support torch compile
|
68
|
+
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
69
|
+
|
67
70
|
|
68
71
|
def _split_tensor_dict(
|
69
72
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
@@ -489,9 +492,7 @@ class GroupCoordinator:
|
|
489
492
|
|
490
493
|
if input_.is_cpu:
|
491
494
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
492
|
-
torch.ops.sgl_kernel.shm_allreduce(
|
493
|
-
input_, torch.distributed.ReduceOp.SUM
|
494
|
-
)
|
495
|
+
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
|
495
496
|
else:
|
496
497
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
497
498
|
return input_
|
@@ -879,17 +880,16 @@ class GroupCoordinator:
|
|
879
880
|
size_tensor = torch.tensor(
|
880
881
|
[object_tensor.numel()],
|
881
882
|
dtype=torch.long,
|
882
|
-
device=
|
883
|
+
device="cpu",
|
883
884
|
)
|
884
|
-
|
885
885
|
# Send object size
|
886
|
-
torch.distributed.send(
|
887
|
-
size_tensor, dst=self.ranks[dst], group=self.device_group
|
888
|
-
)
|
886
|
+
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
889
887
|
|
890
888
|
# Send object
|
891
889
|
torch.distributed.send(
|
892
|
-
object_tensor,
|
890
|
+
object_tensor,
|
891
|
+
dst=self.ranks[dst],
|
892
|
+
group=self.device_group,
|
893
893
|
)
|
894
894
|
|
895
895
|
return None
|
@@ -904,13 +904,11 @@ class GroupCoordinator:
|
|
904
904
|
src != self.rank_in_group
|
905
905
|
), "Invalid source rank. Source rank is the same as the current rank."
|
906
906
|
|
907
|
-
size_tensor = torch.empty(
|
908
|
-
1, dtype=torch.long, device=torch.cuda.current_device()
|
909
|
-
)
|
907
|
+
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
910
908
|
|
911
909
|
# Receive object size
|
912
910
|
rank_size = torch.distributed.recv(
|
913
|
-
size_tensor, src=self.ranks[src], group=self.
|
911
|
+
size_tensor, src=self.ranks[src], group=self.cpu_group
|
914
912
|
)
|
915
913
|
|
916
914
|
# Tensor to receive serialized objects into.
|
@@ -928,7 +926,7 @@ class GroupCoordinator:
|
|
928
926
|
rank_object == rank_size
|
929
927
|
), "Received object sender rank does not match the size sender rank."
|
930
928
|
|
931
|
-
obj = pickle.loads(object_tensor.cpu().numpy()
|
929
|
+
obj = pickle.loads(object_tensor.cpu().numpy())
|
932
930
|
|
933
931
|
return obj
|
934
932
|
|
@@ -1461,43 +1459,49 @@ def initialize_model_parallel(
|
|
1461
1459
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
1462
1460
|
|
1463
1461
|
moe_ep_size = expert_model_parallel_size
|
1464
|
-
|
1465
1462
|
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
1463
|
+
|
1466
1464
|
global _MOE_EP
|
1467
1465
|
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
1468
|
-
group_ranks = []
|
1469
|
-
for i in range(num_tensor_model_parallel_groups):
|
1470
|
-
for j in range(moe_tp_size):
|
1471
|
-
st = i * tensor_model_parallel_size + j
|
1472
|
-
en = (i + 1) * tensor_model_parallel_size + j
|
1473
|
-
ranks = list(range(st, en, moe_tp_size))
|
1474
|
-
group_ranks.append(ranks)
|
1475
1466
|
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1467
|
+
if moe_ep_size == tensor_model_parallel_size:
|
1468
|
+
_MOE_EP = _TP
|
1469
|
+
else:
|
1470
|
+
# TODO(ch-wan): use split_group to save memory
|
1471
|
+
group_ranks = []
|
1472
|
+
for i in range(num_tensor_model_parallel_groups):
|
1473
|
+
for j in range(moe_tp_size):
|
1474
|
+
st = i * tensor_model_parallel_size + j
|
1475
|
+
en = (i + 1) * tensor_model_parallel_size + j
|
1476
|
+
ranks = list(range(st, en, moe_tp_size))
|
1477
|
+
group_ranks.append(ranks)
|
1478
|
+
_MOE_EP = init_model_parallel_group(
|
1479
|
+
group_ranks,
|
1480
|
+
get_world_group().local_rank,
|
1481
|
+
backend,
|
1482
|
+
group_name="moe_ep",
|
1483
|
+
)
|
1483
1484
|
|
1484
1485
|
global _MOE_TP
|
1485
1486
|
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
1486
|
-
group_ranks = []
|
1487
|
-
for i in range(num_tensor_model_parallel_groups):
|
1488
|
-
for j in range(moe_ep_size):
|
1489
|
-
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1490
|
-
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1491
|
-
ranks = list(range(st, en))
|
1492
|
-
group_ranks.append(ranks)
|
1493
1487
|
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1488
|
+
if moe_tp_size == tensor_model_parallel_size:
|
1489
|
+
_MOE_TP = _TP
|
1490
|
+
else:
|
1491
|
+
# TODO(ch-wan): use split_group to save memory
|
1492
|
+
group_ranks = []
|
1493
|
+
for i in range(num_tensor_model_parallel_groups):
|
1494
|
+
for j in range(moe_ep_size):
|
1495
|
+
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1496
|
+
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1497
|
+
ranks = list(range(st, en))
|
1498
|
+
group_ranks.append(ranks)
|
1499
|
+
_MOE_TP = init_model_parallel_group(
|
1500
|
+
group_ranks,
|
1501
|
+
get_world_group().local_rank,
|
1502
|
+
backend,
|
1503
|
+
group_name="moe_tp",
|
1504
|
+
)
|
1501
1505
|
|
1502
1506
|
# Build the pipeline model-parallel groups.
|
1503
1507
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
@@ -1583,6 +1587,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
1583
1587
|
_TP = old_tp_group
|
1584
1588
|
|
1585
1589
|
|
1590
|
+
def get_world_size():
|
1591
|
+
"""Return world size for the world group."""
|
1592
|
+
return get_world_group().world_size
|
1593
|
+
|
1594
|
+
|
1595
|
+
def get_world_rank():
|
1596
|
+
"""Return my rank for the world group."""
|
1597
|
+
return get_world_group().rank_in_group
|
1598
|
+
|
1599
|
+
|
1586
1600
|
def get_tensor_model_parallel_world_size():
|
1587
1601
|
"""Return world size for the tensor model parallel group."""
|
1588
1602
|
return get_tp_group().world_size
|
@@ -1593,6 +1607,16 @@ def get_tensor_model_parallel_rank():
|
|
1593
1607
|
return get_tp_group().rank_in_group
|
1594
1608
|
|
1595
1609
|
|
1610
|
+
def get_pipeline_model_parallel_world_size():
|
1611
|
+
"""Return world size for the pipeline model parallel group."""
|
1612
|
+
return get_pp_group().world_size
|
1613
|
+
|
1614
|
+
|
1615
|
+
def get_pipeline_model_parallel_rank():
|
1616
|
+
"""Return my rank for the pipeline model parallel group."""
|
1617
|
+
return get_pp_group().rank_in_group
|
1618
|
+
|
1619
|
+
|
1596
1620
|
def get_moe_expert_parallel_world_size():
|
1597
1621
|
"""Return world size for the moe expert parallel group."""
|
1598
1622
|
return get_moe_ep_group().world_size
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -33,6 +33,8 @@ import zmq
|
|
33
33
|
import zmq.asyncio
|
34
34
|
from PIL.Image import Image
|
35
35
|
|
36
|
+
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
|
37
|
+
|
36
38
|
# Fix a bug of Python threading
|
37
39
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
38
40
|
|
@@ -138,6 +140,12 @@ class Engine(EngineBase):
|
|
138
140
|
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
139
141
|
)
|
140
142
|
|
143
|
+
if server_args.enable_trace:
|
144
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
145
|
+
if server_args.disaggregation_mode == "null":
|
146
|
+
thread_label = "Tokenizer"
|
147
|
+
trace_set_thread_info(thread_label)
|
148
|
+
|
141
149
|
def generate(
|
142
150
|
self,
|
143
151
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -364,9 +372,9 @@ class Engine(EngineBase):
|
|
364
372
|
loop = asyncio.get_event_loop()
|
365
373
|
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
|
366
374
|
|
367
|
-
def start_profile(self):
|
375
|
+
def start_profile(self, **kwargs):
|
368
376
|
loop = asyncio.get_event_loop()
|
369
|
-
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
377
|
+
loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
|
370
378
|
|
371
379
|
def stop_profile(self):
|
372
380
|
loop = asyncio.get_event_loop()
|
@@ -655,7 +663,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
655
663
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
656
664
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
657
665
|
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
658
|
-
os.environ
|
666
|
+
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
667
|
+
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
659
668
|
|
660
669
|
# Can also be passed as argument
|
661
670
|
os.environ["SGLANG_RUN_ID"] = (
|
@@ -673,7 +682,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
673
682
|
if server_args.attention_backend == "flashinfer":
|
674
683
|
assert_pkg_version(
|
675
684
|
"flashinfer_python",
|
676
|
-
"0.3.
|
685
|
+
"0.3.1",
|
677
686
|
"Please uninstall the old version and "
|
678
687
|
"reinstall the latest version by following the instructions "
|
679
688
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -681,7 +690,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
681
690
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
682
691
|
assert_pkg_version(
|
683
692
|
"sgl-kernel",
|
684
|
-
"0.3.
|
693
|
+
"0.3.9.post2",
|
685
694
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
686
695
|
)
|
687
696
|
|
@@ -703,6 +712,24 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
703
712
|
mp.set_start_method("spawn", force=True)
|
704
713
|
|
705
714
|
|
715
|
+
def _init_tokenizer_manager(
|
716
|
+
server_args: ServerArgs, port_args: PortArgs
|
717
|
+
) -> TokenizerManager:
|
718
|
+
# Launch tokenizer process
|
719
|
+
tokenizer_manager = TokenizerManager(server_args, port_args)
|
720
|
+
|
721
|
+
# Initialize templates
|
722
|
+
template_manager = TemplateManager()
|
723
|
+
template_manager.initialize_templates(
|
724
|
+
tokenizer_manager=tokenizer_manager,
|
725
|
+
model_path=server_args.model_path,
|
726
|
+
chat_template=server_args.chat_template,
|
727
|
+
completion_template=server_args.completion_template,
|
728
|
+
)
|
729
|
+
|
730
|
+
return tokenizer_manager, template_manager
|
731
|
+
|
732
|
+
|
706
733
|
def _launch_subprocesses(
|
707
734
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
708
735
|
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
@@ -815,23 +842,15 @@ def _launch_subprocesses(
|
|
815
842
|
),
|
816
843
|
)
|
817
844
|
detoken_proc.start()
|
845
|
+
|
846
|
+
# Init tokenizer manager first, as the bootstrap server is initialized here
|
818
847
|
if server_args.tokenizer_worker_num > 1:
|
819
848
|
# Launch multi-tokenizer router
|
820
849
|
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
821
|
-
|
822
|
-
# Initialize templates
|
823
850
|
template_manager = None
|
824
851
|
else:
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
# Initialize templates
|
829
|
-
template_manager = TemplateManager()
|
830
|
-
template_manager.initialize_templates(
|
831
|
-
tokenizer_manager=tokenizer_manager,
|
832
|
-
model_path=server_args.model_path,
|
833
|
-
chat_template=server_args.chat_template,
|
834
|
-
completion_template=server_args.completion_template,
|
852
|
+
tokenizer_manager, template_manager = _init_tokenizer_manager(
|
853
|
+
server_args, port_args
|
835
854
|
)
|
836
855
|
|
837
856
|
# Wait for the model to finish loading
|
@@ -855,5 +874,7 @@ def _launch_subprocesses(
|
|
855
874
|
|
856
875
|
# Assume all schedulers have the same scheduler_info
|
857
876
|
scheduler_info = scheduler_infos[0]
|
877
|
+
|
858
878
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
879
|
+
|
859
880
|
return tokenizer_manager, template_manager, scheduler_info
|