sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
|
|
31
31
|
from datetime import datetime
|
32
32
|
from enum import Enum
|
33
33
|
from http import HTTPStatus
|
34
|
-
from typing import
|
35
|
-
Any,
|
36
|
-
Awaitable,
|
37
|
-
Deque,
|
38
|
-
Dict,
|
39
|
-
Generic,
|
40
|
-
List,
|
41
|
-
Optional,
|
42
|
-
Tuple,
|
43
|
-
TypeVar,
|
44
|
-
Union,
|
45
|
-
)
|
34
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
46
35
|
|
47
36
|
import fastapi
|
48
37
|
import torch
|
@@ -53,72 +42,42 @@ from fastapi import BackgroundTasks
|
|
53
42
|
|
54
43
|
from sglang.srt.aio_rwlock import RWLock
|
55
44
|
from sglang.srt.configs.model_config import ModelConfig
|
56
|
-
from sglang.srt.disaggregation.utils import
|
57
|
-
DisaggregationMode,
|
58
|
-
KVClassType,
|
59
|
-
TransferBackend,
|
60
|
-
get_kv_class,
|
61
|
-
)
|
45
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
62
46
|
from sglang.srt.hf_transformers_utils import (
|
63
47
|
get_processor,
|
64
48
|
get_tokenizer,
|
65
49
|
get_tokenizer_from_processor,
|
66
50
|
)
|
67
51
|
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
52
|
+
from sglang.srt.managers.disagg_service import start_disagg_service
|
68
53
|
from sglang.srt.managers.io_struct import (
|
69
54
|
AbortReq,
|
70
55
|
BatchEmbeddingOut,
|
71
56
|
BatchMultimodalOut,
|
72
57
|
BatchStrOut,
|
73
58
|
BatchTokenIDOut,
|
59
|
+
BatchTokenizedEmbeddingReqInput,
|
60
|
+
BatchTokenizedGenerateReqInput,
|
74
61
|
CloseSessionReqInput,
|
75
62
|
ConfigureLoggingReq,
|
76
63
|
EmbeddingReqInput,
|
77
|
-
ExpertDistributionReq,
|
78
|
-
ExpertDistributionReqOutput,
|
79
|
-
FlushCacheReqInput,
|
80
|
-
FlushCacheReqOutput,
|
81
64
|
FreezeGCReq,
|
82
65
|
GenerateReqInput,
|
83
|
-
GetInternalStateReq,
|
84
|
-
GetInternalStateReqOutput,
|
85
|
-
GetWeightsByNameReqInput,
|
86
|
-
GetWeightsByNameReqOutput,
|
87
66
|
HealthCheckOutput,
|
88
|
-
|
89
|
-
InitWeightsUpdateGroupReqOutput,
|
90
|
-
LoadLoRAAdapterReqInput,
|
91
|
-
LoadLoRAAdapterReqOutput,
|
92
|
-
LoRAUpdateResult,
|
67
|
+
MultiTokenizerWrapper,
|
93
68
|
OpenSessionReqInput,
|
94
69
|
OpenSessionReqOutput,
|
95
|
-
ProfileReq,
|
96
|
-
ProfileReqOutput,
|
97
|
-
ProfileReqType,
|
98
|
-
ReleaseMemoryOccupationReqInput,
|
99
|
-
ReleaseMemoryOccupationReqOutput,
|
100
|
-
ResumeMemoryOccupationReqInput,
|
101
|
-
ResumeMemoryOccupationReqOutput,
|
102
70
|
SessionParams,
|
103
|
-
SetInternalStateReq,
|
104
|
-
SetInternalStateReqOutput,
|
105
|
-
SlowDownReqInput,
|
106
|
-
SlowDownReqOutput,
|
107
71
|
TokenizedEmbeddingReqInput,
|
108
72
|
TokenizedGenerateReqInput,
|
109
|
-
UnloadLoRAAdapterReqInput,
|
110
|
-
UnloadLoRAAdapterReqOutput,
|
111
73
|
UpdateWeightFromDiskReqInput,
|
112
74
|
UpdateWeightFromDiskReqOutput,
|
113
|
-
UpdateWeightsFromDistributedReqInput,
|
114
|
-
UpdateWeightsFromDistributedReqOutput,
|
115
|
-
UpdateWeightsFromTensorReqInput,
|
116
|
-
UpdateWeightsFromTensorReqOutput,
|
117
75
|
)
|
118
76
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
119
77
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
120
78
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
121
79
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
80
|
+
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
122
81
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
123
82
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
124
83
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -127,6 +86,7 @@ from sglang.srt.utils import (
|
|
127
86
|
dataclass_to_string_truncated,
|
128
87
|
freeze_gc,
|
129
88
|
get_bool_env_var,
|
89
|
+
get_origin_rid,
|
130
90
|
get_zmq_socket,
|
131
91
|
kill_process_tree,
|
132
92
|
)
|
@@ -174,7 +134,7 @@ class ReqState:
|
|
174
134
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
175
135
|
|
176
136
|
|
177
|
-
class TokenizerManager:
|
137
|
+
class TokenizerManager(TokenizerCommunicatorMixin):
|
178
138
|
"""TokenizerManager is a process that tokenizes the text."""
|
179
139
|
|
180
140
|
def __init__(
|
@@ -262,9 +222,15 @@ class TokenizerManager:
|
|
262
222
|
self.recv_from_detokenizer = get_zmq_socket(
|
263
223
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
264
224
|
)
|
265
|
-
self.
|
266
|
-
|
267
|
-
|
225
|
+
if self.server_args.tokenizer_worker_num > 1:
|
226
|
+
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
227
|
+
self.send_to_scheduler = get_zmq_socket(
|
228
|
+
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
229
|
+
)
|
230
|
+
else:
|
231
|
+
self.send_to_scheduler = get_zmq_socket(
|
232
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
233
|
+
)
|
268
234
|
|
269
235
|
# Request states
|
270
236
|
self.no_create_loop = False
|
@@ -307,36 +273,10 @@ class TokenizerManager:
|
|
307
273
|
# LoRA updates and inference to overlap.
|
308
274
|
self.lora_update_lock = asyncio.Lock()
|
309
275
|
|
310
|
-
# For PD disaggregtion
|
311
276
|
self.disaggregation_mode = DisaggregationMode(
|
312
277
|
self.server_args.disaggregation_mode
|
313
278
|
)
|
314
|
-
self.
|
315
|
-
self.server_args.disaggregation_transfer_backend
|
316
|
-
)
|
317
|
-
# Start kv boostrap server on prefill
|
318
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
319
|
-
# only start bootstrap server on prefill tm
|
320
|
-
kv_bootstrap_server_class = get_kv_class(
|
321
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
322
|
-
)
|
323
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
324
|
-
self.server_args.disaggregation_bootstrap_port
|
325
|
-
)
|
326
|
-
is_create_store = (
|
327
|
-
self.server_args.node_rank == 0
|
328
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
329
|
-
)
|
330
|
-
if is_create_store:
|
331
|
-
try:
|
332
|
-
from mf_adapter import create_config_store
|
333
|
-
|
334
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
335
|
-
create_config_store(ascend_url)
|
336
|
-
except Exception as e:
|
337
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
338
|
-
error_message += f" With exception {e}"
|
339
|
-
raise error_message
|
279
|
+
self.bootstrap_server = start_disagg_service(self.server_args)
|
340
280
|
|
341
281
|
# For load balancing
|
342
282
|
self.current_load = 0
|
@@ -345,6 +285,7 @@ class TokenizerManager:
|
|
345
285
|
# Metrics
|
346
286
|
if self.enable_metrics:
|
347
287
|
self.metrics_collector = TokenizerMetricsCollector(
|
288
|
+
server_args=server_args,
|
348
289
|
labels={
|
349
290
|
"model_name": self.server_args.served_model_name,
|
350
291
|
# TODO: Add lora name/path in the future,
|
@@ -359,47 +300,6 @@ class TokenizerManager:
|
|
359
300
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
360
301
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
361
302
|
|
362
|
-
# Communicators
|
363
|
-
self.init_weights_update_group_communicator = _Communicator(
|
364
|
-
self.send_to_scheduler, server_args.dp_size
|
365
|
-
)
|
366
|
-
self.update_weights_from_distributed_communicator = _Communicator(
|
367
|
-
self.send_to_scheduler, server_args.dp_size
|
368
|
-
)
|
369
|
-
self.update_weights_from_tensor_communicator = _Communicator(
|
370
|
-
self.send_to_scheduler, server_args.dp_size
|
371
|
-
)
|
372
|
-
self.get_weights_by_name_communicator = _Communicator(
|
373
|
-
self.send_to_scheduler, server_args.dp_size
|
374
|
-
)
|
375
|
-
self.release_memory_occupation_communicator = _Communicator(
|
376
|
-
self.send_to_scheduler, server_args.dp_size
|
377
|
-
)
|
378
|
-
self.resume_memory_occupation_communicator = _Communicator(
|
379
|
-
self.send_to_scheduler, server_args.dp_size
|
380
|
-
)
|
381
|
-
self.slow_down_communicator = _Communicator(
|
382
|
-
self.send_to_scheduler, server_args.dp_size
|
383
|
-
)
|
384
|
-
self.flush_cache_communicator = _Communicator(
|
385
|
-
self.send_to_scheduler, server_args.dp_size
|
386
|
-
)
|
387
|
-
self.profile_communicator = _Communicator(
|
388
|
-
self.send_to_scheduler, server_args.dp_size
|
389
|
-
)
|
390
|
-
self.get_internal_state_communicator = _Communicator(
|
391
|
-
self.send_to_scheduler, server_args.dp_size
|
392
|
-
)
|
393
|
-
self.set_internal_state_communicator = _Communicator(
|
394
|
-
self.send_to_scheduler, server_args.dp_size
|
395
|
-
)
|
396
|
-
self.expert_distribution_communicator = _Communicator(
|
397
|
-
self.send_to_scheduler, server_args.dp_size
|
398
|
-
)
|
399
|
-
self.update_lora_adapter_communicator = _Communicator(
|
400
|
-
self.send_to_scheduler, server_args.dp_size
|
401
|
-
)
|
402
|
-
|
403
303
|
self._result_dispatcher = TypeBasedDispatcher(
|
404
304
|
[
|
405
305
|
(
|
@@ -417,66 +317,16 @@ class TokenizerManager:
|
|
417
317
|
UpdateWeightFromDiskReqOutput,
|
418
318
|
self._handle_update_weights_from_disk_req_output,
|
419
319
|
),
|
420
|
-
(
|
421
|
-
InitWeightsUpdateGroupReqOutput,
|
422
|
-
self.init_weights_update_group_communicator.handle_recv,
|
423
|
-
),
|
424
|
-
(
|
425
|
-
UpdateWeightsFromDistributedReqOutput,
|
426
|
-
self.update_weights_from_distributed_communicator.handle_recv,
|
427
|
-
),
|
428
|
-
(
|
429
|
-
UpdateWeightsFromTensorReqOutput,
|
430
|
-
self.update_weights_from_tensor_communicator.handle_recv,
|
431
|
-
),
|
432
|
-
(
|
433
|
-
GetWeightsByNameReqOutput,
|
434
|
-
self.get_weights_by_name_communicator.handle_recv,
|
435
|
-
),
|
436
|
-
(
|
437
|
-
ReleaseMemoryOccupationReqOutput,
|
438
|
-
self.release_memory_occupation_communicator.handle_recv,
|
439
|
-
),
|
440
|
-
(
|
441
|
-
ResumeMemoryOccupationReqOutput,
|
442
|
-
self.resume_memory_occupation_communicator.handle_recv,
|
443
|
-
),
|
444
|
-
(
|
445
|
-
SlowDownReqOutput,
|
446
|
-
self.slow_down_communicator.handle_recv,
|
447
|
-
),
|
448
|
-
(
|
449
|
-
FlushCacheReqOutput,
|
450
|
-
self.flush_cache_communicator.handle_recv,
|
451
|
-
),
|
452
|
-
(
|
453
|
-
ProfileReqOutput,
|
454
|
-
self.profile_communicator.handle_recv,
|
455
|
-
),
|
456
320
|
(
|
457
321
|
FreezeGCReq,
|
458
322
|
lambda x: None,
|
459
323
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
460
|
-
(
|
461
|
-
GetInternalStateReqOutput,
|
462
|
-
self.get_internal_state_communicator.handle_recv,
|
463
|
-
),
|
464
|
-
(
|
465
|
-
SetInternalStateReqOutput,
|
466
|
-
self.set_internal_state_communicator.handle_recv,
|
467
|
-
),
|
468
|
-
(
|
469
|
-
ExpertDistributionReqOutput,
|
470
|
-
self.expert_distribution_communicator.handle_recv,
|
471
|
-
),
|
472
|
-
(
|
473
|
-
LoRAUpdateResult,
|
474
|
-
self.update_lora_adapter_communicator.handle_recv,
|
475
|
-
),
|
476
324
|
(HealthCheckOutput, lambda x: None),
|
477
325
|
]
|
478
326
|
)
|
479
327
|
|
328
|
+
self.init_communicators(server_args)
|
329
|
+
|
480
330
|
async def generate_request(
|
481
331
|
self,
|
482
332
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -486,6 +336,15 @@ class TokenizerManager:
|
|
486
336
|
self.auto_create_handle_loop()
|
487
337
|
obj.normalize_batch_and_arguments()
|
488
338
|
|
339
|
+
if self.server_args.tokenizer_worker_num > 1:
|
340
|
+
# Modify rid, add worker_id
|
341
|
+
if isinstance(obj.rid, list):
|
342
|
+
# If it's an array, add worker_id prefix to each element
|
343
|
+
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
344
|
+
else:
|
345
|
+
# If it's a single value, add worker_id prefix
|
346
|
+
obj.rid = f"{self.worker_id}_{obj.rid}"
|
347
|
+
|
489
348
|
if self.log_requests:
|
490
349
|
max_length, skip_names, _ = self.log_request_metadata
|
491
350
|
logger.info(
|
@@ -768,6 +627,30 @@ class TokenizerManager:
|
|
768
627
|
self.rid_to_state[obj.rid] = state
|
769
628
|
return state
|
770
629
|
|
630
|
+
def _send_batch_request(
|
631
|
+
self,
|
632
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
633
|
+
tokenized_objs: List[
|
634
|
+
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
635
|
+
],
|
636
|
+
created_time: Optional[float] = None,
|
637
|
+
):
|
638
|
+
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
|
639
|
+
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
|
640
|
+
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
|
641
|
+
else:
|
642
|
+
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
|
643
|
+
|
644
|
+
self.send_to_scheduler.send_pyobj(batch_req)
|
645
|
+
|
646
|
+
# Create states for each individual request in the batch
|
647
|
+
for i, tokenized_obj in enumerate(tokenized_objs):
|
648
|
+
tmp_obj = obj[i]
|
649
|
+
state = ReqState(
|
650
|
+
[], False, asyncio.Event(), tmp_obj, created_time=created_time
|
651
|
+
)
|
652
|
+
self.rid_to_state[tmp_obj.rid] = state
|
653
|
+
|
771
654
|
async def _wait_one_response(
|
772
655
|
self,
|
773
656
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -870,10 +753,17 @@ class TokenizerManager:
|
|
870
753
|
|
871
754
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
872
755
|
|
873
|
-
|
756
|
+
# Send as a single batched request
|
757
|
+
self._send_batch_request(obj, tokenized_objs, created_time)
|
758
|
+
|
759
|
+
# Set up generators for each request in the batch
|
760
|
+
for i in range(batch_size):
|
874
761
|
tmp_obj = obj[i]
|
875
|
-
|
876
|
-
|
762
|
+
generators.append(
|
763
|
+
self._wait_one_response(
|
764
|
+
tmp_obj, self.rid_to_state[tmp_obj.rid], request
|
765
|
+
)
|
766
|
+
)
|
877
767
|
rids.append(tmp_obj.rid)
|
878
768
|
else:
|
879
769
|
# Sequential tokenization and processing
|
@@ -952,9 +842,6 @@ class TokenizerManager:
|
|
952
842
|
except StopAsyncIteration:
|
953
843
|
pass
|
954
844
|
|
955
|
-
async def flush_cache(self) -> FlushCacheReqOutput:
|
956
|
-
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
957
|
-
|
958
845
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
959
846
|
if not abort_all and rid not in self.rid_to_state:
|
960
847
|
return
|
@@ -964,55 +851,6 @@ class TokenizerManager:
|
|
964
851
|
if self.enable_metrics:
|
965
852
|
self.metrics_collector.observe_one_aborted_request()
|
966
853
|
|
967
|
-
async def start_profile(
|
968
|
-
self,
|
969
|
-
output_dir: Optional[str] = None,
|
970
|
-
start_step: Optional[int] = None,
|
971
|
-
num_steps: Optional[int] = None,
|
972
|
-
activities: Optional[List[str]] = None,
|
973
|
-
with_stack: Optional[bool] = None,
|
974
|
-
record_shapes: Optional[bool] = None,
|
975
|
-
profile_by_stage: bool = False,
|
976
|
-
):
|
977
|
-
self.auto_create_handle_loop()
|
978
|
-
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
979
|
-
with_stack = False if with_stack is False or env_with_stack is False else True
|
980
|
-
req = ProfileReq(
|
981
|
-
type=ProfileReqType.START_PROFILE,
|
982
|
-
output_dir=output_dir,
|
983
|
-
start_step=start_step,
|
984
|
-
num_steps=num_steps,
|
985
|
-
activities=activities,
|
986
|
-
with_stack=with_stack,
|
987
|
-
record_shapes=record_shapes,
|
988
|
-
profile_by_stage=profile_by_stage,
|
989
|
-
profile_id=str(time.time()),
|
990
|
-
)
|
991
|
-
return await self._execute_profile(req)
|
992
|
-
|
993
|
-
async def stop_profile(self):
|
994
|
-
self.auto_create_handle_loop()
|
995
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
996
|
-
return await self._execute_profile(req)
|
997
|
-
|
998
|
-
async def _execute_profile(self, req: ProfileReq):
|
999
|
-
result = (await self.profile_communicator(req))[0]
|
1000
|
-
if not result.success:
|
1001
|
-
raise RuntimeError(result.message)
|
1002
|
-
return result
|
1003
|
-
|
1004
|
-
async def start_expert_distribution_record(self):
|
1005
|
-
self.auto_create_handle_loop()
|
1006
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
1007
|
-
|
1008
|
-
async def stop_expert_distribution_record(self):
|
1009
|
-
self.auto_create_handle_loop()
|
1010
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
1011
|
-
|
1012
|
-
async def dump_expert_distribution_record(self):
|
1013
|
-
self.auto_create_handle_loop()
|
1014
|
-
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
1015
|
-
|
1016
854
|
async def pause_generation(self):
|
1017
855
|
async with self.is_pause_cond:
|
1018
856
|
self.is_pause = True
|
@@ -1047,6 +885,8 @@ class TokenizerManager:
|
|
1047
885
|
async def _wait_for_model_update_from_disk(
|
1048
886
|
self, obj: UpdateWeightFromDiskReqInput
|
1049
887
|
) -> Tuple[bool, str]:
|
888
|
+
if self.server_args.tokenizer_worker_num > 1:
|
889
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1050
890
|
self.send_to_scheduler.send_pyobj(obj)
|
1051
891
|
self.model_update_result = asyncio.Future()
|
1052
892
|
if self.server_args.dp_size == 1:
|
@@ -1071,191 +911,6 @@ class TokenizerManager:
|
|
1071
911
|
all_paused_requests = [r.num_paused_requests for r in result]
|
1072
912
|
return all_success, all_message, all_paused_requests
|
1073
913
|
|
1074
|
-
async def init_weights_update_group(
|
1075
|
-
self,
|
1076
|
-
obj: InitWeightsUpdateGroupReqInput,
|
1077
|
-
request: Optional[fastapi.Request] = None,
|
1078
|
-
) -> Tuple[bool, str]:
|
1079
|
-
self.auto_create_handle_loop()
|
1080
|
-
assert (
|
1081
|
-
self.server_args.dp_size == 1
|
1082
|
-
), "dp_size must be 1 for init parameter update group"
|
1083
|
-
result = (await self.init_weights_update_group_communicator(obj))[0]
|
1084
|
-
return result.success, result.message
|
1085
|
-
|
1086
|
-
async def update_weights_from_distributed(
|
1087
|
-
self,
|
1088
|
-
obj: UpdateWeightsFromDistributedReqInput,
|
1089
|
-
request: Optional[fastapi.Request] = None,
|
1090
|
-
) -> Tuple[bool, str]:
|
1091
|
-
self.auto_create_handle_loop()
|
1092
|
-
assert (
|
1093
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1094
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
1095
|
-
|
1096
|
-
if obj.abort_all_requests:
|
1097
|
-
self.abort_request(abort_all=True)
|
1098
|
-
|
1099
|
-
# This means that weight sync
|
1100
|
-
# cannot run while requests are in progress.
|
1101
|
-
async with self.model_update_lock.writer_lock:
|
1102
|
-
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
1103
|
-
return result.success, result.message
|
1104
|
-
|
1105
|
-
async def update_weights_from_tensor(
|
1106
|
-
self,
|
1107
|
-
obj: UpdateWeightsFromTensorReqInput,
|
1108
|
-
request: Optional[fastapi.Request] = None,
|
1109
|
-
) -> Tuple[bool, str]:
|
1110
|
-
self.auto_create_handle_loop()
|
1111
|
-
assert (
|
1112
|
-
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
1113
|
-
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
1114
|
-
|
1115
|
-
if obj.abort_all_requests:
|
1116
|
-
self.abort_request(abort_all=True)
|
1117
|
-
|
1118
|
-
# This means that weight sync
|
1119
|
-
# cannot run while requests are in progress.
|
1120
|
-
async with self.model_update_lock.writer_lock:
|
1121
|
-
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
1122
|
-
return result.success, result.message
|
1123
|
-
|
1124
|
-
async def load_lora_adapter(
|
1125
|
-
self,
|
1126
|
-
obj: LoadLoRAAdapterReqInput,
|
1127
|
-
_: Optional[fastapi.Request] = None,
|
1128
|
-
) -> LoadLoRAAdapterReqOutput:
|
1129
|
-
self.auto_create_handle_loop()
|
1130
|
-
|
1131
|
-
try:
|
1132
|
-
if not self.server_args.enable_lora:
|
1133
|
-
raise ValueError(
|
1134
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1135
|
-
)
|
1136
|
-
|
1137
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1138
|
-
# with dp_size > 1.
|
1139
|
-
assert (
|
1140
|
-
self.server_args.dp_size == 1
|
1141
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1142
|
-
logger.info(
|
1143
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1144
|
-
obj.lora_name,
|
1145
|
-
obj.lora_path,
|
1146
|
-
)
|
1147
|
-
|
1148
|
-
async with self.lora_update_lock:
|
1149
|
-
if (
|
1150
|
-
self.server_args.max_loaded_loras is not None
|
1151
|
-
and self.lora_registry.num_registered_loras
|
1152
|
-
>= self.server_args.max_loaded_loras
|
1153
|
-
):
|
1154
|
-
raise ValueError(
|
1155
|
-
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1156
|
-
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1157
|
-
"Please unload some LoRA adapters before loading new ones."
|
1158
|
-
)
|
1159
|
-
|
1160
|
-
# Generate new uniquely identifiable LoRARef object.
|
1161
|
-
new_adapter = LoRARef(
|
1162
|
-
lora_name=obj.lora_name,
|
1163
|
-
lora_path=obj.lora_path,
|
1164
|
-
pinned=obj.pinned,
|
1165
|
-
)
|
1166
|
-
|
1167
|
-
# Trigger the actual loading operation at the backend processes.
|
1168
|
-
obj.lora_id = new_adapter.lora_id
|
1169
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1170
|
-
|
1171
|
-
# Register the LoRA adapter only after loading is successful.
|
1172
|
-
if result.success:
|
1173
|
-
await self.lora_registry.register(new_adapter)
|
1174
|
-
|
1175
|
-
return result
|
1176
|
-
except ValueError as e:
|
1177
|
-
return LoadLoRAAdapterReqOutput(
|
1178
|
-
success=False,
|
1179
|
-
error_message=str(e),
|
1180
|
-
)
|
1181
|
-
|
1182
|
-
async def unload_lora_adapter(
|
1183
|
-
self,
|
1184
|
-
obj: UnloadLoRAAdapterReqInput,
|
1185
|
-
_: Optional[fastapi.Request] = None,
|
1186
|
-
) -> UnloadLoRAAdapterReqOutput:
|
1187
|
-
self.auto_create_handle_loop()
|
1188
|
-
|
1189
|
-
try:
|
1190
|
-
if not self.server_args.enable_lora:
|
1191
|
-
raise ValueError(
|
1192
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1193
|
-
)
|
1194
|
-
|
1195
|
-
assert (
|
1196
|
-
obj.lora_name is not None
|
1197
|
-
), "lora_name must be provided to unload LoRA adapter"
|
1198
|
-
|
1199
|
-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1200
|
-
# with dp_size > 1.
|
1201
|
-
assert (
|
1202
|
-
self.server_args.dp_size == 1
|
1203
|
-
), "dp_size must be 1 for dynamic lora loading"
|
1204
|
-
logger.info(
|
1205
|
-
"Start unload Lora adapter. Lora name=%s",
|
1206
|
-
obj.lora_name,
|
1207
|
-
)
|
1208
|
-
|
1209
|
-
async with self.lora_update_lock:
|
1210
|
-
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1211
|
-
# from being started.
|
1212
|
-
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1213
|
-
obj.lora_id = lora_id
|
1214
|
-
|
1215
|
-
# Initiate the actual unloading operation at the backend processes only after all
|
1216
|
-
# ongoing requests using this LoRA adapter are finished.
|
1217
|
-
await self.lora_registry.wait_for_unload(lora_id)
|
1218
|
-
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1219
|
-
|
1220
|
-
return result
|
1221
|
-
except ValueError as e:
|
1222
|
-
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1223
|
-
|
1224
|
-
async def get_weights_by_name(
|
1225
|
-
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
1226
|
-
):
|
1227
|
-
self.auto_create_handle_loop()
|
1228
|
-
results = await self.get_weights_by_name_communicator(obj)
|
1229
|
-
all_parameters = [r.parameter for r in results]
|
1230
|
-
if self.server_args.dp_size == 1:
|
1231
|
-
return all_parameters[0]
|
1232
|
-
else:
|
1233
|
-
return all_parameters
|
1234
|
-
|
1235
|
-
async def release_memory_occupation(
|
1236
|
-
self,
|
1237
|
-
obj: ReleaseMemoryOccupationReqInput,
|
1238
|
-
request: Optional[fastapi.Request] = None,
|
1239
|
-
):
|
1240
|
-
self.auto_create_handle_loop()
|
1241
|
-
await self.release_memory_occupation_communicator(obj)
|
1242
|
-
|
1243
|
-
async def resume_memory_occupation(
|
1244
|
-
self,
|
1245
|
-
obj: ResumeMemoryOccupationReqInput,
|
1246
|
-
request: Optional[fastapi.Request] = None,
|
1247
|
-
):
|
1248
|
-
self.auto_create_handle_loop()
|
1249
|
-
await self.resume_memory_occupation_communicator(obj)
|
1250
|
-
|
1251
|
-
async def slow_down(
|
1252
|
-
self,
|
1253
|
-
obj: SlowDownReqInput,
|
1254
|
-
request: Optional[fastapi.Request] = None,
|
1255
|
-
):
|
1256
|
-
self.auto_create_handle_loop()
|
1257
|
-
await self.slow_down_communicator(obj)
|
1258
|
-
|
1259
914
|
async def open_session(
|
1260
915
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
1261
916
|
):
|
@@ -1266,6 +921,8 @@ class TokenizerManager:
|
|
1266
921
|
elif obj.session_id in self.session_futures:
|
1267
922
|
return None
|
1268
923
|
|
924
|
+
if self.server_args.tokenizer_worker_num > 1:
|
925
|
+
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
1269
926
|
self.send_to_scheduler.send_pyobj(obj)
|
1270
927
|
|
1271
928
|
self.session_futures[obj.session_id] = asyncio.Future()
|
@@ -1278,30 +935,6 @@ class TokenizerManager:
|
|
1278
935
|
):
|
1279
936
|
await self.send_to_scheduler.send_pyobj(obj)
|
1280
937
|
|
1281
|
-
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
1282
|
-
req = GetInternalStateReq()
|
1283
|
-
responses: List[GetInternalStateReqOutput] = (
|
1284
|
-
await self.get_internal_state_communicator(req)
|
1285
|
-
)
|
1286
|
-
# Many DP ranks
|
1287
|
-
return [res.internal_state for res in responses]
|
1288
|
-
|
1289
|
-
async def set_internal_state(
|
1290
|
-
self, obj: SetInternalStateReq
|
1291
|
-
) -> SetInternalStateReqOutput:
|
1292
|
-
responses: List[SetInternalStateReqOutput] = (
|
1293
|
-
await self.set_internal_state_communicator(obj)
|
1294
|
-
)
|
1295
|
-
return [res.internal_state for res in responses]
|
1296
|
-
|
1297
|
-
async def get_load(self) -> dict:
|
1298
|
-
# TODO(lsyin): fake load report server
|
1299
|
-
if not self.current_load_lock.locked():
|
1300
|
-
async with self.current_load_lock:
|
1301
|
-
internal_state = await self.get_internal_state()
|
1302
|
-
self.current_load = internal_state[0]["load"]
|
1303
|
-
return {"load": self.current_load}
|
1304
|
-
|
1305
938
|
def get_log_request_metadata(self):
|
1306
939
|
max_length = None
|
1307
940
|
skip_names = None
|
@@ -1543,7 +1176,6 @@ class TokenizerManager:
|
|
1543
1176
|
|
1544
1177
|
async def handle_loop(self):
|
1545
1178
|
"""The event loop that handles requests"""
|
1546
|
-
|
1547
1179
|
while True:
|
1548
1180
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
1549
1181
|
self._result_dispatcher(recv_obj)
|
@@ -1563,9 +1195,12 @@ class TokenizerManager:
|
|
1563
1195
|
)
|
1564
1196
|
continue
|
1565
1197
|
|
1198
|
+
origin_rid = rid
|
1199
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1200
|
+
origin_rid = get_origin_rid(rid)
|
1566
1201
|
# Build meta_info and return value
|
1567
1202
|
meta_info = {
|
1568
|
-
"id":
|
1203
|
+
"id": origin_rid,
|
1569
1204
|
"finish_reason": recv_obj.finished_reasons[i],
|
1570
1205
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1571
1206
|
"weight_version": self.server_args.weight_version,
|
@@ -1871,6 +1506,9 @@ class TokenizerManager:
|
|
1871
1506
|
if is_health_check_generate_req(recv_obj):
|
1872
1507
|
return
|
1873
1508
|
state = self.rid_to_state[recv_obj.rid]
|
1509
|
+
origin_rid = recv_obj.rid
|
1510
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1511
|
+
origin_rid = get_origin_rid(origin_rid)
|
1874
1512
|
state.finished = True
|
1875
1513
|
if recv_obj.finished_reason:
|
1876
1514
|
out = {
|
@@ -1883,7 +1521,7 @@ class TokenizerManager:
|
|
1883
1521
|
out = {
|
1884
1522
|
"text": "",
|
1885
1523
|
"meta_info": {
|
1886
|
-
"id":
|
1524
|
+
"id": origin_rid,
|
1887
1525
|
"finish_reason": {
|
1888
1526
|
"type": "abort",
|
1889
1527
|
"message": "Abort before prefill",
|
@@ -2063,47 +1701,6 @@ class SignalHandler:
|
|
2063
1701
|
kill_process_tree(os.getpid())
|
2064
1702
|
|
2065
1703
|
|
2066
|
-
T = TypeVar("T")
|
2067
|
-
|
2068
|
-
|
2069
|
-
class _Communicator(Generic[T]):
|
2070
|
-
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2071
|
-
|
2072
|
-
def __init__(self, sender, fan_out: int):
|
2073
|
-
self._sender = sender
|
2074
|
-
self._fan_out = fan_out
|
2075
|
-
self._result_event: Optional[asyncio.Event] = None
|
2076
|
-
self._result_values: Optional[List[T]] = None
|
2077
|
-
self._ready_queue: Deque[asyncio.Future] = deque()
|
2078
|
-
|
2079
|
-
async def __call__(self, obj):
|
2080
|
-
ready_event = asyncio.Event()
|
2081
|
-
if self._result_event is not None or len(self._ready_queue) > 0:
|
2082
|
-
self._ready_queue.append(ready_event)
|
2083
|
-
await ready_event.wait()
|
2084
|
-
assert self._result_event is None
|
2085
|
-
assert self._result_values is None
|
2086
|
-
|
2087
|
-
if obj:
|
2088
|
-
self._sender.send_pyobj(obj)
|
2089
|
-
|
2090
|
-
self._result_event = asyncio.Event()
|
2091
|
-
self._result_values = []
|
2092
|
-
await self._result_event.wait()
|
2093
|
-
result_values = self._result_values
|
2094
|
-
self._result_event = self._result_values = None
|
2095
|
-
|
2096
|
-
if len(self._ready_queue) > 0:
|
2097
|
-
self._ready_queue.popleft().set()
|
2098
|
-
|
2099
|
-
return result_values
|
2100
|
-
|
2101
|
-
def handle_recv(self, recv_obj: T):
|
2102
|
-
self._result_values.append(recv_obj)
|
2103
|
-
if len(self._result_values) == self._fan_out:
|
2104
|
-
self._result_event.set()
|
2105
|
-
|
2106
|
-
|
2107
1704
|
# Note: request abort handling logic
|
2108
1705
|
# We should handle all of the following cases correctly.
|
2109
1706
|
#
|