sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +77 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/detokenizer_manager.py +2 -0
- sglang/srt/managers/io_struct.py +40 -9
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/scheduler.py +69 -21
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +48 -10
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -0
- sglang/srt/models/llama.py +11 -0
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +60 -2
- sglang/srt/openai_api/protocol.py +48 -0
- sglang/srt/server.py +26 -3
- sglang/srt/server_args.py +24 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/METADATA +3 -3
- sglang-0.4.1.post3.dist-info/RECORD +305 -0
- sglang-0.4.1.post1.dist-info/RECORD +0 -195
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
29
29
|
|
30
30
|
import dataclasses
|
31
31
|
import logging
|
32
|
-
from typing import List, Optional, Tuple, Union
|
32
|
+
from typing import List, Optional, Set, Tuple, Union
|
33
33
|
|
34
34
|
import numpy as np
|
35
35
|
import torch
|
@@ -209,6 +209,7 @@ class Req:
|
|
209
209
|
lora_path: Optional[str] = None,
|
210
210
|
input_embeds: Optional[List[List[float]]] = None,
|
211
211
|
session_id: Optional[str] = None,
|
212
|
+
eos_token_ids: Optional[Set[int]] = None,
|
212
213
|
):
|
213
214
|
# Input and output info
|
214
215
|
self.rid = rid
|
@@ -236,6 +237,7 @@ class Req:
|
|
236
237
|
self.finished_reason = None
|
237
238
|
self.to_abort = False
|
238
239
|
self.stream = stream
|
240
|
+
self.eos_token_ids = eos_token_ids
|
239
241
|
|
240
242
|
# For incremental decoding
|
241
243
|
# ----- | --------- read_ids -------|
|
@@ -395,18 +397,23 @@ class Req:
|
|
395
397
|
|
396
398
|
last_token_id = self.output_ids[-1]
|
397
399
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
400
|
+
if not self.sampling_params.ignore_eos:
|
401
|
+
matched_eos = False
|
402
|
+
|
403
|
+
# Check stop token ids
|
404
|
+
if self.sampling_params.stop_token_ids:
|
405
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
406
|
+
if self.eos_token_ids:
|
407
|
+
matched_eos |= last_token_id in self.eos_token_ids
|
408
|
+
if self.tokenizer is not None:
|
409
|
+
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
410
|
+
if self.tokenizer.additional_stop_token_ids:
|
411
|
+
matched_eos |= (
|
412
|
+
last_token_id in self.tokenizer.additional_stop_token_ids
|
413
|
+
)
|
414
|
+
if matched_eos:
|
415
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
416
|
+
return
|
410
417
|
|
411
418
|
# Check stop strings
|
412
419
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -836,8 +843,8 @@ class ScheduleBatch:
|
|
836
843
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
837
844
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
838
845
|
|
839
|
-
def check_decode_mem(self):
|
840
|
-
bs = len(self.reqs)
|
846
|
+
def check_decode_mem(self, buf_multiplier=1):
|
847
|
+
bs = len(self.reqs) * buf_multiplier
|
841
848
|
if self.token_to_kv_pool.available_size() >= bs:
|
842
849
|
return True
|
843
850
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,7 +22,7 @@ import warnings
|
|
22
22
|
from collections import deque
|
23
23
|
from concurrent import futures
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import
|
25
|
+
from typing import Dict, List, Optional, Tuple
|
26
26
|
|
27
27
|
import psutil
|
28
28
|
import setproctitle
|
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
|
|
52
52
|
UpdateWeightFromDiskReqOutput,
|
53
53
|
UpdateWeightsFromDistributedReqInput,
|
54
54
|
UpdateWeightsFromDistributedReqOutput,
|
55
|
+
UpdateWeightsFromTensorReqInput,
|
56
|
+
UpdateWeightsFromTensorReqOutput,
|
55
57
|
)
|
56
58
|
from sglang.srt.managers.schedule_batch import (
|
57
59
|
FINISH_ABORT,
|
@@ -88,7 +90,7 @@ from sglang.utils import get_exception_traceback
|
|
88
90
|
|
89
91
|
logger = logging.getLogger(__name__)
|
90
92
|
|
91
|
-
# Test retract decode
|
93
|
+
# Test retract decode for debugging purposes
|
92
94
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
93
95
|
|
94
96
|
|
@@ -127,12 +129,12 @@ class Scheduler:
|
|
127
129
|
)
|
128
130
|
|
129
131
|
if server_args.skip_tokenizer_init:
|
130
|
-
# Directly send to the
|
132
|
+
# Directly send to the TokenizerManager
|
131
133
|
self.send_to_detokenizer = get_zmq_socket(
|
132
134
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
133
135
|
)
|
134
136
|
else:
|
135
|
-
# Send to the
|
137
|
+
# Send to the DetokenizerManager
|
136
138
|
self.send_to_detokenizer = get_zmq_socket(
|
137
139
|
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
138
140
|
)
|
@@ -383,7 +385,8 @@ class Scheduler:
|
|
383
385
|
self.process_input_requests(recv_reqs)
|
384
386
|
|
385
387
|
batch = self.get_next_batch_to_run()
|
386
|
-
|
388
|
+
|
389
|
+
if self.server_args.enable_dp_attention: # TODO: simplify this
|
387
390
|
batch = self.prepare_dp_attn_batch(batch)
|
388
391
|
|
389
392
|
self.cur_batch = batch
|
@@ -392,7 +395,7 @@ class Scheduler:
|
|
392
395
|
result = self.run_batch(batch)
|
393
396
|
self.process_batch_result(batch, result)
|
394
397
|
else:
|
395
|
-
#
|
398
|
+
# When the server is idle, so self-check and re-init some states
|
396
399
|
self.check_memory()
|
397
400
|
self.new_token_ratio = self.init_new_token_ratio
|
398
401
|
|
@@ -409,12 +412,13 @@ class Scheduler:
|
|
409
412
|
|
410
413
|
batch = self.get_next_batch_to_run()
|
411
414
|
self.cur_batch = batch
|
415
|
+
|
412
416
|
if batch:
|
413
417
|
result = self.run_batch(batch)
|
414
418
|
result_queue.append((batch.copy(), result))
|
415
419
|
|
416
420
|
if self.last_batch is None:
|
417
|
-
#
|
421
|
+
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
418
422
|
# It is now used for triggering the sampling_info_done event.
|
419
423
|
tmp_batch = ScheduleBatch(
|
420
424
|
reqs=None,
|
@@ -424,19 +428,21 @@ class Scheduler:
|
|
424
428
|
self.process_batch_result(tmp_batch, None)
|
425
429
|
|
426
430
|
if self.last_batch:
|
431
|
+
# Process the results of the last batch
|
427
432
|
tmp_batch, tmp_result = result_queue.popleft()
|
428
433
|
tmp_batch.next_batch_sampling_info = (
|
429
434
|
self.tp_worker.cur_sampling_info if batch else None
|
430
435
|
)
|
431
436
|
self.process_batch_result(tmp_batch, tmp_result)
|
432
437
|
elif batch is None:
|
433
|
-
#
|
438
|
+
# When the server is idle, so self-check and re-init some states
|
434
439
|
self.check_memory()
|
435
440
|
self.new_token_ratio = self.init_new_token_ratio
|
436
441
|
|
437
442
|
self.last_batch = batch
|
438
443
|
|
439
|
-
def recv_requests(self):
|
444
|
+
def recv_requests(self) -> List[Req]:
|
445
|
+
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
440
446
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
441
447
|
recv_reqs = []
|
442
448
|
|
@@ -478,6 +484,11 @@ class Scheduler:
|
|
478
484
|
self.send_to_tokenizer.send_pyobj(
|
479
485
|
UpdateWeightsFromDistributedReqOutput(success, message)
|
480
486
|
)
|
487
|
+
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
488
|
+
success, message = self.update_weights_from_tensor(recv_req)
|
489
|
+
self.send_to_tokenizer.send_pyobj(
|
490
|
+
UpdateWeightsFromTensorReqOutput(success, message)
|
491
|
+
)
|
481
492
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
482
493
|
parameter = self.get_weights_by_name(recv_req)
|
483
494
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
@@ -487,8 +498,10 @@ class Scheduler:
|
|
487
498
|
else:
|
488
499
|
self.stop_profile()
|
489
500
|
elif isinstance(recv_req, OpenSessionReqInput):
|
490
|
-
session_id = self.open_session(recv_req)
|
491
|
-
self.send_to_tokenizer.send_pyobj(
|
501
|
+
session_id, success = self.open_session(recv_req)
|
502
|
+
self.send_to_tokenizer.send_pyobj(
|
503
|
+
OpenSessionReqOutput(session_id=session_id, success=success)
|
504
|
+
)
|
492
505
|
elif isinstance(recv_req, CloseSessionReqInput):
|
493
506
|
self.close_session(recv_req)
|
494
507
|
else:
|
@@ -499,7 +512,11 @@ class Scheduler:
|
|
499
512
|
recv_req: TokenizedGenerateReqInput,
|
500
513
|
):
|
501
514
|
# Create a new request
|
502
|
-
if
|
515
|
+
if (
|
516
|
+
recv_req.session_params is None
|
517
|
+
or recv_req.session_params.id is None
|
518
|
+
or recv_req.session_params.id not in self.sessions
|
519
|
+
):
|
503
520
|
|
504
521
|
if recv_req.input_embeds is not None:
|
505
522
|
# Generate fake input_ids based on the length of input_embeds
|
@@ -517,18 +534,22 @@ class Scheduler:
|
|
517
534
|
stream=recv_req.stream,
|
518
535
|
lora_path=recv_req.lora_path,
|
519
536
|
input_embeds=recv_req.input_embeds,
|
537
|
+
eos_token_ids=self.model_config.hf_eos_token_id,
|
520
538
|
)
|
521
539
|
req.tokenizer = self.tokenizer
|
522
540
|
|
523
|
-
if
|
541
|
+
if (
|
542
|
+
recv_req.session_params is not None
|
543
|
+
and recv_req.session_params.id is not None
|
544
|
+
):
|
524
545
|
req.finished_reason = FINISH_ABORT(
|
525
|
-
f"Invalid request: session id {recv_req.
|
546
|
+
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
526
547
|
)
|
527
548
|
self.waiting_queue.append(req)
|
528
549
|
return
|
529
550
|
else:
|
530
|
-
# Create a new request from a
|
531
|
-
session = self.sessions[recv_req.
|
551
|
+
# Create a new request from a previous session
|
552
|
+
session = self.sessions[recv_req.session_params.id]
|
532
553
|
req = session.create_req(recv_req, self.tokenizer)
|
533
554
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
534
555
|
self.waiting_queue.append(req)
|
@@ -804,6 +825,8 @@ class Scheduler:
|
|
804
825
|
if res == AddReqResult.NO_TOKEN:
|
805
826
|
self.batch_is_full = True
|
806
827
|
break
|
828
|
+
if self.server_args.prefill_only_one_req:
|
829
|
+
break
|
807
830
|
|
808
831
|
# Update waiting queue
|
809
832
|
can_run_list = adder.can_run_list
|
@@ -1195,6 +1218,7 @@ class Scheduler:
|
|
1195
1218
|
decode_ids_list = []
|
1196
1219
|
read_offsets = []
|
1197
1220
|
output_ids = []
|
1221
|
+
origin_input_ids = []
|
1198
1222
|
|
1199
1223
|
skip_special_tokens = []
|
1200
1224
|
spaces_between_special_tokens = []
|
@@ -1243,8 +1267,14 @@ class Scheduler:
|
|
1243
1267
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
1244
1268
|
decode_ids_list.append(decode_ids)
|
1245
1269
|
read_offsets.append(read_offset)
|
1246
|
-
if self.skip_tokenizer_init:
|
1270
|
+
if self.skip_tokenizer_init or self.server_args.return_token_ids:
|
1247
1271
|
output_ids.append(req.output_ids)
|
1272
|
+
else:
|
1273
|
+
output_ids = None
|
1274
|
+
if self.server_args.return_token_ids:
|
1275
|
+
origin_input_ids.append(req.origin_input_ids)
|
1276
|
+
else:
|
1277
|
+
origin_input_ids = None
|
1248
1278
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1249
1279
|
spaces_between_special_tokens.append(
|
1250
1280
|
req.sampling_params.spaces_between_special_tokens
|
@@ -1276,6 +1306,7 @@ class Scheduler:
|
|
1276
1306
|
decoded_texts,
|
1277
1307
|
decode_ids_list,
|
1278
1308
|
read_offsets,
|
1309
|
+
origin_input_ids,
|
1279
1310
|
output_ids,
|
1280
1311
|
skip_special_tokens,
|
1281
1312
|
spaces_between_special_tokens,
|
@@ -1457,6 +1488,17 @@ class Scheduler:
|
|
1457
1488
|
logger.error(message)
|
1458
1489
|
return success, message
|
1459
1490
|
|
1491
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
1492
|
+
"""Update the online model parameter from tensors."""
|
1493
|
+
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
1494
|
+
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
1495
|
+
if success:
|
1496
|
+
flash_cache_success = self.flush_cache()
|
1497
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1498
|
+
else:
|
1499
|
+
logger.error(message)
|
1500
|
+
return success, message
|
1501
|
+
|
1460
1502
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1461
1503
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1462
1504
|
return parameter
|
@@ -1475,16 +1517,20 @@ class Scheduler:
|
|
1475
1517
|
)
|
1476
1518
|
logger.info("Profiler is done")
|
1477
1519
|
|
1478
|
-
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
1520
|
+
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
1479
1521
|
# handle error
|
1480
1522
|
session_id = recv_req.session_id
|
1481
1523
|
if session_id in self.sessions:
|
1482
1524
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1525
|
+
return session_id, False
|
1526
|
+
elif session_id is None:
|
1527
|
+
logger.warning(f"session id is None, cannot open.")
|
1528
|
+
return session_id, False
|
1483
1529
|
else:
|
1484
1530
|
self.sessions[session_id] = Session(
|
1485
1531
|
recv_req.capacity_of_str_len, session_id
|
1486
1532
|
)
|
1487
|
-
|
1533
|
+
return session_id, True
|
1488
1534
|
|
1489
1535
|
def close_session(self, recv_req: CloseSessionReqInput):
|
1490
1536
|
# handle error
|
@@ -1509,18 +1555,20 @@ def run_scheduler_process(
|
|
1509
1555
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1510
1556
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1511
1557
|
|
1558
|
+
# Configue the logger
|
1512
1559
|
if dp_rank is None:
|
1513
1560
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1514
1561
|
else:
|
1515
1562
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1563
|
+
suppress_other_loggers()
|
1516
1564
|
|
1517
|
-
#
|
1565
|
+
# Set cpu affinity to this gpu process
|
1518
1566
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1519
1567
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1520
1568
|
|
1521
|
-
suppress_other_loggers()
|
1522
1569
|
parent_process = psutil.Process().parent()
|
1523
1570
|
|
1571
|
+
# Create a scheduler and run the event loop
|
1524
1572
|
try:
|
1525
1573
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1526
1574
|
pipe_writer.send(
|
@@ -10,41 +10,116 @@
|
|
10
10
|
# limitations under the License.
|
11
11
|
# ==============================================================================
|
12
12
|
|
13
|
+
import logging
|
13
14
|
import uuid
|
15
|
+
from typing import Dict, Optional
|
14
16
|
|
15
17
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
16
|
-
from sglang.srt.managers.schedule_batch import
|
18
|
+
from sglang.srt.managers.schedule_batch import Req
|
19
|
+
|
20
|
+
|
21
|
+
class SessionReqNode:
|
22
|
+
def __init__(self, req, parent=None, childs=None):
|
23
|
+
self.req = req
|
24
|
+
self.parent = parent
|
25
|
+
if parent is not None:
|
26
|
+
parent.childs.append(self)
|
27
|
+
self.childs = [] if not childs else childs
|
28
|
+
|
29
|
+
def clear_childs(self, req_dict):
|
30
|
+
for req_node in self.childs:
|
31
|
+
req_node.clear(req_dict)
|
32
|
+
self.childs = []
|
33
|
+
|
34
|
+
def clear(self, req_dict):
|
35
|
+
for req_node in self.childs:
|
36
|
+
req_node.clear(req_dict)
|
37
|
+
|
38
|
+
if self.req.finished_reason == None:
|
39
|
+
self.req.to_abort = True
|
40
|
+
del req_dict[self.req.rid]
|
41
|
+
|
42
|
+
def abort(self):
|
43
|
+
if self.req.finished_reason == None:
|
44
|
+
self.req.to_abort = True
|
45
|
+
|
46
|
+
def __str__(self):
|
47
|
+
return self._str_helper(self.req.rid)
|
48
|
+
|
49
|
+
def _str_helper(self, prefix=""):
|
50
|
+
if len(self.childs) == 0:
|
51
|
+
return prefix + "\n"
|
52
|
+
else:
|
53
|
+
origin_prefix = prefix
|
54
|
+
prefix += " -- " + self.childs[0].req.rid
|
55
|
+
ret = self.childs[0]._str_helper(prefix)
|
56
|
+
for child in self.childs[1:]:
|
57
|
+
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
58
|
+
ret += child._str_helper(prefix)
|
59
|
+
return ret
|
17
60
|
|
18
61
|
|
19
62
|
class Session:
|
20
|
-
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
63
|
+
def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
|
21
64
|
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
22
65
|
self.capacity_of_str_len = capacity_of_str_len
|
23
|
-
self.
|
66
|
+
self.req_nodes: Dict[str, SessionReqNode] = {}
|
24
67
|
|
25
68
|
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
69
|
+
assert req.session_params is not None
|
70
|
+
session_params = req.session_params
|
71
|
+
|
72
|
+
last_req_node = None
|
73
|
+
last_req = None
|
74
|
+
abort = False
|
75
|
+
if session_params.replace:
|
76
|
+
if session_params.rid is None:
|
77
|
+
for _, req_node in self.req_nodes.items():
|
78
|
+
req_node.clear(self.req_nodes)
|
79
|
+
else:
|
80
|
+
if session_params.rid not in self.req_nodes:
|
81
|
+
abort = True
|
82
|
+
else:
|
83
|
+
last_req_node = self.req_nodes[session_params.rid]
|
84
|
+
last_req_node.abort()
|
85
|
+
last_req = last_req_node.req
|
86
|
+
last_req_node.clear_childs(self.req_nodes)
|
31
87
|
else:
|
32
|
-
|
33
|
-
|
88
|
+
if session_params.rid is not None:
|
89
|
+
if session_params.rid not in self.req_nodes:
|
90
|
+
abort = True
|
91
|
+
else:
|
92
|
+
last_req_node = self.req_nodes[session_params.rid]
|
93
|
+
last_req = last_req_node.req
|
94
|
+
if not last_req.finished():
|
95
|
+
logging.warning(
|
96
|
+
"The request in a session is appending to a request that hasn't finished."
|
97
|
+
)
|
98
|
+
abort = True
|
99
|
+
|
100
|
+
if last_req is not None:
|
101
|
+
# trim bos token if it is an append
|
102
|
+
if req.input_ids[0] == tokenizer.bos_token_id:
|
103
|
+
req.input_ids = req.input_ids[1:]
|
104
|
+
|
34
105
|
input_ids = (
|
35
|
-
|
36
|
-
+
|
37
|
-
: self.reqs[-1].sampling_params.max_new_tokens
|
38
|
-
]
|
39
|
-
+ req.input_ids
|
106
|
+
last_req.origin_input_ids
|
107
|
+
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
40
108
|
)
|
109
|
+
if session_params.offset and session_params.offset != 0:
|
110
|
+
input_ids = input_ids[: session_params.offset] + req.input_ids
|
111
|
+
else:
|
112
|
+
input_ids += req.input_ids
|
41
113
|
input_ids_unpadded = (
|
42
|
-
|
43
|
-
+
|
44
|
-
: self.reqs[-1].sampling_params.max_new_tokens
|
45
|
-
]
|
46
|
-
+ req.input_ids
|
114
|
+
last_req.origin_input_ids_unpadded
|
115
|
+
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
47
116
|
)
|
117
|
+
if session_params.offset and session_params.offset != 0:
|
118
|
+
input_ids_unpadded = (
|
119
|
+
input_ids_unpadded[: session_params.offset] + req.input_ids
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
input_ids_unpadded += req.input_ids
|
48
123
|
else:
|
49
124
|
input_ids = req.input_ids
|
50
125
|
input_ids_unpadded = req.input_ids
|
@@ -57,13 +132,13 @@ class Session:
|
|
57
132
|
lora_path=req.lora_path,
|
58
133
|
session_id=self.session_id,
|
59
134
|
)
|
60
|
-
if
|
61
|
-
new_req.image_inputs =
|
135
|
+
if last_req is not None:
|
136
|
+
new_req.image_inputs = last_req.image_inputs
|
62
137
|
new_req.tokenizer = tokenizer
|
63
|
-
if
|
64
|
-
new_req.
|
65
|
-
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
66
|
-
)
|
138
|
+
if abort:
|
139
|
+
new_req.to_abort = True
|
67
140
|
else:
|
68
|
-
|
141
|
+
new_req_node = SessionReqNode(new_req, last_req_node)
|
142
|
+
self.req_nodes[req.rid] = new_req_node
|
143
|
+
|
69
144
|
return new_req
|
@@ -53,12 +53,15 @@ from sglang.srt.managers.io_struct import (
|
|
53
53
|
OpenSessionReqInput,
|
54
54
|
OpenSessionReqOutput,
|
55
55
|
ProfileReq,
|
56
|
+
SessionParams,
|
56
57
|
TokenizedEmbeddingReqInput,
|
57
58
|
TokenizedGenerateReqInput,
|
58
59
|
UpdateWeightFromDiskReqInput,
|
59
60
|
UpdateWeightFromDiskReqOutput,
|
60
61
|
UpdateWeightsFromDistributedReqInput,
|
61
62
|
UpdateWeightsFromDistributedReqOutput,
|
63
|
+
UpdateWeightsFromTensorReqInput,
|
64
|
+
UpdateWeightsFromTensorReqOutput,
|
62
65
|
)
|
63
66
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
64
67
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -179,6 +182,9 @@ class TokenizerManager:
|
|
179
182
|
self.update_weights_from_distributed_communicator = _Communicator(
|
180
183
|
self.send_to_scheduler, server_args.dp_size
|
181
184
|
)
|
185
|
+
self.update_weights_from_tensor_communicator = _Communicator(
|
186
|
+
self.send_to_scheduler, server_args.dp_size
|
187
|
+
)
|
182
188
|
self.get_weights_by_name_communicator = _Communicator(
|
183
189
|
self.send_to_scheduler, server_args.dp_size
|
184
190
|
)
|
@@ -259,8 +265,9 @@ class TokenizerManager:
|
|
259
265
|
return_logprob = obj.return_logprob
|
260
266
|
logprob_start_len = obj.logprob_start_len
|
261
267
|
top_logprobs_num = obj.top_logprobs_num
|
262
|
-
|
263
|
-
|
268
|
+
session_params = (
|
269
|
+
SessionParams(**obj.session_params) if obj.session_params else None
|
270
|
+
)
|
264
271
|
|
265
272
|
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
266
273
|
raise ValueError(
|
@@ -287,8 +294,7 @@ class TokenizerManager:
|
|
287
294
|
obj.stream,
|
288
295
|
lora_path=obj.lora_path,
|
289
296
|
input_embeds=input_embeds,
|
290
|
-
|
291
|
-
session_rid=session_rid,
|
297
|
+
session_params=session_params,
|
292
298
|
)
|
293
299
|
elif isinstance(obj, EmbeddingReqInput):
|
294
300
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -515,6 +521,22 @@ class TokenizerManager:
|
|
515
521
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
516
522
|
return result.success, result.message
|
517
523
|
|
524
|
+
async def update_weights_from_tensor(
|
525
|
+
self,
|
526
|
+
obj: UpdateWeightsFromTensorReqInput,
|
527
|
+
request: Optional[fastapi.Request] = None,
|
528
|
+
) -> Tuple[bool, str]:
|
529
|
+
self.auto_create_handle_loop()
|
530
|
+
assert (
|
531
|
+
self.server_args.dp_size == 1
|
532
|
+
), "dp_size must be for update weights from distributed"
|
533
|
+
|
534
|
+
# This means that weight sync
|
535
|
+
# cannot run while requests are in progress.
|
536
|
+
async with self.model_update_lock.writer_lock:
|
537
|
+
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
538
|
+
return result.success, result.message
|
539
|
+
|
518
540
|
async def get_weights_by_name(
|
519
541
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
520
542
|
):
|
@@ -531,12 +553,16 @@ class TokenizerManager:
|
|
531
553
|
):
|
532
554
|
self.auto_create_handle_loop()
|
533
555
|
|
534
|
-
session_id
|
535
|
-
|
556
|
+
if obj.session_id is None:
|
557
|
+
obj.session_id = uuid.uuid4().hex
|
558
|
+
elif obj.session_id in self.session_futures:
|
559
|
+
return None
|
560
|
+
|
536
561
|
self.send_to_scheduler.send_pyobj(obj)
|
537
|
-
|
538
|
-
|
539
|
-
|
562
|
+
|
563
|
+
self.session_futures[obj.session_id] = asyncio.Future()
|
564
|
+
session_id = await self.session_futures[obj.session_id]
|
565
|
+
del self.session_futures[obj.session_id]
|
540
566
|
return session_id
|
541
567
|
|
542
568
|
async def close_session(
|
@@ -637,6 +663,13 @@ class TokenizerManager:
|
|
637
663
|
"text": recv_obj.output_strs[i],
|
638
664
|
"meta_info": meta_info,
|
639
665
|
}
|
666
|
+
if self.server_args.return_token_ids:
|
667
|
+
out_dict.update(
|
668
|
+
{
|
669
|
+
"input_ids": recv_obj.origin_input_ids[i],
|
670
|
+
"output_ids": recv_obj.output_ids[i],
|
671
|
+
}
|
672
|
+
)
|
640
673
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
641
674
|
out_dict = {
|
642
675
|
"token_ids": recv_obj.output_ids[i],
|
@@ -688,7 +721,7 @@ class TokenizerManager:
|
|
688
721
|
)
|
689
722
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
690
723
|
self.session_futures[recv_obj.session_id].set_result(
|
691
|
-
recv_obj.session_id
|
724
|
+
recv_obj.session_id if recv_obj.success else None
|
692
725
|
)
|
693
726
|
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
694
727
|
if self.server_args.dp_size == 1:
|
@@ -708,6 +741,11 @@ class TokenizerManager:
|
|
708
741
|
self.server_args.dp_size == 1
|
709
742
|
), "dp_size must be 1 for update weights from distributed"
|
710
743
|
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
744
|
+
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
745
|
+
assert (
|
746
|
+
self.server_args.dp_size == 1
|
747
|
+
), "dp_size must be 1 for update weights from distributed"
|
748
|
+
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
711
749
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
712
750
|
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
713
751
|
else:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
|
|
24
24
|
InitWeightsUpdateGroupReqInput,
|
25
25
|
UpdateWeightFromDiskReqInput,
|
26
26
|
UpdateWeightsFromDistributedReqInput,
|
27
|
+
UpdateWeightsFromTensorReqInput,
|
27
28
|
)
|
28
29
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
29
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -188,6 +189,12 @@ class TpModelWorker:
|
|
188
189
|
)
|
189
190
|
return success, message
|
190
191
|
|
192
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
193
|
+
success, message = self.model_runner.update_weights_from_tensor(
|
194
|
+
recv_req.name, recv_req.tensor
|
195
|
+
)
|
196
|
+
return success, message
|
197
|
+
|
191
198
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
192
199
|
parameter = self.model_runner.get_weights_by_name(
|
193
200
|
recv_req.name, recv_req.truncate_size
|
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
28
28
|
InitWeightsUpdateGroupReqInput,
|
29
29
|
UpdateWeightFromDiskReqInput,
|
30
30
|
UpdateWeightsFromDistributedReqInput,
|
31
|
+
UpdateWeightsFromTensorReqInput,
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
33
34
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
|
|
225
226
|
success, message = self.worker.update_weights_from_distributed(recv_req)
|
226
227
|
return success, message
|
227
228
|
|
229
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
230
|
+
success, message = self.worker.update_weights_from_tensor(recv_req)
|
231
|
+
return success, message
|
232
|
+
|
228
233
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
229
234
|
return self.worker.get_weights_by_name(recv_req)
|
230
235
|
|