sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
15
15
|
|
16
|
+
import faulthandler
|
16
17
|
import logging
|
17
18
|
import multiprocessing as mp
|
18
19
|
import signal
|
@@ -39,7 +40,12 @@ from sglang.srt.managers.scheduler import run_scheduler_process
|
|
39
40
|
from sglang.srt.managers.utils import DPBalanceMeta
|
40
41
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
41
42
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
42
|
-
from sglang.srt.utils import
|
43
|
+
from sglang.srt.utils import (
|
44
|
+
bind_port,
|
45
|
+
configure_logger,
|
46
|
+
get_zmq_socket,
|
47
|
+
kill_itself_when_parent_died,
|
48
|
+
)
|
43
49
|
from sglang.utils import get_exception_traceback
|
44
50
|
|
45
51
|
logger = logging.getLogger(__name__)
|
@@ -100,7 +106,7 @@ class DataParallelController:
|
|
100
106
|
|
101
107
|
# Launch data parallel workers
|
102
108
|
self.scheduler_procs = []
|
103
|
-
self.workers = [None] * server_args.dp_size
|
109
|
+
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
104
110
|
|
105
111
|
if server_args.enable_dp_attention:
|
106
112
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
@@ -266,27 +272,34 @@ class DataParallelController:
|
|
266
272
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
267
273
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
268
274
|
|
275
|
+
def maybe_external_dp_rank_routing(self, req: Req):
|
276
|
+
if req.data_parallel_rank is not None:
|
277
|
+
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
278
|
+
self.workers[req.data_parallel_rank].send_pyobj(req)
|
279
|
+
return True
|
280
|
+
return False
|
281
|
+
|
269
282
|
def round_robin_scheduler(self, req: Req):
|
283
|
+
if self.maybe_external_dp_rank_routing(req):
|
284
|
+
return
|
285
|
+
|
270
286
|
if self.server_args.disaggregation_mode == "null":
|
271
|
-
|
272
|
-
|
273
|
-
self.workers
|
274
|
-
|
275
|
-
self.workers[self.round_robin_counter].send_pyobj(req)
|
276
|
-
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
277
|
-
self.workers
|
278
|
-
)
|
287
|
+
self.workers[self.round_robin_counter].send_pyobj(req)
|
288
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
289
|
+
self.workers
|
290
|
+
)
|
279
291
|
else:
|
280
|
-
|
281
|
-
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
282
|
-
self.workers[req.data_parallel_rank].send_pyobj(req)
|
283
|
-
else:
|
284
|
-
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
292
|
+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
285
293
|
|
286
294
|
def shortest_queue_scheduler(self, input_requests):
|
295
|
+
if self.maybe_external_dp_rank_routing(req):
|
296
|
+
return
|
287
297
|
raise NotImplementedError()
|
288
298
|
|
289
299
|
def minimum_tokens_scheduler(self, req):
|
300
|
+
if self.maybe_external_dp_rank_routing(req):
|
301
|
+
return
|
302
|
+
|
290
303
|
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
291
304
|
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
292
305
|
def get_next_global_balance_id() -> int:
|
@@ -343,7 +356,9 @@ def run_data_parallel_controller_process(
|
|
343
356
|
port_args: PortArgs,
|
344
357
|
pipe_writer,
|
345
358
|
):
|
359
|
+
kill_itself_when_parent_died()
|
346
360
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
361
|
+
faulthandler.enable()
|
347
362
|
configure_logger(server_args)
|
348
363
|
parent_process = psutil.Process().parent()
|
349
364
|
balance_meta = DPBalanceMeta(server_args.dp_size)
|
@@ -32,7 +32,9 @@ from sglang.srt.managers.io_struct import (
|
|
32
32
|
BatchStrOut,
|
33
33
|
BatchTokenIDOut,
|
34
34
|
FreezeGCReq,
|
35
|
+
MultiTokenizerRegisterReq,
|
35
36
|
)
|
37
|
+
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
|
36
38
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
37
39
|
from sglang.srt.utils import (
|
38
40
|
configure_logger,
|
@@ -67,7 +69,7 @@ class DecodeStatus:
|
|
67
69
|
sent_offset: int = 0
|
68
70
|
|
69
71
|
|
70
|
-
class DetokenizerManager:
|
72
|
+
class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
71
73
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
72
74
|
|
73
75
|
def __init__(
|
@@ -102,6 +104,7 @@ class DetokenizerManager:
|
|
102
104
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
103
105
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
104
106
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
107
|
+
(MultiTokenizerRegisterReq, lambda x: x),
|
105
108
|
(FreezeGCReq, self.handle_freeze_gc_req),
|
106
109
|
]
|
107
110
|
)
|
@@ -243,6 +246,8 @@ class DetokenizerManager:
|
|
243
246
|
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
244
247
|
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
245
248
|
output_hidden_states=recv_obj.output_hidden_states,
|
249
|
+
placeholder_tokens_idx=None,
|
250
|
+
placeholder_tokens_val=None,
|
246
251
|
)
|
247
252
|
|
248
253
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
@@ -254,6 +259,8 @@ class DetokenizerManager:
|
|
254
259
|
prompt_tokens=recv_obj.prompt_tokens,
|
255
260
|
completion_tokens=recv_obj.completion_tokens,
|
256
261
|
cached_tokens=recv_obj.cached_tokens,
|
262
|
+
placeholder_tokens_idx=None,
|
263
|
+
placeholder_tokens_val=None,
|
257
264
|
)
|
258
265
|
|
259
266
|
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
|
@@ -285,8 +292,12 @@ def run_detokenizer_process(
|
|
285
292
|
|
286
293
|
try:
|
287
294
|
manager = DetokenizerManager(server_args, port_args)
|
288
|
-
|
295
|
+
if server_args.tokenizer_worker_num > 1:
|
296
|
+
manager.multi_http_worker_event_loop()
|
297
|
+
else:
|
298
|
+
manager.event_loop()
|
289
299
|
except Exception:
|
300
|
+
manager.socket_mapping.clear_all_sockets()
|
290
301
|
traceback = get_exception_traceback()
|
291
302
|
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
292
303
|
parent_process.send_signal(signal.SIGQUIT)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""Start bootstrap/kv-store-related server"""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Type
|
5
|
+
|
6
|
+
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
7
|
+
from sglang.srt.disaggregation.utils import (
|
8
|
+
DisaggregationMode,
|
9
|
+
KVClassType,
|
10
|
+
TransferBackend,
|
11
|
+
get_kv_class,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
|
16
|
+
def start_disagg_service(
|
17
|
+
server_args: ServerArgs,
|
18
|
+
):
|
19
|
+
# Start kv boostrap server on prefill
|
20
|
+
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
|
21
|
+
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
|
22
|
+
|
23
|
+
if disagg_mode == DisaggregationMode.PREFILL:
|
24
|
+
# only start bootstrap server on prefill tm
|
25
|
+
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
26
|
+
transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
27
|
+
)
|
28
|
+
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
29
|
+
host=server_args.host,
|
30
|
+
port=server_args.disaggregation_bootstrap_port,
|
31
|
+
)
|
32
|
+
is_create_store = (
|
33
|
+
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
|
34
|
+
)
|
35
|
+
if is_create_store:
|
36
|
+
try:
|
37
|
+
from mf_adapter import create_config_store
|
38
|
+
|
39
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
40
|
+
create_config_store(ascend_url)
|
41
|
+
except Exception as e:
|
42
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
43
|
+
error_message += f" With exception {e}"
|
44
|
+
raise error_message
|
45
|
+
|
46
|
+
return bootstrap_server
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -121,6 +121,7 @@ class GenerateReqInput:
|
|
121
121
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
122
122
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
123
123
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
124
|
+
bootstrap_pair_key: Optional[Union[List[str], str]] = None
|
124
125
|
|
125
126
|
# For data parallel rank routing
|
126
127
|
data_parallel_rank: Optional[int] = None
|
@@ -128,6 +129,18 @@ class GenerateReqInput:
|
|
128
129
|
# For background responses (OpenAI responses API)
|
129
130
|
background: bool = False
|
130
131
|
|
132
|
+
# Conversation id used for tracking requests
|
133
|
+
conversation_id: Optional[str] = None
|
134
|
+
|
135
|
+
# Label for the request
|
136
|
+
label: Optional[str] = None
|
137
|
+
|
138
|
+
# Priority for the request
|
139
|
+
priority: Optional[int] = None
|
140
|
+
|
141
|
+
# Image gen grpc migration
|
142
|
+
return_bytes: bool = False
|
143
|
+
|
131
144
|
def contains_mm_input(self) -> bool:
|
132
145
|
return (
|
133
146
|
has_valid_data(self.image_data)
|
@@ -258,6 +271,7 @@ class GenerateReqInput:
|
|
258
271
|
self._normalize_sampling_params(num)
|
259
272
|
self._normalize_logprob_params(num)
|
260
273
|
self._normalize_custom_logit_processor(num)
|
274
|
+
self._normalize_bootstrap_params(num)
|
261
275
|
|
262
276
|
def _expand_inputs(self, num):
|
263
277
|
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
@@ -297,6 +311,11 @@ class GenerateReqInput:
|
|
297
311
|
self.image_data = [[self.image_data]] * num
|
298
312
|
self.modalities = ["image"] * num
|
299
313
|
elif isinstance(self.image_data, list):
|
314
|
+
# Handle empty list case - treat as no images
|
315
|
+
if len(self.image_data) == 0:
|
316
|
+
self.image_data = [None] * num
|
317
|
+
return
|
318
|
+
|
300
319
|
if len(self.image_data) != self.batch_size:
|
301
320
|
raise ValueError(
|
302
321
|
"The length of image_data should be equal to the batch size."
|
@@ -421,6 +440,40 @@ class GenerateReqInput:
|
|
421
440
|
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
422
441
|
)
|
423
442
|
|
443
|
+
def _normalize_bootstrap_params(self, num):
|
444
|
+
"""Normalize bootstrap parameters for batch processing."""
|
445
|
+
# Normalize bootstrap_host
|
446
|
+
if self.bootstrap_host is None:
|
447
|
+
self.bootstrap_host = [None] * num
|
448
|
+
elif not isinstance(self.bootstrap_host, list):
|
449
|
+
self.bootstrap_host = [self.bootstrap_host] * num
|
450
|
+
elif isinstance(self.bootstrap_host, list):
|
451
|
+
self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num
|
452
|
+
|
453
|
+
# Normalize bootstrap_port
|
454
|
+
if self.bootstrap_port is None:
|
455
|
+
self.bootstrap_port = [None] * num
|
456
|
+
elif not isinstance(self.bootstrap_port, list):
|
457
|
+
self.bootstrap_port = [self.bootstrap_port] * num
|
458
|
+
elif isinstance(self.bootstrap_port, list):
|
459
|
+
self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num
|
460
|
+
|
461
|
+
# Normalize bootstrap_room
|
462
|
+
if self.bootstrap_room is None:
|
463
|
+
self.bootstrap_room = [None] * num
|
464
|
+
elif not isinstance(self.bootstrap_room, list):
|
465
|
+
self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
|
466
|
+
elif isinstance(self.bootstrap_room, list):
|
467
|
+
self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num
|
468
|
+
|
469
|
+
# Normalize bootstrap_pair_key
|
470
|
+
if self.bootstrap_pair_key is None:
|
471
|
+
self.bootstrap_pair_key = [None] * num
|
472
|
+
elif not isinstance(self.bootstrap_pair_key, list):
|
473
|
+
self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
|
474
|
+
elif isinstance(self.bootstrap_pair_key, list):
|
475
|
+
self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num
|
476
|
+
|
424
477
|
def _validate_session_params(self):
|
425
478
|
"""Validate that session parameters are properly formatted."""
|
426
479
|
if self.session_params is not None:
|
@@ -453,7 +506,13 @@ class GenerateReqInput:
|
|
453
506
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
454
507
|
stream=self.stream,
|
455
508
|
log_metrics=self.log_metrics,
|
509
|
+
return_hidden_states=(
|
510
|
+
self.return_hidden_states[i]
|
511
|
+
if isinstance(self.return_hidden_states, list)
|
512
|
+
else self.return_hidden_states
|
513
|
+
),
|
456
514
|
modalities=self.modalities[i] if self.modalities else None,
|
515
|
+
session_params=self.session_params,
|
457
516
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
458
517
|
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
459
518
|
custom_logit_processor=(
|
@@ -461,11 +520,6 @@ class GenerateReqInput:
|
|
461
520
|
if self.custom_logit_processor is not None
|
462
521
|
else None
|
463
522
|
),
|
464
|
-
return_hidden_states=(
|
465
|
-
self.return_hidden_states[i]
|
466
|
-
if isinstance(self.return_hidden_states, list)
|
467
|
-
else self.return_hidden_states
|
468
|
-
),
|
469
523
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
470
524
|
bootstrap_host=(
|
471
525
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
@@ -476,9 +530,18 @@ class GenerateReqInput:
|
|
476
530
|
bootstrap_room=(
|
477
531
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
478
532
|
),
|
533
|
+
bootstrap_pair_key=(
|
534
|
+
self.bootstrap_pair_key[i]
|
535
|
+
if self.bootstrap_pair_key is not None
|
536
|
+
else None
|
537
|
+
),
|
479
538
|
data_parallel_rank=(
|
480
539
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
481
540
|
),
|
541
|
+
conversation_id=self.conversation_id,
|
542
|
+
label=self.label,
|
543
|
+
priority=self.priority,
|
544
|
+
return_bytes=self.return_bytes,
|
482
545
|
)
|
483
546
|
|
484
547
|
|
@@ -504,27 +567,28 @@ class TokenizedGenerateReqInput:
|
|
504
567
|
token_ids_logprob: List[int]
|
505
568
|
# Whether to stream output
|
506
569
|
stream: bool
|
570
|
+
# Whether to return hidden states
|
571
|
+
return_hidden_states: bool = False
|
507
572
|
|
508
|
-
# LoRA related
|
509
|
-
lora_id: Optional[str] = None # None means just use the base model
|
510
573
|
# The input embeds
|
511
574
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
512
575
|
|
513
576
|
# Session info for continual prompting
|
514
577
|
session_params: Optional[SessionParams] = None
|
515
578
|
|
579
|
+
# LoRA related
|
580
|
+
lora_id: Optional[str] = None # None means just use the base model
|
581
|
+
|
516
582
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
517
583
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
518
584
|
# Use the processor's `to_str()` method to generate the serialized string.
|
519
585
|
custom_logit_processor: Optional[str] = None
|
520
586
|
|
521
|
-
# Whether to return hidden states
|
522
|
-
return_hidden_states: bool = False
|
523
|
-
|
524
587
|
# For disaggregated inference
|
525
588
|
bootstrap_host: Optional[str] = None
|
526
589
|
bootstrap_port: Optional[int] = None
|
527
590
|
bootstrap_room: Optional[int] = None
|
591
|
+
bootstrap_pair_key: Optional[str] = None
|
528
592
|
|
529
593
|
# For data parallel rank routing
|
530
594
|
data_parallel_rank: Optional[int] = None
|
@@ -532,6 +596,15 @@ class TokenizedGenerateReqInput:
|
|
532
596
|
# For dp balance
|
533
597
|
dp_balance_id: int = -1
|
534
598
|
|
599
|
+
# Label for the request
|
600
|
+
label: Optional[str] = None
|
601
|
+
|
602
|
+
# Priority for the request
|
603
|
+
priority: Optional[int] = None
|
604
|
+
|
605
|
+
# Image gen grpc migration
|
606
|
+
return_bytes: bool = False
|
607
|
+
|
535
608
|
|
536
609
|
@dataclass
|
537
610
|
class BatchTokenizedGenerateReqInput:
|
@@ -738,9 +811,26 @@ class BatchTokenIDOut:
|
|
738
811
|
# Hidden states
|
739
812
|
output_hidden_states: List[List[float]]
|
740
813
|
|
814
|
+
# The information of placeholder tokens (e.g., image token)
|
815
|
+
# idx is the index of the token in the prompt after expansion.
|
816
|
+
# val is the length of padded tokens after expansion.
|
817
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
818
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
819
|
+
|
741
820
|
|
742
821
|
@dataclass
|
743
822
|
class BatchMultimodalDecodeReq:
|
823
|
+
decoded_ids: List[int]
|
824
|
+
input_token_logprobs_val: List[float]
|
825
|
+
input_token_logprobs_idx: List[int]
|
826
|
+
output_token_logprobs_val: List[float]
|
827
|
+
output_token_logprobs_idx: List[int]
|
828
|
+
read_offsets: List[int]
|
829
|
+
skip_special_tokens: List[bool]
|
830
|
+
spaces_between_special_tokens: List[bool]
|
831
|
+
image_resolutions: List[List[int]]
|
832
|
+
resize_image_resolutions: List[List[int]]
|
833
|
+
|
744
834
|
# The request id
|
745
835
|
rids: List[str]
|
746
836
|
finished_reasons: List[BaseFinishReason]
|
@@ -750,6 +840,12 @@ class BatchMultimodalDecodeReq:
|
|
750
840
|
completion_tokens: List[int]
|
751
841
|
cached_tokens: List[int]
|
752
842
|
|
843
|
+
# Placeholder token info
|
844
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
845
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
846
|
+
|
847
|
+
return_bytes: bool = False
|
848
|
+
|
753
849
|
|
754
850
|
@dataclass
|
755
851
|
class BatchStrOut:
|
@@ -785,6 +881,9 @@ class BatchStrOut:
|
|
785
881
|
# Hidden states
|
786
882
|
output_hidden_states: List[List[float]]
|
787
883
|
|
884
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
885
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
886
|
+
|
788
887
|
|
789
888
|
@dataclass
|
790
889
|
class BatchMultimodalOut:
|
@@ -792,14 +891,26 @@ class BatchMultimodalOut:
|
|
792
891
|
rids: List[str]
|
793
892
|
# The finish reason
|
794
893
|
finished_reasons: List[dict]
|
894
|
+
decoded_ids: List[List[int]]
|
795
895
|
# The outputs
|
796
|
-
outputs: List[List[Dict]]
|
896
|
+
outputs: Union[List[str | bytes], List[List[Dict]]]
|
897
|
+
|
898
|
+
# probability values for input tokens and output tokens
|
899
|
+
input_token_logprobs_val: List[List[float]]
|
900
|
+
input_token_logprobs_idx: List[List[int]]
|
901
|
+
output_token_logprobs_val: List[List[float]]
|
902
|
+
output_token_logprobs_idx: List[List[int]]
|
797
903
|
|
798
904
|
# Token counts
|
799
905
|
prompt_tokens: List[int]
|
800
906
|
completion_tokens: List[int]
|
801
907
|
cached_tokens: List[int]
|
802
908
|
|
909
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
910
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
911
|
+
|
912
|
+
return_bytes: List[bool]
|
913
|
+
|
803
914
|
|
804
915
|
@dataclass
|
805
916
|
class BatchEmbeddingOut:
|
@@ -812,6 +923,19 @@ class BatchEmbeddingOut:
|
|
812
923
|
# Token counts
|
813
924
|
prompt_tokens: List[int]
|
814
925
|
cached_tokens: List[int]
|
926
|
+
# Placeholder token info
|
927
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
928
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
929
|
+
|
930
|
+
|
931
|
+
@dataclass
|
932
|
+
class ClearHiCacheReqInput:
|
933
|
+
pass
|
934
|
+
|
935
|
+
|
936
|
+
@dataclass
|
937
|
+
class ClearHiCacheReqOutput:
|
938
|
+
success: bool
|
815
939
|
|
816
940
|
|
817
941
|
@dataclass
|
@@ -834,6 +958,12 @@ class UpdateWeightFromDiskReqInput:
|
|
834
958
|
abort_all_requests: bool = False
|
835
959
|
# Optional: Update weight version along with weights
|
836
960
|
weight_version: Optional[str] = None
|
961
|
+
# Whether to update weights asynchronously
|
962
|
+
is_async: bool = False
|
963
|
+
# Whether to empty torch cache
|
964
|
+
torch_empty_cache: bool = False
|
965
|
+
# Whether to keep the scheduler paused after weight update
|
966
|
+
keep_pause: bool = False
|
837
967
|
|
838
968
|
|
839
969
|
@dataclass
|
@@ -973,6 +1103,12 @@ class AbortReq:
|
|
973
1103
|
abort_all: bool = False
|
974
1104
|
# The finished reason data
|
975
1105
|
finished_reason: Optional[Dict[str, Any]] = None
|
1106
|
+
abort_reason: Optional[str] = None
|
1107
|
+
# used in MultiTokenzierManager mode
|
1108
|
+
rids: Optional[Union[List[str], str]] = None
|
1109
|
+
|
1110
|
+
def __post_init__(self):
|
1111
|
+
self.rids = self.rid
|
976
1112
|
|
977
1113
|
|
978
1114
|
@dataclass
|
@@ -1046,6 +1182,7 @@ class ConfigureLoggingReq:
|
|
1046
1182
|
log_requests_level: Optional[int] = None
|
1047
1183
|
dump_requests_folder: Optional[str] = None
|
1048
1184
|
dump_requests_threshold: Optional[int] = None
|
1185
|
+
crash_dump_folder: Optional[str] = None
|
1049
1186
|
|
1050
1187
|
|
1051
1188
|
@dataclass
|
@@ -1173,6 +1310,18 @@ class LoRAUpdateResult:
|
|
1173
1310
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
1174
1311
|
|
1175
1312
|
|
1313
|
+
@dataclass
|
1314
|
+
class MultiTokenizerRegisterReq:
|
1315
|
+
rids: Optional[Union[List[str], str]] = None
|
1316
|
+
ipc_name: Optional[str] = None
|
1317
|
+
|
1318
|
+
|
1319
|
+
@dataclass
|
1320
|
+
class MultiTokenizerWrapper:
|
1321
|
+
worker_id: int
|
1322
|
+
obj: Optional[Any] = None
|
1323
|
+
|
1324
|
+
|
1176
1325
|
class BlockReqType(Enum):
|
1177
1326
|
BLOCK = 1
|
1178
1327
|
UNBLOCK = 2
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
|
|
20
20
|
)
|
21
21
|
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
22
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
|
-
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
23
|
+
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
|
24
24
|
from sglang.utils import logger
|
25
25
|
|
26
|
+
_is_npu = is_npu()
|
27
|
+
|
26
28
|
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
27
29
|
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
28
30
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
|
|
486
488
|
if embedding is None:
|
487
489
|
return None, None
|
488
490
|
# 2. Get mask
|
491
|
+
if _is_npu:
|
492
|
+
torch.npu.current_stream().synchronize()
|
489
493
|
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
490
494
|
# 3. Adjust embedding length if needed
|
491
495
|
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
@@ -625,6 +629,7 @@ def general_mm_embed_routine(
|
|
625
629
|
embed_tokens = language_model.get_input_embeddings()
|
626
630
|
if (
|
627
631
|
not forward_batch.forward_mode.is_decode()
|
632
|
+
and not forward_batch.forward_mode.is_target_verify()
|
628
633
|
and forward_batch.contains_mm_inputs()
|
629
634
|
):
|
630
635
|
mm_inputs_list = [
|