sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,579 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
15
|
+
import asyncio
|
16
|
+
import logging
|
17
|
+
import multiprocessing as multiprocessing
|
18
|
+
import os
|
19
|
+
import pickle
|
20
|
+
import sys
|
21
|
+
import threading
|
22
|
+
from functools import partialmethod
|
23
|
+
from multiprocessing import shared_memory
|
24
|
+
from typing import Any, Dict
|
25
|
+
|
26
|
+
import setproctitle
|
27
|
+
import zmq
|
28
|
+
import zmq.asyncio
|
29
|
+
|
30
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
31
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
32
|
+
from sglang.srt.managers.io_struct import (
|
33
|
+
BatchEmbeddingOut,
|
34
|
+
BatchMultimodalOut,
|
35
|
+
BatchStrOut,
|
36
|
+
BatchTokenIDOut,
|
37
|
+
MultiTokenizerRegisterReq,
|
38
|
+
MultiTokenizerWrapper,
|
39
|
+
)
|
40
|
+
from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
|
41
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
42
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
43
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
44
|
+
from sglang.utils import get_exception_traceback
|
45
|
+
|
46
|
+
logger = logging.getLogger(__name__)
|
47
|
+
|
48
|
+
|
49
|
+
class SocketMapping:
|
50
|
+
def __init__(self):
|
51
|
+
self._zmq_context = zmq.Context()
|
52
|
+
self._mapping: Dict[str, zmq.Socket] = {}
|
53
|
+
|
54
|
+
def clear_all_sockets(self):
|
55
|
+
for socket in self._mapping.values():
|
56
|
+
socket.close()
|
57
|
+
self._mapping.clear()
|
58
|
+
|
59
|
+
def register_ipc_mapping(
|
60
|
+
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
|
61
|
+
):
|
62
|
+
type_str = "tokenizer" if is_tokenizer else "detokenizer"
|
63
|
+
if worker_id in self._mapping:
|
64
|
+
logger.warning(
|
65
|
+
f"{type_str} already registered with worker {worker_id}, skipping..."
|
66
|
+
)
|
67
|
+
return
|
68
|
+
logger.info(
|
69
|
+
f"{type_str} not registered with worker {worker_id}, registering..."
|
70
|
+
)
|
71
|
+
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
|
72
|
+
self._mapping[worker_id] = socket
|
73
|
+
self._mapping[worker_id].send_pyobj(recv_obj)
|
74
|
+
|
75
|
+
def send_output(self, worker_id: str, output: Any):
|
76
|
+
if worker_id not in self._mapping:
|
77
|
+
logger.error(
|
78
|
+
f"worker ID {worker_id} not registered. Check if the server Process is alive"
|
79
|
+
)
|
80
|
+
return
|
81
|
+
self._mapping[worker_id].send_pyobj(output)
|
82
|
+
|
83
|
+
|
84
|
+
def _handle_output_by_index(output, i):
|
85
|
+
"""NOTE: A maintainable method is better here."""
|
86
|
+
if isinstance(output, BatchTokenIDOut):
|
87
|
+
new_output = BatchTokenIDOut(
|
88
|
+
rids=[output.rids[i]],
|
89
|
+
finished_reasons=(
|
90
|
+
[output.finished_reasons[i]]
|
91
|
+
if len(output.finished_reasons) > i
|
92
|
+
else None
|
93
|
+
),
|
94
|
+
decoded_texts=(
|
95
|
+
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
|
96
|
+
),
|
97
|
+
decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
|
98
|
+
read_offsets=(
|
99
|
+
[output.read_offsets[i]] if len(output.read_offsets) > i else None
|
100
|
+
),
|
101
|
+
output_ids=(
|
102
|
+
[output.output_ids[i]]
|
103
|
+
if output.output_ids and len(output.output_ids) > i
|
104
|
+
else None
|
105
|
+
),
|
106
|
+
skip_special_tokens=(
|
107
|
+
[output.skip_special_tokens[i]]
|
108
|
+
if len(output.skip_special_tokens) > i
|
109
|
+
else None
|
110
|
+
),
|
111
|
+
spaces_between_special_tokens=(
|
112
|
+
[output.spaces_between_special_tokens[i]]
|
113
|
+
if len(output.spaces_between_special_tokens) > i
|
114
|
+
else None
|
115
|
+
),
|
116
|
+
no_stop_trim=(
|
117
|
+
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
|
118
|
+
),
|
119
|
+
prompt_tokens=(
|
120
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
121
|
+
),
|
122
|
+
completion_tokens=(
|
123
|
+
[output.completion_tokens[i]]
|
124
|
+
if len(output.completion_tokens) > i
|
125
|
+
else None
|
126
|
+
),
|
127
|
+
cached_tokens=(
|
128
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
129
|
+
),
|
130
|
+
spec_verify_ct=(
|
131
|
+
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
132
|
+
),
|
133
|
+
input_token_logprobs_val=(
|
134
|
+
[output.input_token_logprobs_val[i]]
|
135
|
+
if output.input_token_logprobs_val
|
136
|
+
else None
|
137
|
+
),
|
138
|
+
input_token_logprobs_idx=(
|
139
|
+
[output.input_token_logprobs_idx[i]]
|
140
|
+
if output.input_token_logprobs_idx
|
141
|
+
else None
|
142
|
+
),
|
143
|
+
output_token_logprobs_val=(
|
144
|
+
[output.output_token_logprobs_val[i]]
|
145
|
+
if output.output_token_logprobs_val
|
146
|
+
else None
|
147
|
+
),
|
148
|
+
output_token_logprobs_idx=(
|
149
|
+
[output.output_token_logprobs_idx[i]]
|
150
|
+
if output.output_token_logprobs_idx
|
151
|
+
else None
|
152
|
+
),
|
153
|
+
input_top_logprobs_val=(
|
154
|
+
[output.input_top_logprobs_val[i]]
|
155
|
+
if output.input_top_logprobs_val
|
156
|
+
else None
|
157
|
+
),
|
158
|
+
input_top_logprobs_idx=(
|
159
|
+
[output.input_top_logprobs_idx[i]]
|
160
|
+
if output.input_top_logprobs_idx
|
161
|
+
else None
|
162
|
+
),
|
163
|
+
output_top_logprobs_val=(
|
164
|
+
[output.output_top_logprobs_val[i]]
|
165
|
+
if output.output_top_logprobs_val
|
166
|
+
else None
|
167
|
+
),
|
168
|
+
output_top_logprobs_idx=(
|
169
|
+
[output.output_top_logprobs_idx[i]]
|
170
|
+
if output.output_top_logprobs_idx
|
171
|
+
else None
|
172
|
+
),
|
173
|
+
input_token_ids_logprobs_val=(
|
174
|
+
[output.input_token_ids_logprobs_val[i]]
|
175
|
+
if output.input_token_ids_logprobs_val
|
176
|
+
else None
|
177
|
+
),
|
178
|
+
input_token_ids_logprobs_idx=(
|
179
|
+
[output.input_token_ids_logprobs_idx[i]]
|
180
|
+
if output.input_token_ids_logprobs_idx
|
181
|
+
else None
|
182
|
+
),
|
183
|
+
output_token_ids_logprobs_val=(
|
184
|
+
[output.output_token_ids_logprobs_val[i]]
|
185
|
+
if output.output_token_ids_logprobs_val
|
186
|
+
else None
|
187
|
+
),
|
188
|
+
output_token_ids_logprobs_idx=(
|
189
|
+
[output.output_token_ids_logprobs_idx[i]]
|
190
|
+
if output.output_token_ids_logprobs_idx
|
191
|
+
else None
|
192
|
+
),
|
193
|
+
output_hidden_states=(
|
194
|
+
[output.output_hidden_states[i]]
|
195
|
+
if output.output_hidden_states
|
196
|
+
else None
|
197
|
+
),
|
198
|
+
placeholder_tokens_idx=None,
|
199
|
+
placeholder_tokens_val=None,
|
200
|
+
)
|
201
|
+
elif isinstance(output, BatchEmbeddingOut):
|
202
|
+
new_output = BatchEmbeddingOut(
|
203
|
+
rids=[output.rids[i]],
|
204
|
+
finished_reasons=(
|
205
|
+
[output.finished_reasons[i]]
|
206
|
+
if len(output.finished_reasons) > i
|
207
|
+
else None
|
208
|
+
),
|
209
|
+
embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
|
210
|
+
prompt_tokens=(
|
211
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
212
|
+
),
|
213
|
+
cached_tokens=(
|
214
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
215
|
+
),
|
216
|
+
placeholder_tokens_idx=None,
|
217
|
+
placeholder_tokens_val=None,
|
218
|
+
)
|
219
|
+
elif isinstance(output, BatchStrOut):
|
220
|
+
new_output = BatchStrOut(
|
221
|
+
rids=[output.rids[i]],
|
222
|
+
finished_reasons=(
|
223
|
+
[output.finished_reasons[i]]
|
224
|
+
if len(output.finished_reasons) > i
|
225
|
+
else None
|
226
|
+
),
|
227
|
+
output_strs=(
|
228
|
+
[output.output_strs[i]] if len(output.output_strs) > i else None
|
229
|
+
),
|
230
|
+
output_ids=(
|
231
|
+
[output.output_ids[i]]
|
232
|
+
if output.output_ids and len(output.output_ids) > i
|
233
|
+
else None
|
234
|
+
),
|
235
|
+
prompt_tokens=(
|
236
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
237
|
+
),
|
238
|
+
completion_tokens=(
|
239
|
+
[output.completion_tokens[i]]
|
240
|
+
if len(output.completion_tokens) > i
|
241
|
+
else None
|
242
|
+
),
|
243
|
+
cached_tokens=(
|
244
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
245
|
+
),
|
246
|
+
spec_verify_ct=(
|
247
|
+
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
|
248
|
+
),
|
249
|
+
input_token_logprobs_val=(
|
250
|
+
[output.input_token_logprobs_val[i]]
|
251
|
+
if output.input_token_logprobs_val
|
252
|
+
else None
|
253
|
+
),
|
254
|
+
input_token_logprobs_idx=(
|
255
|
+
[output.input_token_logprobs_idx[i]]
|
256
|
+
if output.input_token_logprobs_idx
|
257
|
+
else None
|
258
|
+
),
|
259
|
+
output_token_logprobs_val=(
|
260
|
+
[output.output_token_logprobs_val[i]]
|
261
|
+
if output.output_token_logprobs_val
|
262
|
+
else None
|
263
|
+
),
|
264
|
+
output_token_logprobs_idx=(
|
265
|
+
[output.output_token_logprobs_idx[i]]
|
266
|
+
if output.output_token_logprobs_idx
|
267
|
+
else None
|
268
|
+
),
|
269
|
+
input_top_logprobs_val=(
|
270
|
+
[output.input_top_logprobs_val[i]]
|
271
|
+
if output.input_top_logprobs_val
|
272
|
+
else None
|
273
|
+
),
|
274
|
+
input_top_logprobs_idx=(
|
275
|
+
[output.input_top_logprobs_idx[i]]
|
276
|
+
if output.input_top_logprobs_idx
|
277
|
+
else None
|
278
|
+
),
|
279
|
+
output_top_logprobs_val=(
|
280
|
+
[output.output_top_logprobs_val[i]]
|
281
|
+
if output.output_top_logprobs_val
|
282
|
+
else None
|
283
|
+
),
|
284
|
+
output_top_logprobs_idx=(
|
285
|
+
[output.output_top_logprobs_idx[i]]
|
286
|
+
if output.output_top_logprobs_idx
|
287
|
+
else None
|
288
|
+
),
|
289
|
+
input_token_ids_logprobs_val=(
|
290
|
+
[output.input_token_ids_logprobs_val[i]]
|
291
|
+
if output.input_token_ids_logprobs_val
|
292
|
+
else None
|
293
|
+
),
|
294
|
+
input_token_ids_logprobs_idx=(
|
295
|
+
[output.input_token_ids_logprobs_idx[i]]
|
296
|
+
if output.input_token_ids_logprobs_idx
|
297
|
+
else None
|
298
|
+
),
|
299
|
+
output_token_ids_logprobs_val=(
|
300
|
+
[output.output_token_ids_logprobs_val[i]]
|
301
|
+
if output.output_token_ids_logprobs_val
|
302
|
+
else None
|
303
|
+
),
|
304
|
+
output_token_ids_logprobs_idx=(
|
305
|
+
[output.output_token_ids_logprobs_idx[i]]
|
306
|
+
if output.output_token_ids_logprobs_idx
|
307
|
+
else None
|
308
|
+
),
|
309
|
+
output_hidden_states=(
|
310
|
+
[output.output_hidden_states[i]]
|
311
|
+
if output.output_hidden_states
|
312
|
+
else None
|
313
|
+
),
|
314
|
+
placeholder_tokens_idx=None,
|
315
|
+
placeholder_tokens_val=None,
|
316
|
+
)
|
317
|
+
elif isinstance(output, BatchMultimodalOut):
|
318
|
+
new_output = BatchMultimodalOut(
|
319
|
+
rids=[output.rids[i]],
|
320
|
+
finished_reasons=(
|
321
|
+
[output.finished_reasons[i]]
|
322
|
+
if len(output.finished_reasons) > i
|
323
|
+
else None
|
324
|
+
),
|
325
|
+
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
|
326
|
+
prompt_tokens=(
|
327
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
328
|
+
),
|
329
|
+
completion_tokens=(
|
330
|
+
[output.completion_tokens[i]]
|
331
|
+
if len(output.completion_tokens) > i
|
332
|
+
else None
|
333
|
+
),
|
334
|
+
cached_tokens=(
|
335
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
336
|
+
),
|
337
|
+
placeholder_tokens_idx=None,
|
338
|
+
placeholder_tokens_val=None,
|
339
|
+
)
|
340
|
+
else:
|
341
|
+
new_output = output
|
342
|
+
return new_output
|
343
|
+
|
344
|
+
|
345
|
+
class MultiHttpWorkerDetokenizerMixin:
|
346
|
+
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
347
|
+
|
348
|
+
def get_worker_ids_from_req_rids(self, rids):
|
349
|
+
if isinstance(rids, list):
|
350
|
+
worker_ids = [int(rid.split("_")[0]) for rid in rids]
|
351
|
+
elif isinstance(rids, str):
|
352
|
+
worker_ids = [int(rids.split("_")[0])]
|
353
|
+
else:
|
354
|
+
worker_ids = []
|
355
|
+
return worker_ids
|
356
|
+
|
357
|
+
def multi_http_worker_event_loop(self):
|
358
|
+
"""The event loop that handles requests, for multi multi-http-worker mode"""
|
359
|
+
self.socket_mapping = SocketMapping()
|
360
|
+
while True:
|
361
|
+
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
362
|
+
output = self._request_dispatcher(recv_obj)
|
363
|
+
if output is None:
|
364
|
+
continue
|
365
|
+
# Extract worker_id from rid
|
366
|
+
if isinstance(recv_obj.rids, list):
|
367
|
+
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
368
|
+
else:
|
369
|
+
raise RuntimeError(
|
370
|
+
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
371
|
+
)
|
372
|
+
|
373
|
+
# Send data using the corresponding socket
|
374
|
+
for i, worker_id in enumerate(worker_ids):
|
375
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
376
|
+
self.socket_mapping.register_ipc_mapping(
|
377
|
+
recv_obj, worker_id, is_tokenizer=False
|
378
|
+
)
|
379
|
+
else:
|
380
|
+
new_output = _handle_output_by_index(output, i)
|
381
|
+
self.socket_mapping.send_output(worker_id, new_output)
|
382
|
+
|
383
|
+
|
384
|
+
class MultiTokenizerRouter:
|
385
|
+
"""A router to receive requests from MultiTokenizerManager"""
|
386
|
+
|
387
|
+
def __init__(
|
388
|
+
self,
|
389
|
+
server_args: ServerArgs,
|
390
|
+
port_args: PortArgs,
|
391
|
+
):
|
392
|
+
self.server_args = server_args
|
393
|
+
context = zmq.asyncio.Context(3)
|
394
|
+
self.recv_from_detokenizer = get_zmq_socket(
|
395
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
396
|
+
)
|
397
|
+
self.send_to_scheduler = get_zmq_socket(
|
398
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
399
|
+
)
|
400
|
+
self.receive_from_worker = get_zmq_socket(
|
401
|
+
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
|
402
|
+
)
|
403
|
+
self._loop = asyncio.new_event_loop()
|
404
|
+
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
405
|
+
self._thread.start()
|
406
|
+
self._task = asyncio.run_coroutine_threadsafe(
|
407
|
+
self.router_worker_obj(), self._loop
|
408
|
+
)
|
409
|
+
# Start handle_loop simultaneously
|
410
|
+
self._handle_task = asyncio.run_coroutine_threadsafe(
|
411
|
+
print_exception_wrapper(self.handle_loop), self._loop
|
412
|
+
)
|
413
|
+
self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
|
414
|
+
|
415
|
+
def _run_loop(self):
|
416
|
+
self._loop.run_forever()
|
417
|
+
|
418
|
+
async def router_worker_obj(self):
|
419
|
+
while True:
|
420
|
+
recv_obj = await self.receive_from_worker.recv_pyobj()
|
421
|
+
await self.send_to_scheduler.send_pyobj(recv_obj)
|
422
|
+
|
423
|
+
async def handle_loop(self):
|
424
|
+
# special reqs will recv from scheduler, need to route to right worker
|
425
|
+
self.socket_mapping = SocketMapping()
|
426
|
+
while True:
|
427
|
+
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
428
|
+
await self._distribute_result_to_workers(recv_obj)
|
429
|
+
|
430
|
+
async def _distribute_result_to_workers(self, recv_obj):
|
431
|
+
"""Distribute result to corresponding workers based on rid"""
|
432
|
+
if isinstance(recv_obj, MultiTokenizerWrapper):
|
433
|
+
worker_ids = [recv_obj.worker_id]
|
434
|
+
recv_obj = recv_obj.obj
|
435
|
+
else:
|
436
|
+
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
437
|
+
|
438
|
+
if len(worker_ids) == 0:
|
439
|
+
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
440
|
+
return
|
441
|
+
|
442
|
+
# Distribute result to each worker
|
443
|
+
for i, worker_id in enumerate(worker_ids):
|
444
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
445
|
+
self.socket_mapping.register_ipc_mapping(
|
446
|
+
recv_obj, worker_id, is_tokenizer=True
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
new_recv_obj = _handle_output_by_index(recv_obj, i)
|
450
|
+
self.socket_mapping.send_output(worker_id, new_recv_obj)
|
451
|
+
|
452
|
+
|
453
|
+
class MultiTokenizerManager(TokenizerManager):
|
454
|
+
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
455
|
+
|
456
|
+
def __init__(
|
457
|
+
self,
|
458
|
+
server_args: ServerArgs,
|
459
|
+
port_args: PortArgs,
|
460
|
+
):
|
461
|
+
setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
|
462
|
+
# prevent init prefill bootstrapserver again
|
463
|
+
disaggregation_mode = server_args.disaggregation_mode
|
464
|
+
server_args.disaggregation_mode = "null"
|
465
|
+
super().__init__(server_args, port_args)
|
466
|
+
|
467
|
+
self.worker_id = os.getpid()
|
468
|
+
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
469
|
+
|
470
|
+
# For PD disaggregtion
|
471
|
+
self.server_args.disaggregation_mode = disaggregation_mode
|
472
|
+
self.disaggregation_mode = DisaggregationMode(
|
473
|
+
self.server_args.disaggregation_mode
|
474
|
+
)
|
475
|
+
self.disaggregation_transfer_backend = TransferBackend(
|
476
|
+
self.server_args.disaggregation_transfer_backend
|
477
|
+
)
|
478
|
+
# Communicator
|
479
|
+
self.register_multi_tokenizer_communicator = _Communicator(
|
480
|
+
self.send_to_scheduler, 2
|
481
|
+
)
|
482
|
+
self._result_dispatcher._mapping.append(
|
483
|
+
(
|
484
|
+
MultiTokenizerRegisterReq,
|
485
|
+
self.register_multi_tokenizer_communicator.handle_recv,
|
486
|
+
)
|
487
|
+
)
|
488
|
+
|
489
|
+
async def register_to_main_tokenizer_manager(self):
|
490
|
+
"""Register this worker to the main TokenizerManager"""
|
491
|
+
# create a handle loop to receive messages from the main TokenizerManager
|
492
|
+
self.auto_create_handle_loop()
|
493
|
+
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
|
494
|
+
req.ipc_name = self.tokenizer_ipc_name
|
495
|
+
_Communicator.enable_multi_tokenizer = True
|
496
|
+
await self.register_multi_tokenizer_communicator(req)
|
497
|
+
|
498
|
+
|
499
|
+
async def print_exception_wrapper(func):
|
500
|
+
"""
|
501
|
+
Sometimes an asyncio function does not print exception.
|
502
|
+
We do another wrapper to handle the exception.
|
503
|
+
"""
|
504
|
+
try:
|
505
|
+
await func()
|
506
|
+
except Exception:
|
507
|
+
traceback = get_exception_traceback()
|
508
|
+
logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
|
509
|
+
if hasattr(func, "__self__") and isinstance(
|
510
|
+
func.__self__, MultiTokenizerRouter
|
511
|
+
):
|
512
|
+
func.__self__.dump_requests_before_crash()
|
513
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
514
|
+
sys.exit(1)
|
515
|
+
|
516
|
+
|
517
|
+
def get_main_process_id() -> int:
|
518
|
+
"""Get the main process ID"""
|
519
|
+
return multiprocessing.current_process()._parent_pid
|
520
|
+
|
521
|
+
|
522
|
+
def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
|
523
|
+
"""Write data to shared memory"""
|
524
|
+
serialized = pickle.dumps(obj)
|
525
|
+
size = len(serialized)
|
526
|
+
try:
|
527
|
+
# Try to open existing shared memory
|
528
|
+
shm = shared_memory.SharedMemory(name=name)
|
529
|
+
# If size is insufficient, close and recreate
|
530
|
+
if shm.size < size:
|
531
|
+
shm.close()
|
532
|
+
shm.unlink()
|
533
|
+
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
534
|
+
except FileNotFoundError:
|
535
|
+
# If not present, create new shared memory
|
536
|
+
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
537
|
+
|
538
|
+
shm.buf[:size] = serialized
|
539
|
+
return shm
|
540
|
+
|
541
|
+
|
542
|
+
def read_from_shared_memory(name: str) -> Any:
|
543
|
+
"""Read data from shared memory"""
|
544
|
+
try:
|
545
|
+
shm = shared_memory.SharedMemory(name=name)
|
546
|
+
data = pickle.loads(bytes(shm.buf))
|
547
|
+
shm.close()
|
548
|
+
return data
|
549
|
+
except FileNotFoundError:
|
550
|
+
raise FileNotFoundError(f"Shared memory {name} not found")
|
551
|
+
|
552
|
+
|
553
|
+
def write_data_for_multi_tokenizer(
|
554
|
+
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
555
|
+
):
|
556
|
+
"""Write args information to share memory for multi-tokenizer"""
|
557
|
+
# get main process ID
|
558
|
+
main_pid = get_main_process_id()
|
559
|
+
current_pid = os.getpid()
|
560
|
+
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
561
|
+
args = (port_args, server_args, scheduler_info)
|
562
|
+
args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
|
563
|
+
args_shm.close()
|
564
|
+
|
565
|
+
return args_shm
|
566
|
+
|
567
|
+
|
568
|
+
def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
|
569
|
+
"""Monkey patch uvicorn multiprocessing is_alive timeout"""
|
570
|
+
# from default 5s -> 10s
|
571
|
+
try:
|
572
|
+
from uvicorn.supervisors.multiprocess import Process
|
573
|
+
|
574
|
+
Process.is_alive = partialmethod(Process.is_alive, timeout=timeout)
|
575
|
+
|
576
|
+
except ImportError:
|
577
|
+
logger.warning(
|
578
|
+
"uvicorn.supervisors.multiprocess not found, skipping monkey patch"
|
579
|
+
)
|