sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
|
|
31
31
|
from datetime import datetime
|
32
32
|
from enum import Enum
|
33
33
|
from http import HTTPStatus
|
34
|
-
from typing import
|
35
|
-
Any,
|
36
|
-
Awaitable,
|
37
|
-
Deque,
|
38
|
-
Dict,
|
39
|
-
Generic,
|
40
|
-
List,
|
41
|
-
Optional,
|
42
|
-
Tuple,
|
43
|
-
TypeVar,
|
44
|
-
Union,
|
45
|
-
)
|
34
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
46
35
|
|
47
36
|
import fastapi
|
48
37
|
import torch
|
@@ -53,18 +42,15 @@ from fastapi import BackgroundTasks
|
|
53
42
|
|
54
43
|
from sglang.srt.aio_rwlock import RWLock
|
55
44
|
from sglang.srt.configs.model_config import ModelConfig
|
56
|
-
from sglang.srt.disaggregation.utils import
|
57
|
-
DisaggregationMode,
|
58
|
-
KVClassType,
|
59
|
-
TransferBackend,
|
60
|
-
get_kv_class,
|
61
|
-
)
|
45
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
62
46
|
from sglang.srt.hf_transformers_utils import (
|
63
47
|
get_processor,
|
64
48
|
get_tokenizer,
|
65
49
|
get_tokenizer_from_processor,
|
66
50
|
)
|
67
51
|
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
52
|
+
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
53
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
68
54
|
from sglang.srt.managers.io_struct import (
|
69
55
|
AbortReq,
|
70
56
|
BatchEmbeddingOut,
|
@@ -73,60 +59,38 @@ from sglang.srt.managers.io_struct import (
|
|
73
59
|
BatchTokenIDOut,
|
74
60
|
BatchTokenizedEmbeddingReqInput,
|
75
61
|
BatchTokenizedGenerateReqInput,
|
76
|
-
ClearHiCacheReqInput,
|
77
|
-
ClearHiCacheReqOutput,
|
78
62
|
CloseSessionReqInput,
|
79
63
|
ConfigureLoggingReq,
|
80
64
|
EmbeddingReqInput,
|
81
|
-
ExpertDistributionReq,
|
82
|
-
ExpertDistributionReqOutput,
|
83
|
-
FlushCacheReqInput,
|
84
|
-
FlushCacheReqOutput,
|
85
65
|
FreezeGCReq,
|
86
66
|
GenerateReqInput,
|
87
|
-
|
88
|
-
GetInternalStateReqOutput,
|
89
|
-
GetWeightsByNameReqInput,
|
90
|
-
GetWeightsByNameReqOutput,
|
67
|
+
GetLoadReqInput,
|
91
68
|
HealthCheckOutput,
|
92
|
-
|
93
|
-
InitWeightsUpdateGroupReqOutput,
|
94
|
-
LoadLoRAAdapterReqInput,
|
95
|
-
LoadLoRAAdapterReqOutput,
|
96
|
-
LoRAUpdateResult,
|
97
|
-
MultiTokenizerWarpper,
|
69
|
+
MultiTokenizerWrapper,
|
98
70
|
OpenSessionReqInput,
|
99
71
|
OpenSessionReqOutput,
|
100
|
-
ProfileReq,
|
101
|
-
ProfileReqOutput,
|
102
|
-
ProfileReqType,
|
103
|
-
ReleaseMemoryOccupationReqInput,
|
104
|
-
ReleaseMemoryOccupationReqOutput,
|
105
|
-
ResumeMemoryOccupationReqInput,
|
106
|
-
ResumeMemoryOccupationReqOutput,
|
107
72
|
SessionParams,
|
108
|
-
SetInternalStateReq,
|
109
|
-
SetInternalStateReqOutput,
|
110
|
-
SlowDownReqInput,
|
111
|
-
SlowDownReqOutput,
|
112
73
|
TokenizedEmbeddingReqInput,
|
113
74
|
TokenizedGenerateReqInput,
|
114
|
-
UnloadLoRAAdapterReqInput,
|
115
|
-
UnloadLoRAAdapterReqOutput,
|
116
75
|
UpdateWeightFromDiskReqInput,
|
117
76
|
UpdateWeightFromDiskReqOutput,
|
118
|
-
|
119
|
-
UpdateWeightsFromDistributedReqOutput,
|
120
|
-
UpdateWeightsFromTensorReqInput,
|
121
|
-
UpdateWeightsFromTensorReqOutput,
|
77
|
+
WatchLoadUpdateReq,
|
122
78
|
)
|
123
79
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
124
80
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
125
81
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
126
82
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
83
|
+
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
127
84
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
128
85
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
129
86
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
87
|
+
from sglang.srt.tracing.trace import (
|
88
|
+
trace_get_proc_propagate_context,
|
89
|
+
trace_req_finish,
|
90
|
+
trace_req_start,
|
91
|
+
trace_slice_end,
|
92
|
+
trace_slice_start,
|
93
|
+
)
|
130
94
|
from sglang.srt.utils import (
|
131
95
|
configure_gc_warning,
|
132
96
|
dataclass_to_string_truncated,
|
@@ -180,7 +144,7 @@ class ReqState:
|
|
180
144
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
181
145
|
|
182
146
|
|
183
|
-
class TokenizerManager:
|
147
|
+
class TokenizerManager(TokenizerCommunicatorMixin):
|
184
148
|
"""TokenizerManager is a process that tokenizes the text."""
|
185
149
|
|
186
150
|
def __init__(
|
@@ -262,6 +226,18 @@ class TokenizerManager:
|
|
262
226
|
trust_remote_code=server_args.trust_remote_code,
|
263
227
|
revision=server_args.revision,
|
264
228
|
)
|
229
|
+
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
230
|
+
if (
|
231
|
+
server_args.enable_dynamic_batch_tokenizer
|
232
|
+
and not server_args.skip_tokenizer_init
|
233
|
+
):
|
234
|
+
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
|
235
|
+
self.tokenizer,
|
236
|
+
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
|
237
|
+
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
|
238
|
+
)
|
239
|
+
else:
|
240
|
+
self.async_dynamic_batch_tokenizer = None
|
265
241
|
|
266
242
|
# Init inter-process communication
|
267
243
|
context = zmq.asyncio.Context(2)
|
@@ -319,8 +295,10 @@ class TokenizerManager:
|
|
319
295
|
# LoRA updates and inference to overlap.
|
320
296
|
self.lora_update_lock = asyncio.Lock()
|
321
297
|
|
322
|
-
|
323
|
-
|
298
|
+
self.disaggregation_mode = DisaggregationMode(
|
299
|
+
self.server_args.disaggregation_mode
|
300
|
+
)
|
301
|
+
self.bootstrap_server = start_disagg_service(self.server_args)
|
324
302
|
|
325
303
|
# For load balancing
|
326
304
|
self.current_load = 0
|
@@ -328,12 +306,16 @@ class TokenizerManager:
|
|
328
306
|
|
329
307
|
# Metrics
|
330
308
|
if self.enable_metrics:
|
309
|
+
labels = {
|
310
|
+
"model_name": self.server_args.served_model_name,
|
311
|
+
# TODO: Add lora name/path in the future,
|
312
|
+
}
|
313
|
+
if server_args.tokenizer_metrics_allowed_customer_labels:
|
314
|
+
for label in server_args.tokenizer_metrics_allowed_customer_labels:
|
315
|
+
labels[label] = ""
|
331
316
|
self.metrics_collector = TokenizerMetricsCollector(
|
332
317
|
server_args=server_args,
|
333
|
-
labels=
|
334
|
-
"model_name": self.server_args.served_model_name,
|
335
|
-
# TODO: Add lora name/path in the future,
|
336
|
-
},
|
318
|
+
labels=labels,
|
337
319
|
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
338
320
|
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
339
321
|
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
@@ -344,50 +326,6 @@ class TokenizerManager:
|
|
344
326
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
345
327
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
346
328
|
|
347
|
-
# Communicators
|
348
|
-
self.init_weights_update_group_communicator = _Communicator(
|
349
|
-
self.send_to_scheduler, server_args.dp_size
|
350
|
-
)
|
351
|
-
self.update_weights_from_distributed_communicator = _Communicator(
|
352
|
-
self.send_to_scheduler, server_args.dp_size
|
353
|
-
)
|
354
|
-
self.update_weights_from_tensor_communicator = _Communicator(
|
355
|
-
self.send_to_scheduler, server_args.dp_size
|
356
|
-
)
|
357
|
-
self.get_weights_by_name_communicator = _Communicator(
|
358
|
-
self.send_to_scheduler, server_args.dp_size
|
359
|
-
)
|
360
|
-
self.release_memory_occupation_communicator = _Communicator(
|
361
|
-
self.send_to_scheduler, server_args.dp_size
|
362
|
-
)
|
363
|
-
self.resume_memory_occupation_communicator = _Communicator(
|
364
|
-
self.send_to_scheduler, server_args.dp_size
|
365
|
-
)
|
366
|
-
self.slow_down_communicator = _Communicator(
|
367
|
-
self.send_to_scheduler, server_args.dp_size
|
368
|
-
)
|
369
|
-
self.flush_cache_communicator = _Communicator(
|
370
|
-
self.send_to_scheduler, server_args.dp_size
|
371
|
-
)
|
372
|
-
self.clear_hicache_storage_communicator = _Communicator(
|
373
|
-
self.send_to_scheduler, server_args.dp_size
|
374
|
-
)
|
375
|
-
self.profile_communicator = _Communicator(
|
376
|
-
self.send_to_scheduler, server_args.dp_size
|
377
|
-
)
|
378
|
-
self.get_internal_state_communicator = _Communicator(
|
379
|
-
self.send_to_scheduler, server_args.dp_size
|
380
|
-
)
|
381
|
-
self.set_internal_state_communicator = _Communicator(
|
382
|
-
self.send_to_scheduler, server_args.dp_size
|
383
|
-
)
|
384
|
-
self.expert_distribution_communicator = _Communicator(
|
385
|
-
self.send_to_scheduler, server_args.dp_size
|
386
|
-
)
|
387
|
-
self.update_lora_adapter_communicator = _Communicator(
|
388
|
-
self.send_to_scheduler, server_args.dp_size
|
389
|
-
)
|
390
|
-
|
391
329
|
self._result_dispatcher = TypeBasedDispatcher(
|
392
330
|
[
|
393
331
|
(
|
@@ -405,100 +343,15 @@ class TokenizerManager:
|
|
405
343
|
UpdateWeightFromDiskReqOutput,
|
406
344
|
self._handle_update_weights_from_disk_req_output,
|
407
345
|
),
|
408
|
-
(
|
409
|
-
InitWeightsUpdateGroupReqOutput,
|
410
|
-
self.init_weights_update_group_communicator.handle_recv,
|
411
|
-
),
|
412
|
-
(
|
413
|
-
UpdateWeightsFromDistributedReqOutput,
|
414
|
-
self.update_weights_from_distributed_communicator.handle_recv,
|
415
|
-
),
|
416
|
-
(
|
417
|
-
UpdateWeightsFromTensorReqOutput,
|
418
|
-
self.update_weights_from_tensor_communicator.handle_recv,
|
419
|
-
),
|
420
|
-
(
|
421
|
-
GetWeightsByNameReqOutput,
|
422
|
-
self.get_weights_by_name_communicator.handle_recv,
|
423
|
-
),
|
424
|
-
(
|
425
|
-
ReleaseMemoryOccupationReqOutput,
|
426
|
-
self.release_memory_occupation_communicator.handle_recv,
|
427
|
-
),
|
428
|
-
(
|
429
|
-
ResumeMemoryOccupationReqOutput,
|
430
|
-
self.resume_memory_occupation_communicator.handle_recv,
|
431
|
-
),
|
432
|
-
(
|
433
|
-
SlowDownReqOutput,
|
434
|
-
self.slow_down_communicator.handle_recv,
|
435
|
-
),
|
436
|
-
(
|
437
|
-
ClearHiCacheReqOutput,
|
438
|
-
self.clear_hicache_storage_communicator.handle_recv,
|
439
|
-
),
|
440
|
-
(
|
441
|
-
FlushCacheReqOutput,
|
442
|
-
self.flush_cache_communicator.handle_recv,
|
443
|
-
),
|
444
|
-
(
|
445
|
-
ProfileReqOutput,
|
446
|
-
self.profile_communicator.handle_recv,
|
447
|
-
),
|
448
346
|
(
|
449
347
|
FreezeGCReq,
|
450
348
|
lambda x: None,
|
451
349
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
452
|
-
(
|
453
|
-
GetInternalStateReqOutput,
|
454
|
-
self.get_internal_state_communicator.handle_recv,
|
455
|
-
),
|
456
|
-
(
|
457
|
-
SetInternalStateReqOutput,
|
458
|
-
self.set_internal_state_communicator.handle_recv,
|
459
|
-
),
|
460
|
-
(
|
461
|
-
ExpertDistributionReqOutput,
|
462
|
-
self.expert_distribution_communicator.handle_recv,
|
463
|
-
),
|
464
|
-
(
|
465
|
-
LoRAUpdateResult,
|
466
|
-
self.update_lora_adapter_communicator.handle_recv,
|
467
|
-
),
|
468
350
|
(HealthCheckOutput, lambda x: None),
|
469
351
|
]
|
470
352
|
)
|
471
353
|
|
472
|
-
|
473
|
-
self.disaggregation_mode = DisaggregationMode(
|
474
|
-
self.server_args.disaggregation_mode
|
475
|
-
)
|
476
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
477
|
-
self.server_args.disaggregation_transfer_backend
|
478
|
-
)
|
479
|
-
# Start kv boostrap server on prefill
|
480
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
481
|
-
# only start bootstrap server on prefill tm
|
482
|
-
kv_bootstrap_server_class = get_kv_class(
|
483
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
484
|
-
)
|
485
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
486
|
-
self.server_args.disaggregation_bootstrap_port
|
487
|
-
)
|
488
|
-
is_create_store = (
|
489
|
-
self.server_args.node_rank == 0
|
490
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
491
|
-
)
|
492
|
-
if is_create_store:
|
493
|
-
try:
|
494
|
-
from mf_adapter import create_config_store
|
495
|
-
|
496
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
497
|
-
create_config_store(ascend_url)
|
498
|
-
except Exception as e:
|
499
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
500
|
-
error_message += f" With exception {e}"
|
501
|
-
raise error_message
|
354
|
+
self.init_communicators(server_args)
|
502
355
|
|
503
356
|
async def generate_request(
|
504
357
|
self,
|
@@ -518,6 +371,24 @@ class TokenizerManager:
|
|
518
371
|
# If it's a single value, add worker_id prefix
|
519
372
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
520
373
|
|
374
|
+
if obj.is_single:
|
375
|
+
bootstrap_room = (
|
376
|
+
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
377
|
+
)
|
378
|
+
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
|
379
|
+
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
380
|
+
else:
|
381
|
+
for i in range(len(obj.rid)):
|
382
|
+
bootstrap_room = (
|
383
|
+
obj.bootstrap_room[i]
|
384
|
+
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
385
|
+
else None
|
386
|
+
)
|
387
|
+
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
|
388
|
+
trace_slice_start(
|
389
|
+
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
390
|
+
)
|
391
|
+
|
521
392
|
if self.log_requests:
|
522
393
|
max_length, skip_names, _ = self.log_request_metadata
|
523
394
|
logger.info(
|
@@ -543,6 +414,144 @@ class TokenizerManager:
|
|
543
414
|
):
|
544
415
|
yield response
|
545
416
|
|
417
|
+
def _detect_input_format(
|
418
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool
|
419
|
+
) -> str:
|
420
|
+
"""Detect the format of input texts for proper tokenization handling.
|
421
|
+
|
422
|
+
Returns:
|
423
|
+
- "single_string": Regular single text like "Hello world"
|
424
|
+
- "batch_strings": Regular batch like ["Hello", "World"]
|
425
|
+
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
|
426
|
+
"""
|
427
|
+
if isinstance(texts, str):
|
428
|
+
return "single_string"
|
429
|
+
|
430
|
+
if (
|
431
|
+
is_cross_encoder
|
432
|
+
and len(texts) > 0
|
433
|
+
and isinstance(texts[0], list)
|
434
|
+
and len(texts[0]) == 2
|
435
|
+
):
|
436
|
+
return "cross_encoder_pairs"
|
437
|
+
|
438
|
+
return "batch_strings"
|
439
|
+
|
440
|
+
def _prepare_tokenizer_input(
|
441
|
+
self, texts: Union[str, List[str]], input_format: str
|
442
|
+
) -> Union[List[str], List[List[str]]]:
|
443
|
+
"""Prepare input for the tokenizer based on detected format."""
|
444
|
+
if input_format == "single_string":
|
445
|
+
return [texts] # Wrap single string for batch processing
|
446
|
+
elif input_format == "cross_encoder_pairs":
|
447
|
+
return texts # Already in correct format: [["query", "doc"]]
|
448
|
+
else: # batch_strings
|
449
|
+
return texts # Already in correct format: ["text1", "text2"]
|
450
|
+
|
451
|
+
def _extract_tokenizer_results(
|
452
|
+
self,
|
453
|
+
input_ids: List[List[int]],
|
454
|
+
token_type_ids: Optional[List[List[int]]],
|
455
|
+
input_format: str,
|
456
|
+
original_batch_size: int,
|
457
|
+
) -> Union[
|
458
|
+
Tuple[List[int], Optional[List[int]]],
|
459
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
460
|
+
]:
|
461
|
+
"""Extract results from tokenizer output based on input format."""
|
462
|
+
|
463
|
+
# For single inputs (string or single cross-encoder pair), extract first element
|
464
|
+
if (
|
465
|
+
input_format in ["single_string", "cross_encoder_pairs"]
|
466
|
+
and original_batch_size == 1
|
467
|
+
):
|
468
|
+
single_input_ids = input_ids[0] if input_ids else []
|
469
|
+
single_token_type_ids = token_type_ids[0] if token_type_ids else None
|
470
|
+
return single_input_ids, single_token_type_ids
|
471
|
+
|
472
|
+
# For true batches, return as-is
|
473
|
+
return input_ids, token_type_ids
|
474
|
+
|
475
|
+
async def _tokenize_texts(
|
476
|
+
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
|
477
|
+
) -> Union[
|
478
|
+
Tuple[List[int], Optional[List[int]]],
|
479
|
+
Tuple[List[List[int]], Optional[List[List[int]]]],
|
480
|
+
]:
|
481
|
+
"""
|
482
|
+
Tokenize text(s) using the appropriate tokenizer strategy.
|
483
|
+
|
484
|
+
This method handles multiple input formats and chooses between async dynamic
|
485
|
+
batch tokenizer (for single texts only) and regular tokenizer.
|
486
|
+
|
487
|
+
Args:
|
488
|
+
texts: Text input in various formats:
|
489
|
+
|
490
|
+
Regular cases:
|
491
|
+
- Single string: "How are you?"
|
492
|
+
- Batch of strings: ["Hello", "World", "How are you?"]
|
493
|
+
|
494
|
+
Cross-encoder cases (sentence pairs for similarity/ranking):
|
495
|
+
- Single pair: [["query text", "document text"]]
|
496
|
+
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
497
|
+
|
498
|
+
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
|
499
|
+
Enables proper handling of sentence pairs with segment IDs.
|
500
|
+
|
501
|
+
Returns:
|
502
|
+
Single input cases:
|
503
|
+
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
|
504
|
+
Example: ([101, 2129, 102], [0, 0, 0]) for single text
|
505
|
+
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
|
506
|
+
|
507
|
+
Batch input cases:
|
508
|
+
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
|
509
|
+
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
|
510
|
+
|
511
|
+
Note: token_type_ids is None unless is_cross_encoder=True.
|
512
|
+
"""
|
513
|
+
if not texts or self.tokenizer is None:
|
514
|
+
raise ValueError("texts cannot be empty and tokenizer must be initialized")
|
515
|
+
|
516
|
+
# Step 1: Detect input format and prepare for tokenization
|
517
|
+
input_format = self._detect_input_format(texts, is_cross_encoder)
|
518
|
+
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
|
519
|
+
original_batch_size = len(texts) if not isinstance(texts, str) else 1
|
520
|
+
|
521
|
+
# Step 2: Set up tokenizer arguments
|
522
|
+
tokenizer_kwargs = (
|
523
|
+
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
|
524
|
+
)
|
525
|
+
|
526
|
+
# Step 3: Choose tokenization strategy
|
527
|
+
use_async_tokenizer = (
|
528
|
+
self.async_dynamic_batch_tokenizer is not None
|
529
|
+
and input_format == "single_string"
|
530
|
+
)
|
531
|
+
|
532
|
+
if use_async_tokenizer:
|
533
|
+
logger.debug("Using async dynamic batch tokenizer for single text")
|
534
|
+
result = await self.async_dynamic_batch_tokenizer.encode(
|
535
|
+
tokenizer_input[0], **tokenizer_kwargs
|
536
|
+
)
|
537
|
+
# Convert to batch format for consistency
|
538
|
+
input_ids = [result["input_ids"]]
|
539
|
+
token_type_ids = (
|
540
|
+
[result["token_type_ids"]]
|
541
|
+
if is_cross_encoder and result.get("token_type_ids")
|
542
|
+
else None
|
543
|
+
)
|
544
|
+
else:
|
545
|
+
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
|
546
|
+
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
|
547
|
+
input_ids = encoded["input_ids"]
|
548
|
+
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
|
549
|
+
|
550
|
+
# Step 4: Extract results based on input format
|
551
|
+
return self._extract_tokenizer_results(
|
552
|
+
input_ids, token_type_ids, input_format, original_batch_size
|
553
|
+
)
|
554
|
+
|
546
555
|
async def _tokenize_one_request(
|
547
556
|
self,
|
548
557
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -573,14 +582,10 @@ class TokenizerManager:
|
|
573
582
|
"accept text prompts. Please provide input_ids or re-initialize "
|
574
583
|
"the engine with skip_tokenizer_init=False."
|
575
584
|
)
|
576
|
-
encoded = self.tokenizer(
|
577
|
-
input_text, return_token_type_ids=is_cross_encoder_request
|
578
|
-
)
|
579
585
|
|
580
|
-
input_ids =
|
581
|
-
|
582
|
-
|
583
|
-
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
586
|
+
input_ids, token_type_ids = await self._tokenize_texts(
|
587
|
+
input_text, is_cross_encoder_request
|
588
|
+
)
|
584
589
|
|
585
590
|
if self.mm_processor and obj.contains_mm_input():
|
586
591
|
if not isinstance(obj.image_data, list):
|
@@ -600,6 +605,7 @@ class TokenizerManager:
|
|
600
605
|
mm_inputs = None
|
601
606
|
|
602
607
|
self._validate_one_request(obj, input_ids)
|
608
|
+
trace_slice_end("tokenize", obj.rid)
|
603
609
|
return self._create_tokenized_object(
|
604
610
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
605
611
|
)
|
@@ -674,7 +680,7 @@ class TokenizerManager:
|
|
674
680
|
):
|
675
681
|
raise ValueError(
|
676
682
|
"The server is not configured to enable custom logit processor. "
|
677
|
-
"Please set `--enable-custom-
|
683
|
+
"Please set `--enable-custom-logit-processor` to enable this feature."
|
678
684
|
)
|
679
685
|
|
680
686
|
def _validate_input_ids_in_vocab(
|
@@ -755,19 +761,30 @@ class TokenizerManager:
|
|
755
761
|
requests = [obj[i] for i in range(batch_size)]
|
756
762
|
texts = [req.text for req in requests]
|
757
763
|
|
758
|
-
#
|
759
|
-
|
760
|
-
|
764
|
+
# Check if any request is a cross-encoder request
|
765
|
+
is_cross_encoder_request = any(
|
766
|
+
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
|
767
|
+
for req in requests
|
768
|
+
)
|
769
|
+
|
770
|
+
# Batch tokenize all texts using unified method
|
771
|
+
input_ids_list, token_type_ids_list = await self._tokenize_texts(
|
772
|
+
texts, is_cross_encoder_request
|
773
|
+
)
|
761
774
|
|
762
775
|
# Process all requests
|
763
776
|
tokenized_objs = []
|
764
777
|
for i, req in enumerate(requests):
|
765
778
|
self._validate_one_request(obj[i], input_ids_list[i])
|
779
|
+
token_type_ids = (
|
780
|
+
token_type_ids_list[i] if token_type_ids_list is not None else None
|
781
|
+
)
|
766
782
|
tokenized_objs.append(
|
767
783
|
self._create_tokenized_object(
|
768
|
-
req, req.text, input_ids_list[i], None, None
|
784
|
+
req, req.text, input_ids_list[i], None, None, token_type_ids
|
769
785
|
)
|
770
786
|
)
|
787
|
+
trace_slice_end("tokenize", req.rid)
|
771
788
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
772
789
|
return tokenized_objs
|
773
790
|
|
@@ -795,9 +812,12 @@ class TokenizerManager:
|
|
795
812
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
796
813
|
created_time: Optional[float] = None,
|
797
814
|
):
|
815
|
+
trace_slice_start("dispatch", obj.rid)
|
816
|
+
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
|
798
817
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
799
818
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
800
819
|
self.rid_to_state[obj.rid] = state
|
820
|
+
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
|
801
821
|
return state
|
802
822
|
|
803
823
|
def _send_batch_request(
|
@@ -1015,74 +1035,14 @@ class TokenizerManager:
|
|
1015
1035
|
except StopAsyncIteration:
|
1016
1036
|
pass
|
1017
1037
|
|
1018
|
-
async def flush_cache(self) -> FlushCacheReqOutput:
|
1019
|
-
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
1020
|
-
|
1021
|
-
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1022
|
-
"""Clear the hierarchical cache storage."""
|
1023
|
-
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1024
|
-
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1025
|
-
0
|
1026
|
-
]
|
1027
|
-
|
1028
1038
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
1029
1039
|
if not abort_all and rid not in self.rid_to_state:
|
1030
1040
|
return
|
1031
1041
|
req = AbortReq(rid, abort_all)
|
1032
1042
|
self.send_to_scheduler.send_pyobj(req)
|
1033
|
-
|
1034
1043
|
if self.enable_metrics:
|
1035
1044
|
self.metrics_collector.observe_one_aborted_request()
|
1036
1045
|
|
1037
|
-
async def start_profile(
|
1038
|
-
self,
|
1039
|
-
output_dir: Optional[str] = None,
|
1040
|
-
start_step: Optional[int] = None,
|
1041
|
-
num_steps: Optional[int] = None,
|
1042
|
-
activities: Optional[List[str]] = None,
|
1043
|
-
with_stack: Optional[bool] = None,
|
1044
|
-
record_shapes: Optional[bool] = None,
|
1045
|
-
profile_by_stage: bool = False,
|
1046
|
-
):
|
1047
|
-
self.auto_create_handle_loop()
|
1048
|
-
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
1049
|
-
with_stack = False if with_stack is False or env_with_stack is False else True
|
1050
|
-
req = ProfileReq(
|
1051
|
-
type=ProfileReqType.START_PROFILE,
|
1052
|
-
output_dir=output_dir,
|
1053
|
-
start_step=start_step,
|
1054
|
-
num_steps=num_steps,
|
1055
|
-
activities=activities,
|
1056
|
-
with_stack=with_stack,
|
1057
|
-
record_shapes=record_shapes,
|
1058
|
-
profile_by_stage=profile_by_stage,
|
1059
|
-
profile_id=str(time.time()),
|
1060
|
-
)
|
1061
|
-
return await self._execute_profile(req)
|
1062
|
-
|
1063
|
-
async def stop_profile(self):
|
1064
|
-
self.auto_create_handle_loop()
|
1065
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
1066
|
-
return await self._execute_profile(req)
|
1067
|
-
|
1068
|
-
async def _execute_profile(self, req: ProfileReq):
|
1069
|
-
result = (await self.profile_communicator(req))[0]
|
1070
|
-
if not result.success:
|
1071
|
-
raise RuntimeError(result.message)
|
1072
|
-
return result
|
1073
|
-
|
1074
|
-
async def start_expert_distribution_record(self):
|
1075
|
-
self.auto_create_handle_loop()
|
1076
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
1077
|
-
|
1078
|
-
async def stop_expert_distribution_record(self):
|
1079
|
-
self.auto_create_handle_loop()
|
1080
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
1081
|
-
|
1082
|
-
async def dump_expert_distribution_record(self):
|
1083
|
-
self.auto_create_handle_loop()
|
1084
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
1085
|
-
|
1086
1046
|
async def pause_generation(self):
|
1087
1047
|
async with self.is_pause_cond:
|
1088
1048
|
self.is_pause = True
|
@@ -1118,7 +1078,7 @@ class TokenizerManager:
|
|
1118
1078
|
self, obj: UpdateWeightFromDiskReqInput
|
1119
1079
|
) -> Tuple[bool, str]:
|
1120
1080
|
if self.server_args.tokenizer_worker_num > 1:
|
1121
|
-
obj =
|
1081
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1122
1082
|
self.send_to_scheduler.send_pyobj(obj)
|
1123
1083
|
self.model_update_result = asyncio.Future()
|
1124
1084
|
if self.server_args.dp_size == 1:
|
@@ -1143,191 +1103,6 @@ class TokenizerManager:
|
|
1143
1103
|
all_paused_requests = [r.num_paused_requests for r in result]
|
1144
1104
|
return all_success, all_message, all_paused_requests
|
1145
1105
|
|
1146
|
-
async def init_weights_update_group(
|
1147
|
-
self,
|
1148
|
-
obj: InitWeightsUpdateGroupReqInput,
|
1149
|
-
request: Optional[fastapi.Request] = None,
|
1150
|
-
) -> Tuple[bool, str]:
|
1151
|
-
self.auto_create_handle_loop()
|
1152
|
-
assert (
|
1153
|
-
self.server_args.dp_size == 1
|
1154
|
-
), "dp_size must be 1 for init parameter update group"
|
1155
|
-
result = (await self.init_weights_update_group_communicator(obj))[0]
|
1156
|
-
return result.success, result.message
|
1157
|
-
|
1158
|
-
async def update_weights_from_distributed(
|
1159
|
-
self,
|
1160
|
-
obj: UpdateWeightsFromDistributedReqInput,
|
1161
|
-
request: Optional[fastapi.Request] = None,
|
1162
|
-
) -> Tuple[bool, str]:
|
1163
|
-
self.auto_create_handle_loop()
|
1164
|
-
assert (
|
1165
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1166
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
1167
|
-
|
1168
|
-
if obj.abort_all_requests:
|
1169
|
-
self.abort_request(abort_all=True)
|
1170
|
-
|
1171
|
-
# This means that weight sync
|
1172
|
-
# cannot run while requests are in progress.
|
1173
|
-
async with self.model_update_lock.writer_lock:
|
1174
|
-
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
1175
|
-
return result.success, result.message
|
1176
|
-
|
1177
|
-
async def update_weights_from_tensor(
|
1178
|
-
self,
|
1179
|
-
obj: UpdateWeightsFromTensorReqInput,
|
1180
|
-
request: Optional[fastapi.Request] = None,
|
1181
|
-
) -> Tuple[bool, str]:
|
1182
|
-
self.auto_create_handle_loop()
|
1183
|
-
assert (
|
1184
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1185
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
1186
|
-
|
1187
|
-
if obj.abort_all_requests:
|
1188
|
-
self.abort_request(abort_all=True)
|
1189
|
-
|
1190
|
-
# This means that weight sync
|
1191
|
-
# cannot run while requests are in progress.
|
1192
|
-
async with self.model_update_lock.writer_lock:
|
1193
|
-
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
1194
|
-
return result.success, result.message
|
1195
|
-
|
1196
|
-
async def load_lora_adapter(
|
1197
|
-
self,
|
1198
|
-
obj: LoadLoRAAdapterReqInput,
|
1199
|
-
_: Optional[fastapi.Request] = None,
|
1200
|
-
) -> LoadLoRAAdapterReqOutput:
|
1201
|
-
self.auto_create_handle_loop()
|
1202
|
-
|
1203
|
-
try:
|
1204
|
-
if not self.server_args.enable_lora:
|
1205
|
-
raise ValueError(
|
1206
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1207
|
-
)
|
1208
|
-
|
1209
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1210
|
-
# with dp_size > 1.
|
1211
|
-
assert (
|
1212
|
-
self.server_args.dp_size == 1
|
1213
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1214
|
-
logger.info(
|
1215
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1216
|
-
obj.lora_name,
|
1217
|
-
obj.lora_path,
|
1218
|
-
)
|
1219
|
-
|
1220
|
-
async with self.lora_update_lock:
|
1221
|
-
if (
|
1222
|
-
self.server_args.max_loaded_loras is not None
|
1223
|
-
and self.lora_registry.num_registered_loras
|
1224
|
-
>= self.server_args.max_loaded_loras
|
1225
|
-
):
|
1226
|
-
raise ValueError(
|
1227
|
-
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1228
|
-
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1229
|
-
"Please unload some LoRA adapters before loading new ones."
|
1230
|
-
)
|
1231
|
-
|
1232
|
-
# Generate new uniquely identifiable LoRARef object.
|
1233
|
-
new_adapter = LoRARef(
|
1234
|
-
lora_name=obj.lora_name,
|
1235
|
-
lora_path=obj.lora_path,
|
1236
|
-
pinned=obj.pinned,
|
1237
|
-
)
|
1238
|
-
|
1239
|
-
# Trigger the actual loading operation at the backend processes.
|
1240
|
-
obj.lora_id = new_adapter.lora_id
|
1241
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1242
|
-
|
1243
|
-
# Register the LoRA adapter only after loading is successful.
|
1244
|
-
if result.success:
|
1245
|
-
await self.lora_registry.register(new_adapter)
|
1246
|
-
|
1247
|
-
return result
|
1248
|
-
except ValueError as e:
|
1249
|
-
return LoadLoRAAdapterReqOutput(
|
1250
|
-
success=False,
|
1251
|
-
error_message=str(e),
|
1252
|
-
)
|
1253
|
-
|
1254
|
-
async def unload_lora_adapter(
|
1255
|
-
self,
|
1256
|
-
obj: UnloadLoRAAdapterReqInput,
|
1257
|
-
_: Optional[fastapi.Request] = None,
|
1258
|
-
) -> UnloadLoRAAdapterReqOutput:
|
1259
|
-
self.auto_create_handle_loop()
|
1260
|
-
|
1261
|
-
try:
|
1262
|
-
if not self.server_args.enable_lora:
|
1263
|
-
raise ValueError(
|
1264
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1265
|
-
)
|
1266
|
-
|
1267
|
-
assert (
|
1268
|
-
obj.lora_name is not None
|
1269
|
-
), "lora_name must be provided to unload LoRA adapter"
|
1270
|
-
|
1271
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1272
|
-
# with dp_size > 1.
|
1273
|
-
assert (
|
1274
|
-
self.server_args.dp_size == 1
|
1275
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1276
|
-
logger.info(
|
1277
|
-
"Start unload Lora adapter. Lora name=%s",
|
1278
|
-
obj.lora_name,
|
1279
|
-
)
|
1280
|
-
|
1281
|
-
async with self.lora_update_lock:
|
1282
|
-
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1283
|
-
# from being started.
|
1284
|
-
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1285
|
-
obj.lora_id = lora_id
|
1286
|
-
|
1287
|
-
# Initiate the actual unloading operation at the backend processes only after all
|
1288
|
-
# ongoing requests using this LoRA adapter are finished.
|
1289
|
-
await self.lora_registry.wait_for_unload(lora_id)
|
1290
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1291
|
-
|
1292
|
-
return result
|
1293
|
-
except ValueError as e:
|
1294
|
-
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1295
|
-
|
1296
|
-
async def get_weights_by_name(
|
1297
|
-
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
1298
|
-
):
|
1299
|
-
self.auto_create_handle_loop()
|
1300
|
-
results = await self.get_weights_by_name_communicator(obj)
|
1301
|
-
all_parameters = [r.parameter for r in results]
|
1302
|
-
if self.server_args.dp_size == 1:
|
1303
|
-
return all_parameters[0]
|
1304
|
-
else:
|
1305
|
-
return all_parameters
|
1306
|
-
|
1307
|
-
async def release_memory_occupation(
|
1308
|
-
self,
|
1309
|
-
obj: ReleaseMemoryOccupationReqInput,
|
1310
|
-
request: Optional[fastapi.Request] = None,
|
1311
|
-
):
|
1312
|
-
self.auto_create_handle_loop()
|
1313
|
-
await self.release_memory_occupation_communicator(obj)
|
1314
|
-
|
1315
|
-
async def resume_memory_occupation(
|
1316
|
-
self,
|
1317
|
-
obj: ResumeMemoryOccupationReqInput,
|
1318
|
-
request: Optional[fastapi.Request] = None,
|
1319
|
-
):
|
1320
|
-
self.auto_create_handle_loop()
|
1321
|
-
await self.resume_memory_occupation_communicator(obj)
|
1322
|
-
|
1323
|
-
async def slow_down(
|
1324
|
-
self,
|
1325
|
-
obj: SlowDownReqInput,
|
1326
|
-
request: Optional[fastapi.Request] = None,
|
1327
|
-
):
|
1328
|
-
self.auto_create_handle_loop()
|
1329
|
-
await self.slow_down_communicator(obj)
|
1330
|
-
|
1331
1106
|
async def open_session(
|
1332
1107
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
1333
1108
|
):
|
@@ -1339,7 +1114,7 @@ class TokenizerManager:
|
|
1339
1114
|
return None
|
1340
1115
|
|
1341
1116
|
if self.server_args.tokenizer_worker_num > 1:
|
1342
|
-
obj =
|
1117
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1343
1118
|
self.send_to_scheduler.send_pyobj(obj)
|
1344
1119
|
|
1345
1120
|
self.session_futures[obj.session_id] = asyncio.Future()
|
@@ -1352,28 +1127,6 @@ class TokenizerManager:
|
|
1352
1127
|
):
|
1353
1128
|
await self.send_to_scheduler.send_pyobj(obj)
|
1354
1129
|
|
1355
|
-
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
1356
|
-
req = GetInternalStateReq()
|
1357
|
-
responses: List[GetInternalStateReqOutput] = (
|
1358
|
-
await self.get_internal_state_communicator(req)
|
1359
|
-
)
|
1360
|
-
# Many DP ranks
|
1361
|
-
return [res.internal_state for res in responses]
|
1362
|
-
|
1363
|
-
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1364
|
-
responses: List[SetInternalStateReqOutput] = (
|
1365
|
-
await self.set_internal_state_communicator(obj)
|
1366
|
-
)
|
1367
|
-
return [res.updated for res in responses]
|
1368
|
-
|
1369
|
-
async def get_load(self) -> dict:
|
1370
|
-
# TODO(lsyin): fake load report server
|
1371
|
-
if not self.current_load_lock.locked():
|
1372
|
-
async with self.current_load_lock:
|
1373
|
-
internal_state = await self.get_internal_state()
|
1374
|
-
self.current_load = internal_state[0]["load"]
|
1375
|
-
return {"load": self.current_load}
|
1376
|
-
|
1377
1130
|
def get_log_request_metadata(self):
|
1378
1131
|
max_length = None
|
1379
1132
|
skip_names = None
|
@@ -1492,6 +1245,9 @@ class TokenizerManager:
|
|
1492
1245
|
self.asyncio_tasks.add(
|
1493
1246
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
1494
1247
|
)
|
1248
|
+
self.asyncio_tasks.add(
|
1249
|
+
loop.create_task(print_exception_wrapper(self.watch_load_thread))
|
1250
|
+
)
|
1495
1251
|
|
1496
1252
|
def dump_requests_before_crash(self):
|
1497
1253
|
if self.crash_dump_performed:
|
@@ -1711,6 +1467,9 @@ class TokenizerManager:
|
|
1711
1467
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1712
1468
|
state.finished_time = time.time()
|
1713
1469
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1470
|
+
|
1471
|
+
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
|
1472
|
+
|
1714
1473
|
del self.rid_to_state[rid]
|
1715
1474
|
|
1716
1475
|
# Mark ongoing LoRA request as finished.
|
@@ -1860,6 +1619,12 @@ class TokenizerManager:
|
|
1860
1619
|
else 0
|
1861
1620
|
)
|
1862
1621
|
|
1622
|
+
customer_labels = getattr(state.obj, "customer_labels", None)
|
1623
|
+
labels = (
|
1624
|
+
{**self.metrics_collector.labels, **customer_labels}
|
1625
|
+
if customer_labels
|
1626
|
+
else self.metrics_collector.labels
|
1627
|
+
)
|
1863
1628
|
if (
|
1864
1629
|
state.first_token_time == 0.0
|
1865
1630
|
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
@@ -1867,7 +1632,7 @@ class TokenizerManager:
|
|
1867
1632
|
state.first_token_time = state.last_time = time.time()
|
1868
1633
|
state.last_completion_tokens = completion_tokens
|
1869
1634
|
self.metrics_collector.observe_time_to_first_token(
|
1870
|
-
state.first_token_time - state.created_time
|
1635
|
+
labels, state.first_token_time - state.created_time
|
1871
1636
|
)
|
1872
1637
|
else:
|
1873
1638
|
num_new_tokens = completion_tokens - state.last_completion_tokens
|
@@ -1875,6 +1640,7 @@ class TokenizerManager:
|
|
1875
1640
|
new_time = time.time()
|
1876
1641
|
interval = new_time - state.last_time
|
1877
1642
|
self.metrics_collector.observe_inter_token_latency(
|
1643
|
+
labels,
|
1878
1644
|
interval,
|
1879
1645
|
num_new_tokens,
|
1880
1646
|
)
|
@@ -1889,6 +1655,7 @@ class TokenizerManager:
|
|
1889
1655
|
or state.obj.sampling_params.get("structural_tag", None)
|
1890
1656
|
)
|
1891
1657
|
self.metrics_collector.observe_one_finished_request(
|
1658
|
+
labels,
|
1892
1659
|
recv_obj.prompt_tokens[i],
|
1893
1660
|
completion_tokens,
|
1894
1661
|
recv_obj.cached_tokens[i],
|
@@ -2060,11 +1827,15 @@ class TokenizerManager:
|
|
2060
1827
|
# the next position after the last token in the prompt
|
2061
1828
|
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
2062
1829
|
|
2063
|
-
#
|
2064
|
-
if
|
1830
|
+
# Check if output_logprobs is properly populated
|
1831
|
+
if (
|
1832
|
+
output_logprobs is None
|
1833
|
+
or not output_logprobs
|
1834
|
+
or len(output_logprobs) == 0
|
1835
|
+
):
|
2065
1836
|
raise RuntimeError(
|
2066
|
-
f"output_logprobs is
|
2067
|
-
"This
|
1837
|
+
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
|
1838
|
+
"This indicates token_ids_logprobs were not computed properly for the scoring request."
|
2068
1839
|
)
|
2069
1840
|
|
2070
1841
|
for logprob, token_id, _ in output_logprobs[0]:
|
@@ -2089,6 +1860,20 @@ class TokenizerManager:
|
|
2089
1860
|
|
2090
1861
|
return scores
|
2091
1862
|
|
1863
|
+
async def watch_load_thread(self):
|
1864
|
+
# Only for dp_controller when dp_size > 1
|
1865
|
+
if (
|
1866
|
+
self.server_args.dp_size == 1
|
1867
|
+
or self.server_args.load_balance_method == "round_robin"
|
1868
|
+
):
|
1869
|
+
return
|
1870
|
+
|
1871
|
+
while True:
|
1872
|
+
await asyncio.sleep(self.server_args.load_watch_interval)
|
1873
|
+
loads = await self.get_load_communicator(GetLoadReqInput())
|
1874
|
+
load_udpate_req = WatchLoadUpdateReq(loads=loads)
|
1875
|
+
self.send_to_scheduler.send_pyobj(load_udpate_req)
|
1876
|
+
|
2092
1877
|
|
2093
1878
|
class ServerStatus(Enum):
|
2094
1879
|
Up = "Up"
|
@@ -2140,51 +1925,6 @@ class SignalHandler:
|
|
2140
1925
|
kill_process_tree(os.getpid())
|
2141
1926
|
|
2142
1927
|
|
2143
|
-
T = TypeVar("T")
|
2144
|
-
|
2145
|
-
|
2146
|
-
class _Communicator(Generic[T]):
|
2147
|
-
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2148
|
-
|
2149
|
-
enable_multi_tokenizer = False
|
2150
|
-
|
2151
|
-
def __init__(self, sender, fan_out: int):
|
2152
|
-
self._sender = sender
|
2153
|
-
self._fan_out = fan_out
|
2154
|
-
self._result_event: Optional[asyncio.Event] = None
|
2155
|
-
self._result_values: Optional[List[T]] = None
|
2156
|
-
self._ready_queue: Deque[asyncio.Future] = deque()
|
2157
|
-
|
2158
|
-
async def __call__(self, obj):
|
2159
|
-
ready_event = asyncio.Event()
|
2160
|
-
if self._result_event is not None or len(self._ready_queue) > 0:
|
2161
|
-
self._ready_queue.append(ready_event)
|
2162
|
-
await ready_event.wait()
|
2163
|
-
assert self._result_event is None
|
2164
|
-
assert self._result_values is None
|
2165
|
-
|
2166
|
-
if obj:
|
2167
|
-
if _Communicator.enable_multi_tokenizer:
|
2168
|
-
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2169
|
-
self._sender.send_pyobj(obj)
|
2170
|
-
|
2171
|
-
self._result_event = asyncio.Event()
|
2172
|
-
self._result_values = []
|
2173
|
-
await self._result_event.wait()
|
2174
|
-
result_values = self._result_values
|
2175
|
-
self._result_event = self._result_values = None
|
2176
|
-
|
2177
|
-
if len(self._ready_queue) > 0:
|
2178
|
-
self._ready_queue.popleft().set()
|
2179
|
-
|
2180
|
-
return result_values
|
2181
|
-
|
2182
|
-
def handle_recv(self, recv_obj: T):
|
2183
|
-
self._result_values.append(recv_obj)
|
2184
|
-
if len(self._result_values) == self._fan_out:
|
2185
|
-
self._result_event.set()
|
2186
|
-
|
2187
|
-
|
2188
1928
|
# Note: request abort handling logic
|
2189
1929
|
# We should handle all of the following cases correctly.
|
2190
1930
|
#
|