sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
15
15
|
|
16
|
+
import faulthandler
|
16
17
|
import logging
|
17
18
|
import multiprocessing as mp
|
18
19
|
import signal
|
@@ -20,6 +21,7 @@ import struct
|
|
20
21
|
import sys
|
21
22
|
import threading
|
22
23
|
import time
|
24
|
+
from collections import deque
|
23
25
|
from enum import Enum, auto
|
24
26
|
from multiprocessing import shared_memory
|
25
27
|
from typing import Dict, List
|
@@ -33,14 +35,20 @@ from sglang.srt.managers.io_struct import (
|
|
33
35
|
BlockReqInput,
|
34
36
|
TokenizedEmbeddingReqInput,
|
35
37
|
TokenizedGenerateReqInput,
|
38
|
+
WatchLoadUpdateReq,
|
36
39
|
)
|
37
40
|
from sglang.srt.managers.schedule_batch import Req
|
38
41
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
39
42
|
from sglang.srt.managers.utils import DPBalanceMeta
|
40
43
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
41
44
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
42
|
-
from sglang.srt.utils import
|
43
|
-
|
45
|
+
from sglang.srt.utils import (
|
46
|
+
bind_port,
|
47
|
+
configure_logger,
|
48
|
+
get_zmq_socket,
|
49
|
+
kill_itself_when_parent_died,
|
50
|
+
)
|
51
|
+
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
44
52
|
|
45
53
|
logger = logging.getLogger(__name__)
|
46
54
|
|
@@ -61,6 +69,42 @@ class LoadBalanceMethod(Enum):
|
|
61
69
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
62
70
|
|
63
71
|
|
72
|
+
class DPBudget:
|
73
|
+
def __init__(self):
|
74
|
+
# TODO: support minimum tokens method
|
75
|
+
self.budget_queue = deque()
|
76
|
+
|
77
|
+
def update_budget(self, load_update: WatchLoadUpdateReq):
|
78
|
+
"""Update the budget queue.
|
79
|
+
Use num_reqs instead of num_waiting_reqs to balance decode running batch.
|
80
|
+
"""
|
81
|
+
loads = load_update.loads
|
82
|
+
self.budget_queue.clear()
|
83
|
+
|
84
|
+
num_reqs = [load.num_reqs for load in loads]
|
85
|
+
if not num_reqs:
|
86
|
+
return
|
87
|
+
|
88
|
+
max_num_reqs = max(num_reqs)
|
89
|
+
if all(x == max_num_reqs for x in num_reqs):
|
90
|
+
return
|
91
|
+
|
92
|
+
while any(x != num_reqs[0] for x in num_reqs):
|
93
|
+
min_load = min(num_reqs)
|
94
|
+
min_indices = [i for i, x in enumerate(num_reqs) if x == min_load]
|
95
|
+
second_min_load = min(x for x in num_reqs if x > min_load)
|
96
|
+
self.budget_queue.extend(
|
97
|
+
[loads[i].dp_rank for i in min_indices] * (second_min_load - min_load)
|
98
|
+
)
|
99
|
+
for idx in min_indices:
|
100
|
+
num_reqs[idx] = second_min_load
|
101
|
+
|
102
|
+
def dispatch(self):
|
103
|
+
if self.budget_queue:
|
104
|
+
return self.budget_queue.popleft()
|
105
|
+
return None
|
106
|
+
|
107
|
+
|
64
108
|
class DataParallelController:
|
65
109
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
66
110
|
|
@@ -98,9 +142,12 @@ class DataParallelController:
|
|
98
142
|
}
|
99
143
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
100
144
|
|
145
|
+
# Load balance budget
|
146
|
+
self.dp_budget = DPBudget()
|
147
|
+
|
101
148
|
# Launch data parallel workers
|
102
149
|
self.scheduler_procs = []
|
103
|
-
self.workers = [None] * server_args.dp_size
|
150
|
+
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
104
151
|
|
105
152
|
if server_args.enable_dp_attention:
|
106
153
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
@@ -121,6 +168,31 @@ class DataParallelController:
|
|
121
168
|
|
122
169
|
self.max_req_input_len = None
|
123
170
|
|
171
|
+
self.init_dispatcher()
|
172
|
+
|
173
|
+
def send_to_all_workers(self, obj):
|
174
|
+
for worker in self.workers:
|
175
|
+
worker.send_pyobj(obj)
|
176
|
+
|
177
|
+
def send_control_message(self, obj):
|
178
|
+
# Send control messages to first worker of tp group
|
179
|
+
for worker in self.workers[:: self.control_message_step]:
|
180
|
+
worker.send_pyobj(obj)
|
181
|
+
|
182
|
+
def handle_load_update_req(self, obj):
|
183
|
+
self.dp_budget.update_budget(obj)
|
184
|
+
|
185
|
+
def init_dispatcher(self):
|
186
|
+
self._request_dispatcher = TypeBasedDispatcher(
|
187
|
+
[
|
188
|
+
(TokenizedGenerateReqInput, self.dispatching),
|
189
|
+
(TokenizedEmbeddingReqInput, self.dispatching),
|
190
|
+
(BlockReqInput, self.send_to_all_workers),
|
191
|
+
(WatchLoadUpdateReq, self.handle_load_update_req),
|
192
|
+
]
|
193
|
+
)
|
194
|
+
self._request_dispatcher.add_fallback_fn(self.send_control_message)
|
195
|
+
|
124
196
|
def launch_dp_schedulers(self, server_args, port_args):
|
125
197
|
base_gpu_id = 0
|
126
198
|
|
@@ -266,27 +338,38 @@ class DataParallelController:
|
|
266
338
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
267
339
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
268
340
|
|
341
|
+
def maybe_external_dp_rank_routing(self, req: Req):
|
342
|
+
if req.data_parallel_rank is not None:
|
343
|
+
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
344
|
+
self.workers[req.data_parallel_rank].send_pyobj(req)
|
345
|
+
return True
|
346
|
+
return False
|
347
|
+
|
269
348
|
def round_robin_scheduler(self, req: Req):
|
349
|
+
if self.maybe_external_dp_rank_routing(req):
|
350
|
+
return
|
351
|
+
|
270
352
|
if self.server_args.disaggregation_mode == "null":
|
271
|
-
|
272
|
-
|
273
|
-
self.workers
|
274
|
-
|
275
|
-
self.workers[self.round_robin_counter].send_pyobj(req)
|
276
|
-
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
277
|
-
self.workers
|
278
|
-
)
|
353
|
+
self.workers[self.round_robin_counter].send_pyobj(req)
|
354
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
355
|
+
self.workers
|
356
|
+
)
|
279
357
|
else:
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
358
|
+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
359
|
+
|
360
|
+
def shortest_queue_scheduler(self, req):
|
361
|
+
if self.maybe_external_dp_rank_routing(req):
|
362
|
+
return
|
363
|
+
target_worker = self.dp_budget.dispatch()
|
364
|
+
if target_worker is None:
|
365
|
+
self.round_robin_scheduler(req)
|
366
|
+
else:
|
367
|
+
self.workers[target_worker].send_pyobj(req)
|
288
368
|
|
289
369
|
def minimum_tokens_scheduler(self, req):
|
370
|
+
if self.maybe_external_dp_rank_routing(req):
|
371
|
+
return
|
372
|
+
|
290
373
|
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
291
374
|
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
292
375
|
def get_next_global_balance_id() -> int:
|
@@ -320,22 +403,7 @@ class DataParallelController:
|
|
320
403
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
321
404
|
except zmq.ZMQError:
|
322
405
|
break
|
323
|
-
|
324
|
-
if isinstance(
|
325
|
-
recv_req,
|
326
|
-
(
|
327
|
-
TokenizedGenerateReqInput,
|
328
|
-
TokenizedEmbeddingReqInput,
|
329
|
-
),
|
330
|
-
):
|
331
|
-
self.dispatching(recv_req)
|
332
|
-
elif isinstance(recv_req, BlockReqInput):
|
333
|
-
for worker in self.workers:
|
334
|
-
worker.send_pyobj(recv_req)
|
335
|
-
else:
|
336
|
-
# Send other control messages to first worker of tp group
|
337
|
-
for worker in self.workers[:: self.control_message_step]:
|
338
|
-
worker.send_pyobj(recv_req)
|
406
|
+
self._request_dispatcher(recv_req)
|
339
407
|
|
340
408
|
|
341
409
|
def run_data_parallel_controller_process(
|
@@ -343,7 +411,9 @@ def run_data_parallel_controller_process(
|
|
343
411
|
port_args: PortArgs,
|
344
412
|
pipe_writer,
|
345
413
|
):
|
414
|
+
kill_itself_when_parent_died()
|
346
415
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
416
|
+
faulthandler.enable()
|
347
417
|
configure_logger(server_args)
|
348
418
|
parent_process = psutil.Process().parent()
|
349
419
|
balance_meta = DPBalanceMeta(server_args.dp_size)
|
@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
|
|
34
34
|
FreezeGCReq,
|
35
35
|
MultiTokenizerRegisterReq,
|
36
36
|
)
|
37
|
-
from sglang.srt.managers.multi_tokenizer_mixin import
|
37
|
+
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
|
38
38
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
39
39
|
from sglang.srt.utils import (
|
40
40
|
configure_logger,
|
@@ -69,7 +69,7 @@ class DecodeStatus:
|
|
69
69
|
sent_offset: int = 0
|
70
70
|
|
71
71
|
|
72
|
-
class DetokenizerManager(
|
72
|
+
class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
73
73
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
74
74
|
|
75
75
|
def __init__(
|
@@ -246,6 +246,8 @@ class DetokenizerManager(MultiTokenizerMixin):
|
|
246
246
|
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
247
247
|
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
248
248
|
output_hidden_states=recv_obj.output_hidden_states,
|
249
|
+
placeholder_tokens_idx=None,
|
250
|
+
placeholder_tokens_val=None,
|
249
251
|
)
|
250
252
|
|
251
253
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
@@ -257,6 +259,8 @@ class DetokenizerManager(MultiTokenizerMixin):
|
|
257
259
|
prompt_tokens=recv_obj.prompt_tokens,
|
258
260
|
completion_tokens=recv_obj.completion_tokens,
|
259
261
|
cached_tokens=recv_obj.cached_tokens,
|
262
|
+
placeholder_tokens_idx=None,
|
263
|
+
placeholder_tokens_val=None,
|
260
264
|
)
|
261
265
|
|
262
266
|
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
|
@@ -289,11 +293,11 @@ def run_detokenizer_process(
|
|
289
293
|
try:
|
290
294
|
manager = DetokenizerManager(server_args, port_args)
|
291
295
|
if server_args.tokenizer_worker_num > 1:
|
292
|
-
manager.
|
296
|
+
manager.multi_http_worker_event_loop()
|
293
297
|
else:
|
294
298
|
manager.event_loop()
|
295
299
|
except Exception:
|
296
|
-
manager.
|
300
|
+
manager.maybe_clear_socket_mapping()
|
297
301
|
traceback = get_exception_traceback()
|
298
302
|
logger.error(f"DetokenizerManager hit an exception: {traceback}")
|
299
303
|
parent_process.send_signal(signal.SIGQUIT)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""Start bootstrap/kv-store-related server"""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Type
|
5
|
+
|
6
|
+
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
|
7
|
+
from sglang.srt.disaggregation.utils import (
|
8
|
+
DisaggregationMode,
|
9
|
+
KVClassType,
|
10
|
+
TransferBackend,
|
11
|
+
get_kv_class,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
|
16
|
+
def start_disagg_service(
|
17
|
+
server_args: ServerArgs,
|
18
|
+
):
|
19
|
+
# Start kv boostrap server on prefill
|
20
|
+
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
|
21
|
+
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
|
22
|
+
|
23
|
+
if disagg_mode == DisaggregationMode.PREFILL:
|
24
|
+
# only start bootstrap server on prefill tm
|
25
|
+
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
|
26
|
+
transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
27
|
+
)
|
28
|
+
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
|
29
|
+
host=server_args.host,
|
30
|
+
port=server_args.disaggregation_bootstrap_port,
|
31
|
+
)
|
32
|
+
is_create_store = (
|
33
|
+
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
|
34
|
+
)
|
35
|
+
if is_create_store:
|
36
|
+
try:
|
37
|
+
from mf_adapter import create_config_store
|
38
|
+
|
39
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
40
|
+
create_config_store(ascend_url)
|
41
|
+
except Exception as e:
|
42
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
43
|
+
error_message += f" With exception {e}"
|
44
|
+
raise error_message
|
45
|
+
|
46
|
+
return bootstrap_server
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -121,6 +121,7 @@ class GenerateReqInput:
|
|
121
121
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
122
122
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
123
123
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
124
|
+
bootstrap_pair_key: Optional[Union[List[str], str]] = None
|
124
125
|
|
125
126
|
# For data parallel rank routing
|
126
127
|
data_parallel_rank: Optional[int] = None
|
@@ -128,6 +129,21 @@ class GenerateReqInput:
|
|
128
129
|
# For background responses (OpenAI responses API)
|
129
130
|
background: bool = False
|
130
131
|
|
132
|
+
# Conversation id used for tracking requests
|
133
|
+
conversation_id: Optional[str] = None
|
134
|
+
|
135
|
+
# Label for the request
|
136
|
+
label: Optional[str] = None
|
137
|
+
|
138
|
+
# Priority for the request
|
139
|
+
priority: Optional[int] = None
|
140
|
+
|
141
|
+
# Image gen grpc migration
|
142
|
+
return_bytes: bool = False
|
143
|
+
|
144
|
+
# For customer metric labels
|
145
|
+
customer_labels: Optional[Dict[str, str]] = None
|
146
|
+
|
131
147
|
def contains_mm_input(self) -> bool:
|
132
148
|
return (
|
133
149
|
has_valid_data(self.image_data)
|
@@ -258,6 +274,7 @@ class GenerateReqInput:
|
|
258
274
|
self._normalize_sampling_params(num)
|
259
275
|
self._normalize_logprob_params(num)
|
260
276
|
self._normalize_custom_logit_processor(num)
|
277
|
+
self._normalize_bootstrap_params(num)
|
261
278
|
|
262
279
|
def _expand_inputs(self, num):
|
263
280
|
"""Expand the main inputs (text, input_ids, input_embeds) for parallel sampling."""
|
@@ -297,6 +314,11 @@ class GenerateReqInput:
|
|
297
314
|
self.image_data = [[self.image_data]] * num
|
298
315
|
self.modalities = ["image"] * num
|
299
316
|
elif isinstance(self.image_data, list):
|
317
|
+
# Handle empty list case - treat as no images
|
318
|
+
if len(self.image_data) == 0:
|
319
|
+
self.image_data = [None] * num
|
320
|
+
return
|
321
|
+
|
300
322
|
if len(self.image_data) != self.batch_size:
|
301
323
|
raise ValueError(
|
302
324
|
"The length of image_data should be equal to the batch size."
|
@@ -421,6 +443,40 @@ class GenerateReqInput:
|
|
421
443
|
"Cannot use list custom_logit_processor with parallel_sample_num > 1"
|
422
444
|
)
|
423
445
|
|
446
|
+
def _normalize_bootstrap_params(self, num):
|
447
|
+
"""Normalize bootstrap parameters for batch processing."""
|
448
|
+
# Normalize bootstrap_host
|
449
|
+
if self.bootstrap_host is None:
|
450
|
+
self.bootstrap_host = [None] * num
|
451
|
+
elif not isinstance(self.bootstrap_host, list):
|
452
|
+
self.bootstrap_host = [self.bootstrap_host] * num
|
453
|
+
elif isinstance(self.bootstrap_host, list):
|
454
|
+
self.bootstrap_host = self.bootstrap_host * self.parallel_sample_num
|
455
|
+
|
456
|
+
# Normalize bootstrap_port
|
457
|
+
if self.bootstrap_port is None:
|
458
|
+
self.bootstrap_port = [None] * num
|
459
|
+
elif not isinstance(self.bootstrap_port, list):
|
460
|
+
self.bootstrap_port = [self.bootstrap_port] * num
|
461
|
+
elif isinstance(self.bootstrap_port, list):
|
462
|
+
self.bootstrap_port = self.bootstrap_port * self.parallel_sample_num
|
463
|
+
|
464
|
+
# Normalize bootstrap_room
|
465
|
+
if self.bootstrap_room is None:
|
466
|
+
self.bootstrap_room = [None] * num
|
467
|
+
elif not isinstance(self.bootstrap_room, list):
|
468
|
+
self.bootstrap_room = [self.bootstrap_room + i for i in range(num)]
|
469
|
+
elif isinstance(self.bootstrap_room, list):
|
470
|
+
self.bootstrap_room = self.bootstrap_room * self.parallel_sample_num
|
471
|
+
|
472
|
+
# Normalize bootstrap_pair_key
|
473
|
+
if self.bootstrap_pair_key is None:
|
474
|
+
self.bootstrap_pair_key = [None] * num
|
475
|
+
elif not isinstance(self.bootstrap_pair_key, list):
|
476
|
+
self.bootstrap_pair_key = [self.bootstrap_pair_key] * num
|
477
|
+
elif isinstance(self.bootstrap_pair_key, list):
|
478
|
+
self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num
|
479
|
+
|
424
480
|
def _validate_session_params(self):
|
425
481
|
"""Validate that session parameters are properly formatted."""
|
426
482
|
if self.session_params is not None:
|
@@ -453,7 +509,13 @@ class GenerateReqInput:
|
|
453
509
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
454
510
|
stream=self.stream,
|
455
511
|
log_metrics=self.log_metrics,
|
512
|
+
return_hidden_states=(
|
513
|
+
self.return_hidden_states[i]
|
514
|
+
if isinstance(self.return_hidden_states, list)
|
515
|
+
else self.return_hidden_states
|
516
|
+
),
|
456
517
|
modalities=self.modalities[i] if self.modalities else None,
|
518
|
+
session_params=self.session_params,
|
457
519
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
458
520
|
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
459
521
|
custom_logit_processor=(
|
@@ -461,11 +523,6 @@ class GenerateReqInput:
|
|
461
523
|
if self.custom_logit_processor is not None
|
462
524
|
else None
|
463
525
|
),
|
464
|
-
return_hidden_states=(
|
465
|
-
self.return_hidden_states[i]
|
466
|
-
if isinstance(self.return_hidden_states, list)
|
467
|
-
else self.return_hidden_states
|
468
|
-
),
|
469
526
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
470
527
|
bootstrap_host=(
|
471
528
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
@@ -476,9 +533,18 @@ class GenerateReqInput:
|
|
476
533
|
bootstrap_room=(
|
477
534
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
478
535
|
),
|
536
|
+
bootstrap_pair_key=(
|
537
|
+
self.bootstrap_pair_key[i]
|
538
|
+
if self.bootstrap_pair_key is not None
|
539
|
+
else None
|
540
|
+
),
|
479
541
|
data_parallel_rank=(
|
480
542
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
481
543
|
),
|
544
|
+
conversation_id=self.conversation_id,
|
545
|
+
label=self.label,
|
546
|
+
priority=self.priority,
|
547
|
+
return_bytes=self.return_bytes,
|
482
548
|
)
|
483
549
|
|
484
550
|
|
@@ -504,27 +570,28 @@ class TokenizedGenerateReqInput:
|
|
504
570
|
token_ids_logprob: List[int]
|
505
571
|
# Whether to stream output
|
506
572
|
stream: bool
|
573
|
+
# Whether to return hidden states
|
574
|
+
return_hidden_states: bool = False
|
507
575
|
|
508
|
-
# LoRA related
|
509
|
-
lora_id: Optional[str] = None # None means just use the base model
|
510
576
|
# The input embeds
|
511
577
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
512
578
|
|
513
579
|
# Session info for continual prompting
|
514
580
|
session_params: Optional[SessionParams] = None
|
515
581
|
|
582
|
+
# LoRA related
|
583
|
+
lora_id: Optional[str] = None # None means just use the base model
|
584
|
+
|
516
585
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
517
586
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
518
587
|
# Use the processor's `to_str()` method to generate the serialized string.
|
519
588
|
custom_logit_processor: Optional[str] = None
|
520
589
|
|
521
|
-
# Whether to return hidden states
|
522
|
-
return_hidden_states: bool = False
|
523
|
-
|
524
590
|
# For disaggregated inference
|
525
591
|
bootstrap_host: Optional[str] = None
|
526
592
|
bootstrap_port: Optional[int] = None
|
527
593
|
bootstrap_room: Optional[int] = None
|
594
|
+
bootstrap_pair_key: Optional[str] = None
|
528
595
|
|
529
596
|
# For data parallel rank routing
|
530
597
|
data_parallel_rank: Optional[int] = None
|
@@ -532,6 +599,18 @@ class TokenizedGenerateReqInput:
|
|
532
599
|
# For dp balance
|
533
600
|
dp_balance_id: int = -1
|
534
601
|
|
602
|
+
# Label for the request
|
603
|
+
label: Optional[str] = None
|
604
|
+
|
605
|
+
# Priority for the request
|
606
|
+
priority: Optional[int] = None
|
607
|
+
|
608
|
+
# Image gen grpc migration
|
609
|
+
return_bytes: bool = False
|
610
|
+
|
611
|
+
# tracing context
|
612
|
+
trace_context: Optional[Dict] = None
|
613
|
+
|
535
614
|
|
536
615
|
@dataclass
|
537
616
|
class BatchTokenizedGenerateReqInput:
|
@@ -581,6 +660,9 @@ class EmbeddingReqInput:
|
|
581
660
|
# For background responses (OpenAI responses API)
|
582
661
|
background: bool = False
|
583
662
|
|
663
|
+
# tracing context
|
664
|
+
trace_context: Optional[Dict] = None
|
665
|
+
|
584
666
|
def normalize_batch_and_arguments(self):
|
585
667
|
# at least one of text, input_ids, or image should be provided
|
586
668
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -738,9 +820,26 @@ class BatchTokenIDOut:
|
|
738
820
|
# Hidden states
|
739
821
|
output_hidden_states: List[List[float]]
|
740
822
|
|
823
|
+
# The information of placeholder tokens (e.g., image token)
|
824
|
+
# idx is the index of the token in the prompt after expansion.
|
825
|
+
# val is the length of padded tokens after expansion.
|
826
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
827
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
828
|
+
|
741
829
|
|
742
830
|
@dataclass
|
743
831
|
class BatchMultimodalDecodeReq:
|
832
|
+
decoded_ids: List[int]
|
833
|
+
input_token_logprobs_val: List[float]
|
834
|
+
input_token_logprobs_idx: List[int]
|
835
|
+
output_token_logprobs_val: List[float]
|
836
|
+
output_token_logprobs_idx: List[int]
|
837
|
+
read_offsets: List[int]
|
838
|
+
skip_special_tokens: List[bool]
|
839
|
+
spaces_between_special_tokens: List[bool]
|
840
|
+
image_resolutions: List[List[int]]
|
841
|
+
resize_image_resolutions: List[List[int]]
|
842
|
+
|
744
843
|
# The request id
|
745
844
|
rids: List[str]
|
746
845
|
finished_reasons: List[BaseFinishReason]
|
@@ -750,6 +849,12 @@ class BatchMultimodalDecodeReq:
|
|
750
849
|
completion_tokens: List[int]
|
751
850
|
cached_tokens: List[int]
|
752
851
|
|
852
|
+
# Placeholder token info
|
853
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
854
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
855
|
+
|
856
|
+
return_bytes: bool = False
|
857
|
+
|
753
858
|
|
754
859
|
@dataclass
|
755
860
|
class BatchStrOut:
|
@@ -785,6 +890,9 @@ class BatchStrOut:
|
|
785
890
|
# Hidden states
|
786
891
|
output_hidden_states: List[List[float]]
|
787
892
|
|
893
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
894
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
895
|
+
|
788
896
|
|
789
897
|
@dataclass
|
790
898
|
class BatchMultimodalOut:
|
@@ -792,14 +900,26 @@ class BatchMultimodalOut:
|
|
792
900
|
rids: List[str]
|
793
901
|
# The finish reason
|
794
902
|
finished_reasons: List[dict]
|
903
|
+
decoded_ids: List[List[int]]
|
795
904
|
# The outputs
|
796
|
-
outputs: List[List[Dict]]
|
905
|
+
outputs: Union[List[str | bytes], List[List[Dict]]]
|
906
|
+
|
907
|
+
# probability values for input tokens and output tokens
|
908
|
+
input_token_logprobs_val: List[List[float]]
|
909
|
+
input_token_logprobs_idx: List[List[int]]
|
910
|
+
output_token_logprobs_val: List[List[float]]
|
911
|
+
output_token_logprobs_idx: List[List[int]]
|
797
912
|
|
798
913
|
# Token counts
|
799
914
|
prompt_tokens: List[int]
|
800
915
|
completion_tokens: List[int]
|
801
916
|
cached_tokens: List[int]
|
802
917
|
|
918
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
919
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
920
|
+
|
921
|
+
return_bytes: List[bool]
|
922
|
+
|
803
923
|
|
804
924
|
@dataclass
|
805
925
|
class BatchEmbeddingOut:
|
@@ -812,6 +932,9 @@ class BatchEmbeddingOut:
|
|
812
932
|
# Token counts
|
813
933
|
prompt_tokens: List[int]
|
814
934
|
cached_tokens: List[int]
|
935
|
+
# Placeholder token info
|
936
|
+
placeholder_tokens_idx: List[Optional[List[int]]]
|
937
|
+
placeholder_tokens_val: List[Optional[List[int]]]
|
815
938
|
|
816
939
|
|
817
940
|
@dataclass
|
@@ -844,6 +967,12 @@ class UpdateWeightFromDiskReqInput:
|
|
844
967
|
abort_all_requests: bool = False
|
845
968
|
# Optional: Update weight version along with weights
|
846
969
|
weight_version: Optional[str] = None
|
970
|
+
# Whether to update weights asynchronously
|
971
|
+
is_async: bool = False
|
972
|
+
# Whether to empty torch cache
|
973
|
+
torch_empty_cache: bool = False
|
974
|
+
# Whether to keep the scheduler paused after weight update
|
975
|
+
keep_pause: bool = False
|
847
976
|
|
848
977
|
|
849
978
|
@dataclass
|
@@ -900,6 +1029,44 @@ class UpdateWeightsFromTensorReqOutput:
|
|
900
1029
|
message: str
|
901
1030
|
|
902
1031
|
|
1032
|
+
@dataclass
|
1033
|
+
class InitWeightsSendGroupForRemoteInstanceReqInput:
|
1034
|
+
# The master address
|
1035
|
+
master_address: str
|
1036
|
+
# The ports for each rank's communication group
|
1037
|
+
ports: str
|
1038
|
+
# The rank in the communication group
|
1039
|
+
group_rank: int
|
1040
|
+
# The world size
|
1041
|
+
world_size: int
|
1042
|
+
# The group name
|
1043
|
+
group_name: str = "weight_send_group"
|
1044
|
+
# The backend
|
1045
|
+
backend: str = "nccl"
|
1046
|
+
|
1047
|
+
|
1048
|
+
@dataclass
|
1049
|
+
class InitWeightsSendGroupForRemoteInstanceReqOutput:
|
1050
|
+
success: bool
|
1051
|
+
message: str
|
1052
|
+
|
1053
|
+
|
1054
|
+
@dataclass
|
1055
|
+
class SendWeightsToRemoteInstanceReqInput:
|
1056
|
+
# The master address
|
1057
|
+
master_address: str
|
1058
|
+
# The ports for each rank's communication group
|
1059
|
+
ports: str
|
1060
|
+
# The group name
|
1061
|
+
group_name: str = "weight_send_group"
|
1062
|
+
|
1063
|
+
|
1064
|
+
@dataclass
|
1065
|
+
class SendWeightsToRemoteInstanceReqOutput:
|
1066
|
+
success: bool
|
1067
|
+
message: str
|
1068
|
+
|
1069
|
+
|
903
1070
|
@dataclass
|
904
1071
|
class InitWeightsUpdateGroupReqInput:
|
905
1072
|
# The master address
|
@@ -983,6 +1150,7 @@ class AbortReq:
|
|
983
1150
|
abort_all: bool = False
|
984
1151
|
# The finished reason data
|
985
1152
|
finished_reason: Optional[Dict[str, Any]] = None
|
1153
|
+
abort_reason: Optional[str] = None
|
986
1154
|
# used in MultiTokenzierManager mode
|
987
1155
|
rids: Optional[Union[List[str], str]] = None
|
988
1156
|
|
@@ -1061,6 +1229,7 @@ class ConfigureLoggingReq:
|
|
1061
1229
|
log_requests_level: Optional[int] = None
|
1062
1230
|
dump_requests_folder: Optional[str] = None
|
1063
1231
|
dump_requests_threshold: Optional[int] = None
|
1232
|
+
crash_dump_folder: Optional[str] = None
|
1064
1233
|
|
1065
1234
|
|
1066
1235
|
@dataclass
|
@@ -1195,7 +1364,7 @@ class MultiTokenizerRegisterReq:
|
|
1195
1364
|
|
1196
1365
|
|
1197
1366
|
@dataclass
|
1198
|
-
class
|
1367
|
+
class MultiTokenizerWrapper:
|
1199
1368
|
worker_id: int
|
1200
1369
|
obj: Optional[Any] = None
|
1201
1370
|
|
@@ -1208,3 +1377,21 @@ class BlockReqType(Enum):
|
|
1208
1377
|
@dataclass
|
1209
1378
|
class BlockReqInput:
|
1210
1379
|
type: BlockReqType
|
1380
|
+
|
1381
|
+
|
1382
|
+
@dataclass
|
1383
|
+
class GetLoadReqInput:
|
1384
|
+
pass
|
1385
|
+
|
1386
|
+
|
1387
|
+
@dataclass
|
1388
|
+
class GetLoadReqOutput:
|
1389
|
+
dp_rank: int
|
1390
|
+
num_reqs: int
|
1391
|
+
num_waiting_reqs: int
|
1392
|
+
num_tokens: int
|
1393
|
+
|
1394
|
+
|
1395
|
+
@dataclass
|
1396
|
+
class WatchLoadUpdateReq:
|
1397
|
+
loads: List[GetLoadReqOutput]
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -629,6 +629,7 @@ def general_mm_embed_routine(
|
|
629
629
|
embed_tokens = language_model.get_input_embeddings()
|
630
630
|
if (
|
631
631
|
not forward_batch.forward_mode.is_decode()
|
632
|
+
and not forward_batch.forward_mode.is_target_verify()
|
632
633
|
and forward_batch.contains_mm_inputs()
|
633
634
|
):
|
634
635
|
mm_inputs_list = [
|