sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
15
15
|
|
16
|
+
import faulthandler
|
16
17
|
import logging
|
17
18
|
import multiprocessing as mp
|
18
19
|
import signal
|
@@ -39,7 +40,12 @@ from sglang.srt.managers.scheduler import run_scheduler_process
|
|
39
40
|
from sglang.srt.managers.utils import DPBalanceMeta
|
40
41
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
41
42
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
42
|
-
from sglang.srt.utils import
|
43
|
+
from sglang.srt.utils import (
|
44
|
+
bind_port,
|
45
|
+
configure_logger,
|
46
|
+
get_zmq_socket,
|
47
|
+
kill_itself_when_parent_died,
|
48
|
+
)
|
43
49
|
from sglang.utils import get_exception_traceback
|
44
50
|
|
45
51
|
logger = logging.getLogger(__name__)
|
@@ -100,7 +106,7 @@ class DataParallelController:
|
|
100
106
|
|
101
107
|
# Launch data parallel workers
|
102
108
|
self.scheduler_procs = []
|
103
|
-
self.workers = [None] * server_args.dp_size
|
109
|
+
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
104
110
|
|
105
111
|
if server_args.enable_dp_attention:
|
106
112
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
@@ -266,27 +272,34 @@ class DataParallelController:
|
|
266
272
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
267
273
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
268
274
|
|
275
|
+
def maybe_external_dp_rank_routing(self, req: Req):
|
276
|
+
if req.data_parallel_rank is not None:
|
277
|
+
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
278
|
+
self.workers[req.data_parallel_rank].send_pyobj(req)
|
279
|
+
return True
|
280
|
+
return False
|
281
|
+
|
269
282
|
def round_robin_scheduler(self, req: Req):
|
283
|
+
if self.maybe_external_dp_rank_routing(req):
|
284
|
+
return
|
285
|
+
|
270
286
|
if self.server_args.disaggregation_mode == "null":
|
271
|
-
|
272
|
-
|
273
|
-
self.workers
|
274
|
-
|
275
|
-
self.workers[self.round_robin_counter].send_pyobj(req)
|
276
|
-
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
277
|
-
self.workers
|
278
|
-
)
|
287
|
+
self.workers[self.round_robin_counter].send_pyobj(req)
|
288
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
289
|
+
self.workers
|
290
|
+
)
|
279
291
|
else:
|
280
|
-
|
281
|
-
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
282
|
-
self.workers[req.data_parallel_rank].send_pyobj(req)
|
283
|
-
else:
|
284
|
-
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
292
|
+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
285
293
|
|
286
294
|
def shortest_queue_scheduler(self, input_requests):
|
295
|
+
if self.maybe_external_dp_rank_routing(req):
|
296
|
+
return
|
287
297
|
raise NotImplementedError()
|
288
298
|
|
289
299
|
def minimum_tokens_scheduler(self, req):
|
300
|
+
if self.maybe_external_dp_rank_routing(req):
|
301
|
+
return
|
302
|
+
|
290
303
|
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
291
304
|
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
292
305
|
def get_next_global_balance_id() -> int:
|
@@ -343,7 +356,9 @@ def run_data_parallel_controller_process(
|
|
343
356
|
port_args: PortArgs,
|
344
357
|
pipe_writer,
|
345
358
|
):
|
359
|
+
kill_itself_when_parent_died()
|
346
360
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
361
|
+
faulthandler.enable()
|
347
362
|
configure_logger(server_args)
|
348
363
|
parent_process = psutil.Process().parent()
|
349
364
|
balance_meta = DPBalanceMeta(server_args.dp_size)
|
@@ -32,7 +32,9 @@ from sglang.srt.managers.io_struct import (
|
|
32
32
|
BatchStrOut,
|
33
33
|
BatchTokenIDOut,
|
34
34
|
FreezeGCReq,
|
35
|
+
MultiTokenizerRegisterReq,
|
35
36
|
)
|
37
|
+
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
|
36
38
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
37
39
|
from sglang.srt.utils import (
|
38
40
|
configure_logger,
|
@@ -67,7 +69,7 @@ class DecodeStatus:
|
|
67
69
|
sent_offset: int = 0
|
68
70
|
|
69
71
|
|
70
|
-
class DetokenizerManager:
|
72
|
+
class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
71
73
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
72
74
|
|
73
75
|
def __init__(
|
@@ -102,10 +104,13 @@ class DetokenizerManager:
|
|
102
104
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
103
105
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
104
106
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
107
|
+
(MultiTokenizerRegisterReq, lambda x: x),
|
105
108
|
(FreezeGCReq, self.handle_freeze_gc_req),
|
106
109
|
]
|
107
110
|
)
|
108
111
|
|
112
|
+
self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss"
|
113
|
+
|
109
114
|
def event_loop(self):
|
110
115
|
"""The event loop that handles requests"""
|
111
116
|
while True:
|
@@ -133,6 +138,9 @@ class DetokenizerManager:
|
|
133
138
|
|
134
139
|
# Trim stop token.
|
135
140
|
if isinstance(matched, int) and isinstance(output, list):
|
141
|
+
# 200012 <|call|> is the tool call token and one of eos tokens for gpt-oss model
|
142
|
+
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
|
143
|
+
return output
|
136
144
|
assert len(output) > 0
|
137
145
|
return output[:-1]
|
138
146
|
return output
|
@@ -238,6 +246,8 @@ class DetokenizerManager:
|
|
238
246
|
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
239
247
|
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
240
248
|
output_hidden_states=recv_obj.output_hidden_states,
|
249
|
+
placeholder_tokens_idx=None,
|
250
|
+
placeholder_tokens_val=None,
|
241
251
|
)
|
242
252
|
|
243
253
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
@@ -249,6 +259,8 @@ class DetokenizerManager:
|
|
249
259
|
prompt_tokens=recv_obj.prompt_tokens,
|
250
260
|
completion_tokens=recv_obj.completion_tokens,
|
251
261
|
cached_tokens=recv_obj.cached_tokens,
|
262
|
+
placeholder_tokens_idx=None,
|
263
|
+
placeholder_tokens_val=None,
|
252
264
|
)
|
253
265
|
|
254
266
|
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
|
@@ -280,8 +292,12 @@ def run_detokenizer_process(
|
|
280
292
|
|
281
293
|
try:
|
282
294
|
manager = DetokenizerManager(server_args, port_args)
|
283
|
-
|
295
|
+
if server_args.tokenizer_worker_num > 1:
|
296
|
+
manager.multi_http_worker_event_loop()
|
297
|
+
else:
|
298
|
+
manager.event_loop()
|
284
299
|
except Exception:
|
300
|
+
manager.socket_mapping.clear_all_sockets()
|
285
301
|
traceback = get_exception_traceback()
|
286
302
|
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
287
303
|
parent_process.send_signal(signal.SIGQUIT)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""Start bootstrap/kv-store-related server"""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Type
|
5
|
+
|
6
|
+
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
7
|
+
from sglang.srt.disaggregation.utils import (
|
8
|
+
DisaggregationMode,
|
9
|
+
KVClassType,
|
10
|
+
TransferBackend,
|
11
|
+
get_kv_class,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
|
16
|
+
def start_disagg_service(
|
17
|
+
server_args: ServerArgs,
|
18
|
+
):
|
19
|
+
# Start kv boostrap server on prefill
|
20
|
+
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
|
21
|
+
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
|
22
|
+
|
23
|
+
if disagg_mode == DisaggregationMode.PREFILL:
|
24
|
+
# only start bootstrap server on prefill tm
|
25
|
+
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
26
|
+
transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
27
|
+
)
|
28
|
+
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
29
|
+
host=server_args.host,
|
30
|
+
port=server_args.disaggregation_bootstrap_port,
|
31
|
+
)
|
32
|
+
is_create_store = (
|
33
|
+
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
|
34
|
+
)
|
35
|
+
if is_create_store:
|
36
|
+
try:
|
37
|
+
from mf_adapter import create_config_store
|
38
|
+
|
39
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
40
|
+
create_config_store(ascend_url)
|
41
|
+
except Exception as e:
|
42
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
43
|
+
error_message += f" With exception {e}"
|
44
|
+
raise error_message
|
45
|
+
|
46
|
+
return bootstrap_server
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -121,6 +121,7 @@ class GenerateReqInput:
|
|
121
121
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
122
122
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
123
123
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
124
|
+
bootstrap_pair_key: Optional[Union[List[str], str]] = None
|
124
125
|
|
125
126
|
# For data parallel rank routing
|
126
127
|
data_parallel_rank: Optional[int] = None
|
@@ -128,6 +129,18 @@ class GenerateReqInput:
|
|
128
129
|
# For background responses (OpenAI responses API)
|
129
130
|
background: bool = False
|
130
131
|
|
132
|
+
# Conversation id used for tracking requests
|
133
|
+
conversation_id: Optional[str] = None
|
134
|
+
|
135
|
+
# Label for the request
|
136
|
+
label: Optional[str] = None
|
137
|
+
|
138
|
+
# Priority for the request
|
139
|
+
priority: Optional[int] = None
|
140
|
+
|
141
|
+
# Image gen grpc migration
|
142
|
+
return_bytes: bool = False
|
143
|
+
|
131
144
|
def contains_mm_input(self) -> bool:
|
132
145
|
return (
|
133
146
|
has_valid_data(self.image_data)
|
@@ -258,6 +271,7 @@ class GenerateReqInput:
|
|
258
271
|
self._normalize_sampling_params(num)
|
259
272
|
self._normalize_logprob_params(num)
|
260
273
|
self._normalize_custom_logit_processor(num)
|
274
|
+
self._normalize_bootstrap_params(num)
|
261
275
|
|
262
276
|
def _expand_inputs(self, num):
|
263
277
|
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
@@ -297,6 +311,11 @@ class GenerateReqInput:
|
|
297
311
|
self.image_data = [[self.image_data]] * num
|
298
312
|
self.modalities = ["image"] * num
|
299
313
|
elif isinstance(self.image_data, list):
|
314
|
+
# Handle empty list case - treat as no images
|
315
|
+
if len(self.image_data) == 0:
|
316
|
+
self.image_data = [None] * num
|
317
|
+
return
|
318
|
+
|
300
319
|
if len(self.image_data) != self.batch_size:
|
301
320
|
raise ValueError(
|
302
321
|
"The length of image_data should be equal to the batch size."
|
@@ -421,6 +440,40 @@ class GenerateReqInput:
|
|
421
440
|
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
422
441
|
)
|
423
442
|
|
443
|
+
def _normalize_bootstrap_params(self, num):
|
444
|
+
"""Normalize bootstrap parameters for batch processing."""
|
445
|
+
# Normalize bootstrap_host
|
446
|
+
if self.bootstrap_host is None:
|
447
|
+
self.bootstrap_host = [None] * num
|
448
|
+
elif not isinstance(self.bootstrap_host, list):
|
449
|
+
self.bootstrap_host = [self.bootstrap_host] * num
|
450
|
+
elif isinstance(self.bootstrap_host, list):
|
451
|
+
self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num
|
452
|
+
|
453
|
+
# Normalize bootstrap_port
|
454
|
+
if self.bootstrap_port is None:
|
455
|
+
self.bootstrap_port = [None] * num
|
456
|
+
elif not isinstance(self.bootstrap_port, list):
|
457
|
+
self.bootstrap_port = [self.bootstrap_port] * num
|
458
|
+
elif isinstance(self.bootstrap_port, list):
|
459
|
+
self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num
|
460
|
+
|
461
|
+
# Normalize bootstrap_room
|
462
|
+
if self.bootstrap_room is None:
|
463
|
+
self.bootstrap_room = [None] * num
|
464
|
+
elif not isinstance(self.bootstrap_room, list):
|
465
|
+
self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
|
466
|
+
elif isinstance(self.bootstrap_room, list):
|
467
|
+
self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num
|
468
|
+
|
469
|
+
# Normalize bootstrap_pair_key
|
470
|
+
if self.bootstrap_pair_key is None:
|
471
|
+
self.bootstrap_pair_key = [None] * num
|
472
|
+
elif not isinstance(self.bootstrap_pair_key, list):
|
473
|
+
self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
|
474
|
+
elif isinstance(self.bootstrap_pair_key, list):
|
475
|
+
self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num
|
476
|
+
|
424
477
|
def _validate_session_params(self):
|
425
478
|
"""Validate that session parameters are properly formatted."""
|
426
479
|
if self.session_params is not None:
|
@@ -453,7 +506,13 @@ class GenerateReqInput:
|
|
453
506
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
454
507
|
stream=self.stream,
|
455
508
|
log_metrics=self.log_metrics,
|
509
|
+
return_hidden_states=(
|
510
|
+
self.return_hidden_states[i]
|
511
|
+
if isinstance(self.return_hidden_states, list)
|
512
|
+
else self.return_hidden_states
|
513
|
+
),
|
456
514
|
modalities=self.modalities[i] if self.modalities else None,
|
515
|
+
session_params=self.session_params,
|
457
516
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
458
517
|
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
459
518
|
custom_logit_processor=(
|
@@ -461,11 +520,6 @@ class GenerateReqInput:
|
|
461
520
|
if self.custom_logit_processor is not None
|
462
521
|
else None
|
463
522
|
),
|
464
|
-
return_hidden_states=(
|
465
|
-
self.return_hidden_states[i]
|
466
|
-
if isinstance(self.return_hidden_states, list)
|
467
|
-
else self.return_hidden_states
|
468
|
-
),
|
469
523
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
470
524
|
bootstrap_host=(
|
471
525
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
@@ -476,9 +530,18 @@ class GenerateReqInput:
|
|
476
530
|
bootstrap_room=(
|
477
531
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
478
532
|
),
|
533
|
+
bootstrap_pair_key=(
|
534
|
+
self.bootstrap_pair_key[i]
|
535
|
+
if self.bootstrap_pair_key is not None
|
536
|
+
else None
|
537
|
+
),
|
479
538
|
data_parallel_rank=(
|
480
539
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
481
540
|
),
|
541
|
+
conversation_id=self.conversation_id,
|
542
|
+
label=self.label,
|
543
|
+
priority=self.priority,
|
544
|
+
return_bytes=self.return_bytes,
|
482
545
|
)
|
483
546
|
|
484
547
|
|
@@ -504,27 +567,28 @@ class TokenizedGenerateReqInput:
|
|
504
567
|
token_ids_logprob: List[int]
|
505
568
|
# Whether to stream output
|
506
569
|
stream: bool
|
570
|
+
# Whether to return hidden states
|
571
|
+
return_hidden_states: bool = False
|
507
572
|
|
508
|
-
# LoRA related
|
509
|
-
lora_id: Optional[str] = None # None means just use the base model
|
510
573
|
# The input embeds
|
511
574
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
512
575
|
|
513
576
|
# Session info for continual prompting
|
514
577
|
session_params: Optional[SessionParams] = None
|
515
578
|
|
579
|
+
# LoRA related
|
580
|
+
lora_id: Optional[str] = None # None means just use the base model
|
581
|
+
|
516
582
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
517
583
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
518
584
|
# Use the processor's `to_str()` method to generate the serialized string.
|
519
585
|
custom_logit_processor: Optional[str] = None
|
520
586
|
|
521
|
-
# Whether to return hidden states
|
522
|
-
return_hidden_states: bool = False
|
523
|
-
|
524
587
|
# For disaggregated inference
|
525
588
|
bootstrap_host: Optional[str] = None
|
526
589
|
bootstrap_port: Optional[int] = None
|
527
590
|
bootstrap_room: Optional[int] = None
|
591
|
+
bootstrap_pair_key: Optional[str] = None
|
528
592
|
|
529
593
|
# For data parallel rank routing
|
530
594
|
data_parallel_rank: Optional[int] = None
|
@@ -532,6 +596,30 @@ class TokenizedGenerateReqInput:
|
|
532
596
|
# For dp balance
|
533
597
|
dp_balance_id: int = -1
|
534
598
|
|
599
|
+
# Label for the request
|
600
|
+
label: Optional[str] = None
|
601
|
+
|
602
|
+
# Priority for the request
|
603
|
+
priority: Optional[int] = None
|
604
|
+
|
605
|
+
# Image gen grpc migration
|
606
|
+
return_bytes: bool = False
|
607
|
+
|
608
|
+
|
609
|
+
@dataclass
|
610
|
+
class BatchTokenizedGenerateReqInput:
|
611
|
+
# The batch of tokenized requests
|
612
|
+
batch: List[TokenizedGenerateReqInput]
|
613
|
+
|
614
|
+
def __len__(self):
|
615
|
+
return len(self.batch)
|
616
|
+
|
617
|
+
def __getitem__(self, i):
|
618
|
+
return self.batch[i]
|
619
|
+
|
620
|
+
def __iter__(self):
|
621
|
+
return iter(self.batch)
|
622
|
+
|
535
623
|
|
536
624
|
@dataclass
|
537
625
|
class EmbeddingReqInput:
|
@@ -668,6 +756,21 @@ class TokenizedEmbeddingReqInput:
|
|
668
756
|
dp_balance_id: int = -1
|
669
757
|
|
670
758
|
|
759
|
+
@dataclass
|
760
|
+
class BatchTokenizedEmbeddingReqInput:
|
761
|
+
# The batch of tokenized embedding requests
|
762
|
+
batch: List[TokenizedEmbeddingReqInput]
|
763
|
+
|
764
|
+
def __len__(self):
|
765
|
+
return len(self.batch)
|
766
|
+
|
767
|
+
def __getitem__(self, i):
|
768
|
+
return self.batch[i]
|
769
|
+
|
770
|
+
def __iter__(self):
|
771
|
+
return iter(self.batch)
|
772
|
+
|
773
|
+
|
671
774
|
@dataclass
|
672
775
|
class BatchTokenIDOut:
|
673
776
|
# The request id
|
@@ -708,9 +811,26 @@ class BatchTokenIDOut:
|
|
708
811
|
# Hidden states
|
709
812
|
output_hidden_states: List[List[float]]
|
710
813
|
|
814
|
+
# The information of placeholder tokens (e.g., image token)
|
815
|
+
# idx is the index of the token in the prompt after expansion.
|
816
|
+
# val is the length of padded tokens after expansion.
|
817
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
818
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
819
|
+
|
711
820
|
|
712
821
|
@dataclass
|
713
822
|
class BatchMultimodalDecodeReq:
|
823
|
+
decoded_ids: List[int]
|
824
|
+
input_token_logprobs_val: List[float]
|
825
|
+
input_token_logprobs_idx: List[int]
|
826
|
+
output_token_logprobs_val: List[float]
|
827
|
+
output_token_logprobs_idx: List[int]
|
828
|
+
read_offsets: List[int]
|
829
|
+
skip_special_tokens: List[bool]
|
830
|
+
spaces_between_special_tokens: List[bool]
|
831
|
+
image_resolutions: List[List[int]]
|
832
|
+
resize_image_resolutions: List[List[int]]
|
833
|
+
|
714
834
|
# The request id
|
715
835
|
rids: List[str]
|
716
836
|
finished_reasons: List[BaseFinishReason]
|
@@ -720,6 +840,12 @@ class BatchMultimodalDecodeReq:
|
|
720
840
|
completion_tokens: List[int]
|
721
841
|
cached_tokens: List[int]
|
722
842
|
|
843
|
+
# Placeholder token info
|
844
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
845
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
846
|
+
|
847
|
+
return_bytes: bool = False
|
848
|
+
|
723
849
|
|
724
850
|
@dataclass
|
725
851
|
class BatchStrOut:
|
@@ -755,6 +881,9 @@ class BatchStrOut:
|
|
755
881
|
# Hidden states
|
756
882
|
output_hidden_states: List[List[float]]
|
757
883
|
|
884
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
885
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
886
|
+
|
758
887
|
|
759
888
|
@dataclass
|
760
889
|
class BatchMultimodalOut:
|
@@ -762,14 +891,26 @@ class BatchMultimodalOut:
|
|
762
891
|
rids: List[str]
|
763
892
|
# The finish reason
|
764
893
|
finished_reasons: List[dict]
|
894
|
+
decoded_ids: List[List[int]]
|
765
895
|
# The outputs
|
766
|
-
outputs: List[List[Dict]]
|
896
|
+
outputs: Union[List[str | bytes], List[List[Dict]]]
|
897
|
+
|
898
|
+
# probability values for input tokens and output tokens
|
899
|
+
input_token_logprobs_val: List[List[float]]
|
900
|
+
input_token_logprobs_idx: List[List[int]]
|
901
|
+
output_token_logprobs_val: List[List[float]]
|
902
|
+
output_token_logprobs_idx: List[List[int]]
|
767
903
|
|
768
904
|
# Token counts
|
769
905
|
prompt_tokens: List[int]
|
770
906
|
completion_tokens: List[int]
|
771
907
|
cached_tokens: List[int]
|
772
908
|
|
909
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
910
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
911
|
+
|
912
|
+
return_bytes: List[bool]
|
913
|
+
|
773
914
|
|
774
915
|
@dataclass
|
775
916
|
class BatchEmbeddingOut:
|
@@ -782,6 +923,19 @@ class BatchEmbeddingOut:
|
|
782
923
|
# Token counts
|
783
924
|
prompt_tokens: List[int]
|
784
925
|
cached_tokens: List[int]
|
926
|
+
# Placeholder token info
|
927
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
928
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
929
|
+
|
930
|
+
|
931
|
+
@dataclass
|
932
|
+
class ClearHiCacheReqInput:
|
933
|
+
pass
|
934
|
+
|
935
|
+
|
936
|
+
@dataclass
|
937
|
+
class ClearHiCacheReqOutput:
|
938
|
+
success: bool
|
785
939
|
|
786
940
|
|
787
941
|
@dataclass
|
@@ -804,6 +958,12 @@ class UpdateWeightFromDiskReqInput:
|
|
804
958
|
abort_all_requests: bool = False
|
805
959
|
# Optional: Update weight version along with weights
|
806
960
|
weight_version: Optional[str] = None
|
961
|
+
# Whether to update weights asynchronously
|
962
|
+
is_async: bool = False
|
963
|
+
# Whether to empty torch cache
|
964
|
+
torch_empty_cache: bool = False
|
965
|
+
# Whether to keep the scheduler paused after weight update
|
966
|
+
keep_pause: bool = False
|
807
967
|
|
808
968
|
|
809
969
|
@dataclass
|
@@ -943,6 +1103,12 @@ class AbortReq:
|
|
943
1103
|
abort_all: bool = False
|
944
1104
|
# The finished reason data
|
945
1105
|
finished_reason: Optional[Dict[str, Any]] = None
|
1106
|
+
abort_reason: Optional[str] = None
|
1107
|
+
# used in MultiTokenzierManager mode
|
1108
|
+
rids: Optional[Union[List[str], str]] = None
|
1109
|
+
|
1110
|
+
def __post_init__(self):
|
1111
|
+
self.rids = self.rid
|
946
1112
|
|
947
1113
|
|
948
1114
|
@dataclass
|
@@ -1016,6 +1182,7 @@ class ConfigureLoggingReq:
|
|
1016
1182
|
log_requests_level: Optional[int] = None
|
1017
1183
|
dump_requests_folder: Optional[str] = None
|
1018
1184
|
dump_requests_threshold: Optional[int] = None
|
1185
|
+
crash_dump_folder: Optional[str] = None
|
1019
1186
|
|
1020
1187
|
|
1021
1188
|
@dataclass
|
@@ -1143,6 +1310,18 @@ class LoRAUpdateResult:
|
|
1143
1310
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
1144
1311
|
|
1145
1312
|
|
1313
|
+
@dataclass
|
1314
|
+
class MultiTokenizerRegisterReq:
|
1315
|
+
rids: Optional[Union[List[str], str]] = None
|
1316
|
+
ipc_name: Optional[str] = None
|
1317
|
+
|
1318
|
+
|
1319
|
+
@dataclass
|
1320
|
+
class MultiTokenizerWrapper:
|
1321
|
+
worker_id: int
|
1322
|
+
obj: Optional[Any] = None
|
1323
|
+
|
1324
|
+
|
1146
1325
|
class BlockReqType(Enum):
|
1147
1326
|
BLOCK = 1
|
1148
1327
|
UNBLOCK = 2
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
|
|
20
20
|
)
|
21
21
|
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
22
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
|
-
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
23
|
+
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
|
24
24
|
from sglang.utils import logger
|
25
25
|
|
26
|
+
_is_npu = is_npu()
|
27
|
+
|
26
28
|
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
27
29
|
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
28
30
|
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
|
|
486
488
|
if embedding is None:
|
487
489
|
return None, None
|
488
490
|
# 2. Get mask
|
491
|
+
if _is_npu:
|
492
|
+
torch.npu.current_stream().synchronize()
|
489
493
|
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
490
494
|
# 3. Adjust embedding length if needed
|
491
495
|
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
@@ -625,6 +629,7 @@ def general_mm_embed_routine(
|
|
625
629
|
embed_tokens = language_model.get_input_embeddings()
|
626
630
|
if (
|
627
631
|
not forward_batch.forward_mode.is_decode()
|
632
|
+
and not forward_batch.forward_mode.is_target_verify()
|
628
633
|
and forward_batch.contains_mm_inputs()
|
629
634
|
):
|
630
635
|
mm_inputs_list = [
|