sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -35,12 +35,12 @@ class SessionReqNode:
|
|
35
35
|
for req_node in self.childs:
|
36
36
|
req_node.clear(req_dict)
|
37
37
|
|
38
|
-
if self.req.finished_reason
|
38
|
+
if self.req.finished_reason is None:
|
39
39
|
self.req.to_abort = True
|
40
40
|
del req_dict[self.req.rid]
|
41
41
|
|
42
42
|
def abort(self):
|
43
|
-
if self.req.finished_reason
|
43
|
+
if self.req.finished_reason is None:
|
44
44
|
self.req.to_abort = True
|
45
45
|
|
46
46
|
def __str__(self):
|
@@ -132,6 +132,10 @@ class Session:
|
|
132
132
|
lora_path=req.lora_path,
|
133
133
|
session_id=self.session_id,
|
134
134
|
custom_logit_processor=req.custom_logit_processor,
|
135
|
+
stream=req.stream,
|
136
|
+
return_logprob=req.return_logprob,
|
137
|
+
top_logprobs_num=req.top_logprobs_num,
|
138
|
+
token_ids_logprob=req.token_ids_logprob,
|
135
139
|
)
|
136
140
|
if last_req is not None:
|
137
141
|
new_req.image_inputs = last_req.image_inputs
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import asyncio
|
17
17
|
import copy
|
18
18
|
import dataclasses
|
19
|
+
import json
|
19
20
|
import logging
|
20
21
|
import os
|
21
22
|
import pickle
|
@@ -24,9 +25,21 @@ import sys
|
|
24
25
|
import threading
|
25
26
|
import time
|
26
27
|
import uuid
|
28
|
+
from collections import deque
|
27
29
|
from datetime import datetime
|
28
30
|
from http import HTTPStatus
|
29
|
-
from typing import
|
31
|
+
from typing import (
|
32
|
+
Any,
|
33
|
+
Awaitable,
|
34
|
+
Deque,
|
35
|
+
Dict,
|
36
|
+
Generic,
|
37
|
+
List,
|
38
|
+
Optional,
|
39
|
+
Tuple,
|
40
|
+
TypeVar,
|
41
|
+
Union,
|
42
|
+
)
|
30
43
|
|
31
44
|
import fastapi
|
32
45
|
import uvloop
|
@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
|
|
44
57
|
from sglang.srt.managers.io_struct import (
|
45
58
|
AbortReq,
|
46
59
|
BatchEmbeddingOut,
|
60
|
+
BatchMultimodalOut,
|
47
61
|
BatchStrOut,
|
48
62
|
BatchTokenIDOut,
|
49
63
|
CloseSessionReqInput,
|
@@ -51,13 +65,18 @@ from sglang.srt.managers.io_struct import (
|
|
51
65
|
EmbeddingReqInput,
|
52
66
|
FlushCacheReq,
|
53
67
|
GenerateReqInput,
|
68
|
+
GetInternalStateReq,
|
69
|
+
GetInternalStateReqOutput,
|
54
70
|
GetWeightsByNameReqInput,
|
55
71
|
GetWeightsByNameReqOutput,
|
72
|
+
HealthCheckOutput,
|
56
73
|
InitWeightsUpdateGroupReqInput,
|
57
74
|
InitWeightsUpdateGroupReqOutput,
|
58
75
|
OpenSessionReqInput,
|
59
76
|
OpenSessionReqOutput,
|
60
77
|
ProfileReq,
|
78
|
+
ProfileReqOutput,
|
79
|
+
ProfileReqType,
|
61
80
|
ReleaseMemoryOccupationReqInput,
|
62
81
|
ReleaseMemoryOccupationReqOutput,
|
63
82
|
ResumeMemoryOccupationReqInput,
|
@@ -98,7 +117,10 @@ class ReqState:
|
|
98
117
|
|
99
118
|
# For metrics
|
100
119
|
created_time: float
|
101
|
-
|
120
|
+
finished_time: float = 0.0
|
121
|
+
first_token_time: float = 0.0
|
122
|
+
last_time: float = 0.0
|
123
|
+
last_completion_tokens: int = 1
|
102
124
|
|
103
125
|
# For streaming output
|
104
126
|
last_output_offset: int = 0
|
@@ -113,11 +135,10 @@ class TokenizerManager:
|
|
113
135
|
port_args: PortArgs,
|
114
136
|
):
|
115
137
|
# Parse args
|
116
|
-
|
117
138
|
self.server_args = server_args
|
118
139
|
self.enable_metrics = server_args.enable_metrics
|
119
140
|
self.log_requests = server_args.log_requests
|
120
|
-
self.log_requests_level =
|
141
|
+
self.log_requests_level = server_args.log_requests_level
|
121
142
|
|
122
143
|
# Init inter-process communication
|
123
144
|
context = zmq.asyncio.Context(2)
|
@@ -143,6 +164,7 @@ class TokenizerManager:
|
|
143
164
|
)
|
144
165
|
|
145
166
|
self.is_generation = self.model_config.is_generation
|
167
|
+
self.is_image_gen = self.model_config.is_image_gen
|
146
168
|
self.context_len = self.model_config.context_len
|
147
169
|
self.image_token_id = self.model_config.image_token_id
|
148
170
|
|
@@ -178,9 +200,12 @@ class TokenizerManager:
|
|
178
200
|
# Store states
|
179
201
|
self.no_create_loop = False
|
180
202
|
self.rid_to_state: Dict[str, ReqState] = {}
|
203
|
+
self.gracefully_exit = False
|
204
|
+
self.last_receive_tstamp = 0
|
181
205
|
self.dump_requests_folder = "" # By default do not dump
|
182
206
|
self.dump_requests_threshold = 1000
|
183
207
|
self.dump_request_list: List[Tuple] = []
|
208
|
+
self.log_request_metadata = self.get_log_request_metadata()
|
184
209
|
|
185
210
|
# The event to notify the weight sync is finished.
|
186
211
|
self.model_update_lock = RWLock()
|
@@ -192,8 +217,19 @@ class TokenizerManager:
|
|
192
217
|
# For session info
|
193
218
|
self.session_futures = {} # session_id -> asyncio event
|
194
219
|
|
195
|
-
#
|
196
|
-
self.
|
220
|
+
# Set after scheduler is initialized
|
221
|
+
self.max_req_input_len = None
|
222
|
+
|
223
|
+
# Metrics
|
224
|
+
if self.enable_metrics:
|
225
|
+
self.metrics_collector = TokenizerMetricsCollector(
|
226
|
+
labels={
|
227
|
+
"model_name": self.server_args.served_model_name,
|
228
|
+
# TODO: Add lora name/path in the future,
|
229
|
+
},
|
230
|
+
)
|
231
|
+
|
232
|
+
# Communicators
|
197
233
|
self.init_weights_update_group_communicator = _Communicator(
|
198
234
|
self.send_to_scheduler, server_args.dp_size
|
199
235
|
)
|
@@ -212,22 +248,23 @@ class TokenizerManager:
|
|
212
248
|
self.resume_memory_occupation_communicator = _Communicator(
|
213
249
|
self.send_to_scheduler, server_args.dp_size
|
214
250
|
)
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
self.
|
221
|
-
|
222
|
-
"model_name": self.server_args.served_model_name,
|
223
|
-
# TODO: Add lora name/path in the future,
|
224
|
-
},
|
225
|
-
)
|
251
|
+
self.start_profile_communicator = _Communicator(
|
252
|
+
self.send_to_scheduler, server_args.dp_size
|
253
|
+
)
|
254
|
+
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
255
|
+
self.get_internal_state_communicator = _Communicator(
|
256
|
+
self.send_to_scheduler, server_args.dp_size
|
257
|
+
)
|
226
258
|
|
227
259
|
self._result_dispatcher = TypeBasedDispatcher(
|
228
260
|
[
|
229
261
|
(
|
230
|
-
(
|
262
|
+
(
|
263
|
+
BatchStrOut,
|
264
|
+
BatchEmbeddingOut,
|
265
|
+
BatchTokenIDOut,
|
266
|
+
BatchMultimodalOut,
|
267
|
+
),
|
231
268
|
self._handle_batch_output,
|
232
269
|
),
|
233
270
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
@@ -259,6 +296,15 @@ class TokenizerManager:
|
|
259
296
|
ResumeMemoryOccupationReqOutput,
|
260
297
|
self.resume_memory_occupation_communicator.handle_recv,
|
261
298
|
),
|
299
|
+
(
|
300
|
+
ProfileReqOutput,
|
301
|
+
self.start_profile_communicator.handle_recv,
|
302
|
+
),
|
303
|
+
(
|
304
|
+
GetInternalStateReqOutput,
|
305
|
+
self.get_internal_state_communicator.handle_recv,
|
306
|
+
),
|
307
|
+
(HealthCheckOutput, lambda x: None),
|
262
308
|
]
|
263
309
|
)
|
264
310
|
|
@@ -280,9 +326,9 @@ class TokenizerManager:
|
|
280
326
|
obj.normalize_batch_and_arguments()
|
281
327
|
|
282
328
|
if self.log_requests:
|
283
|
-
max_length
|
329
|
+
max_length, skip_names, _ = self.log_request_metadata
|
284
330
|
logger.info(
|
285
|
-
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
|
331
|
+
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
286
332
|
)
|
287
333
|
|
288
334
|
async with self.model_update_lock.reader_lock:
|
@@ -336,6 +382,7 @@ class TokenizerManager:
|
|
336
382
|
return_logprob = obj.return_logprob
|
337
383
|
logprob_start_len = obj.logprob_start_len
|
338
384
|
top_logprobs_num = obj.top_logprobs_num
|
385
|
+
token_ids_logprob = obj.token_ids_logprob
|
339
386
|
session_params = (
|
340
387
|
SessionParams(**obj.session_params) if obj.session_params else None
|
341
388
|
)
|
@@ -378,11 +425,13 @@ class TokenizerManager:
|
|
378
425
|
return_logprob,
|
379
426
|
logprob_start_len,
|
380
427
|
top_logprobs_num,
|
428
|
+
token_ids_logprob,
|
381
429
|
obj.stream,
|
382
430
|
lora_path=obj.lora_path,
|
383
431
|
input_embeds=input_embeds,
|
384
432
|
session_params=session_params,
|
385
433
|
custom_logit_processor=obj.custom_logit_processor,
|
434
|
+
return_hidden_states=obj.return_hidden_states,
|
386
435
|
)
|
387
436
|
elif isinstance(obj, EmbeddingReqInput):
|
388
437
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -400,8 +449,7 @@ class TokenizerManager:
|
|
400
449
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
401
450
|
created_time: Optional[float] = None,
|
402
451
|
):
|
403
|
-
|
404
|
-
state = ReqState([], False, event, obj, created_time=created_time)
|
452
|
+
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
405
453
|
self.rid_to_state[obj.rid] = state
|
406
454
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
407
455
|
|
@@ -419,7 +467,10 @@ class TokenizerManager:
|
|
419
467
|
except asyncio.TimeoutError:
|
420
468
|
if request is not None and await request.is_disconnected():
|
421
469
|
self.abort_request(obj.rid)
|
422
|
-
raise ValueError(
|
470
|
+
raise ValueError(
|
471
|
+
"Request is disconnected from the client side. "
|
472
|
+
f"Abort request {obj.rid}"
|
473
|
+
)
|
423
474
|
continue
|
424
475
|
|
425
476
|
out = state.out_list[-1]
|
@@ -427,8 +478,11 @@ class TokenizerManager:
|
|
427
478
|
state.out_list = []
|
428
479
|
if state.finished:
|
429
480
|
if self.log_requests:
|
430
|
-
max_length
|
431
|
-
|
481
|
+
max_length, skip_names, out_skip_names = self.log_request_metadata
|
482
|
+
if self.model_config.is_multimodal_gen:
|
483
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
484
|
+
else:
|
485
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
432
486
|
logger.info(msg)
|
433
487
|
del self.rid_to_state[obj.rid]
|
434
488
|
|
@@ -451,7 +505,10 @@ class TokenizerManager:
|
|
451
505
|
else:
|
452
506
|
if request is not None and await request.is_disconnected():
|
453
507
|
self.abort_request(obj.rid)
|
454
|
-
raise ValueError(
|
508
|
+
raise ValueError(
|
509
|
+
"Request is disconnected from the client side. "
|
510
|
+
f"Abort request {obj.rid}"
|
511
|
+
)
|
455
512
|
|
456
513
|
async def _handle_batch_request(
|
457
514
|
self,
|
@@ -542,12 +599,25 @@ class TokenizerManager:
|
|
542
599
|
req = AbortReq(rid)
|
543
600
|
self.send_to_scheduler.send_pyobj(req)
|
544
601
|
|
545
|
-
def start_profile(
|
546
|
-
|
547
|
-
|
602
|
+
async def start_profile(
|
603
|
+
self,
|
604
|
+
output_dir: Optional[str] = None,
|
605
|
+
num_steps: Optional[int] = None,
|
606
|
+
activities: Optional[List[str]] = None,
|
607
|
+
):
|
608
|
+
req = ProfileReq(
|
609
|
+
type=ProfileReqType.START_PROFILE,
|
610
|
+
output_dir=output_dir,
|
611
|
+
num_steps=num_steps,
|
612
|
+
activities=activities,
|
613
|
+
)
|
614
|
+
result = (await self.start_profile_communicator(req))[0]
|
615
|
+
if not result.success:
|
616
|
+
raise RuntimeError(result.message)
|
617
|
+
return result
|
548
618
|
|
549
619
|
def stop_profile(self):
|
550
|
-
req = ProfileReq.STOP_PROFILE
|
620
|
+
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
551
621
|
self.send_to_scheduler.send_pyobj(req)
|
552
622
|
|
553
623
|
async def update_weights_from_disk(
|
@@ -580,7 +650,7 @@ class TokenizerManager:
|
|
580
650
|
self.server_args.model_path = obj.model_path
|
581
651
|
self.server_args.load_format = obj.load_format
|
582
652
|
self.model_path = obj.model_path
|
583
|
-
return result.success, result.message
|
653
|
+
return result.success, result.message, result.num_paused_requests
|
584
654
|
else: # self.server_args.dp_size > 1
|
585
655
|
self.model_update_tmp = []
|
586
656
|
result = await self.model_update_result
|
@@ -592,7 +662,8 @@ class TokenizerManager:
|
|
592
662
|
self.model_path = obj.model_path
|
593
663
|
all_message = [r.message for r in result]
|
594
664
|
all_message = " | ".join(all_message)
|
595
|
-
|
665
|
+
all_paused_requests = [r.num_paused_requests for r in result]
|
666
|
+
return all_success, all_message, all_paused_requests
|
596
667
|
|
597
668
|
async def init_weights_update_group(
|
598
669
|
self,
|
@@ -687,6 +758,46 @@ class TokenizerManager:
|
|
687
758
|
):
|
688
759
|
await self.send_to_scheduler.send_pyobj(obj)
|
689
760
|
|
761
|
+
async def get_internal_state(self) -> Dict[Any, Any]:
|
762
|
+
req = GetInternalStateReq()
|
763
|
+
res: List[GetInternalStateReqOutput] = (
|
764
|
+
await self.get_internal_state_communicator(req)
|
765
|
+
)
|
766
|
+
return res[0].internal_state
|
767
|
+
|
768
|
+
def get_log_request_metadata(self):
|
769
|
+
max_length = None
|
770
|
+
skip_names = None
|
771
|
+
out_skip_names = None
|
772
|
+
if self.log_requests:
|
773
|
+
if self.log_requests_level == 0:
|
774
|
+
max_length = 1 << 30
|
775
|
+
skip_names = set(
|
776
|
+
[
|
777
|
+
"text",
|
778
|
+
"input_ids",
|
779
|
+
"input_embeds",
|
780
|
+
"image_data",
|
781
|
+
"audio_data",
|
782
|
+
"lora_path",
|
783
|
+
]
|
784
|
+
)
|
785
|
+
out_skip_names = set(
|
786
|
+
[
|
787
|
+
"text",
|
788
|
+
"output_ids",
|
789
|
+
]
|
790
|
+
)
|
791
|
+
elif self.log_requests_level == 1:
|
792
|
+
max_length = 2048
|
793
|
+
elif self.log_requests_level == 2:
|
794
|
+
max_length = 1 << 30
|
795
|
+
else:
|
796
|
+
raise ValueError(
|
797
|
+
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
798
|
+
)
|
799
|
+
return max_length, skip_names, out_skip_names
|
800
|
+
|
690
801
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
691
802
|
if obj.log_requests is not None:
|
692
803
|
self.log_requests = obj.log_requests
|
@@ -697,6 +808,7 @@ class TokenizerManager:
|
|
697
808
|
if obj.dump_requests_threshold is not None:
|
698
809
|
self.dump_requests_threshold = obj.dump_requests_threshold
|
699
810
|
logging.info(f"Config logging: {obj=}")
|
811
|
+
self.log_request_metadata = self.get_log_request_metadata()
|
700
812
|
|
701
813
|
def create_abort_task(self, obj: GenerateReqInput):
|
702
814
|
# Abort the request if the client is disconnected.
|
@@ -761,15 +873,20 @@ class TokenizerManager:
|
|
761
873
|
while True:
|
762
874
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
763
875
|
self._result_dispatcher(recv_obj)
|
876
|
+
self.last_receive_tstamp = time.time()
|
764
877
|
|
765
878
|
def _handle_batch_output(
|
766
|
-
self,
|
879
|
+
self,
|
880
|
+
recv_obj: Union[
|
881
|
+
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
882
|
+
],
|
767
883
|
):
|
768
884
|
for i, rid in enumerate(recv_obj.rids):
|
769
885
|
state = self.rid_to_state.get(rid, None)
|
770
886
|
if state is None:
|
771
887
|
continue
|
772
888
|
|
889
|
+
# Build meta_info and return value
|
773
890
|
meta_info = {
|
774
891
|
"id": rid,
|
775
892
|
"finish_reason": recv_obj.finished_reasons[i],
|
@@ -780,14 +897,12 @@ class TokenizerManager:
|
|
780
897
|
self.convert_logprob_style(
|
781
898
|
meta_info,
|
782
899
|
state.obj.top_logprobs_num,
|
900
|
+
state.obj.token_ids_logprob,
|
783
901
|
state.obj.return_text_in_logprobs,
|
784
902
|
recv_obj,
|
785
903
|
i,
|
786
904
|
)
|
787
905
|
|
788
|
-
if self.server_args.speculative_algorithm:
|
789
|
-
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
790
|
-
|
791
906
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
792
907
|
meta_info.update(
|
793
908
|
{
|
@@ -796,10 +911,7 @@ class TokenizerManager:
|
|
796
911
|
}
|
797
912
|
)
|
798
913
|
|
799
|
-
if (
|
800
|
-
hasattr(recv_obj, "output_hidden_states")
|
801
|
-
and len(recv_obj.output_hidden_states[i]) > 0
|
802
|
-
):
|
914
|
+
if getattr(recv_obj, "output_hidden_states", None):
|
803
915
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
804
916
|
|
805
917
|
if isinstance(recv_obj, BatchStrOut):
|
@@ -808,10 +920,20 @@ class TokenizerManager:
|
|
808
920
|
"meta_info": meta_info,
|
809
921
|
}
|
810
922
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
923
|
+
if self.server_args.stream_output and state.obj.stream:
|
924
|
+
output_token_ids = recv_obj.output_ids[i][
|
925
|
+
state.last_output_offset :
|
926
|
+
]
|
927
|
+
state.last_output_offset = len(recv_obj.output_ids[i])
|
928
|
+
else:
|
929
|
+
output_token_ids = recv_obj.output_ids[i]
|
930
|
+
|
811
931
|
out_dict = {
|
812
|
-
"
|
932
|
+
"output_ids": output_token_ids,
|
813
933
|
"meta_info": meta_info,
|
814
934
|
}
|
935
|
+
elif isinstance(recv_obj, BatchMultimodalOut):
|
936
|
+
raise NotImplementedError()
|
815
937
|
else:
|
816
938
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
817
939
|
out_dict = {
|
@@ -819,10 +941,17 @@ class TokenizerManager:
|
|
819
941
|
"meta_info": meta_info,
|
820
942
|
}
|
821
943
|
|
822
|
-
state.out_list.append(out_dict)
|
823
944
|
state.finished = recv_obj.finished_reasons[i] is not None
|
945
|
+
if state.finished:
|
946
|
+
if self.server_args.speculative_algorithm:
|
947
|
+
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
948
|
+
state.finished_time = time.time()
|
949
|
+
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
950
|
+
|
951
|
+
state.out_list.append(out_dict)
|
824
952
|
state.event.set()
|
825
953
|
|
954
|
+
# Log metrics and dump
|
826
955
|
if self.enable_metrics and state.obj.log_metrics:
|
827
956
|
self.collect_metrics(state, recv_obj, i)
|
828
957
|
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
@@ -832,6 +961,7 @@ class TokenizerManager:
|
|
832
961
|
self,
|
833
962
|
meta_info: dict,
|
834
963
|
top_logprobs_num: int,
|
964
|
+
token_ids_logprob: List[int],
|
835
965
|
return_text_in_logprobs: bool,
|
836
966
|
recv_obj: BatchStrOut,
|
837
967
|
recv_obj_index: int,
|
@@ -859,6 +989,20 @@ class TokenizerManager:
|
|
859
989
|
return_text_in_logprobs,
|
860
990
|
)
|
861
991
|
|
992
|
+
if token_ids_logprob is not None:
|
993
|
+
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
994
|
+
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
|
995
|
+
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
|
996
|
+
return_text_in_logprobs,
|
997
|
+
)
|
998
|
+
meta_info["output_token_ids_logprobs"] = (
|
999
|
+
self.detokenize_top_logprobs_tokens(
|
1000
|
+
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
|
1001
|
+
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
|
1002
|
+
return_text_in_logprobs,
|
1003
|
+
)
|
1004
|
+
)
|
1005
|
+
|
862
1006
|
def detokenize_logprob_tokens(
|
863
1007
|
self,
|
864
1008
|
token_logprobs_val: List[float],
|
@@ -902,34 +1046,30 @@ class TokenizerManager:
|
|
902
1046
|
else 0
|
903
1047
|
)
|
904
1048
|
|
905
|
-
if state.first_token_time
|
906
|
-
state.first_token_time = time.time()
|
1049
|
+
if state.first_token_time == 0.0:
|
1050
|
+
state.first_token_time = state.last_time = time.time()
|
1051
|
+
state.last_completion_tokens = completion_tokens
|
907
1052
|
self.metrics_collector.observe_time_to_first_token(
|
908
1053
|
state.first_token_time - state.created_time
|
909
1054
|
)
|
910
1055
|
else:
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
1056
|
+
num_new_tokens = completion_tokens - state.last_completion_tokens
|
1057
|
+
if num_new_tokens:
|
1058
|
+
new_time = time.time()
|
1059
|
+
interval = new_time - state.last_time
|
1060
|
+
self.metrics_collector.observe_inter_token_latency(
|
1061
|
+
interval,
|
1062
|
+
num_new_tokens,
|
915
1063
|
)
|
1064
|
+
state.last_time = new_time
|
1065
|
+
state.last_completion_tokens = completion_tokens
|
916
1066
|
|
917
1067
|
if state.finished:
|
918
1068
|
self.metrics_collector.observe_one_finished_request(
|
919
|
-
recv_obj.prompt_tokens[i],
|
920
|
-
|
921
|
-
|
922
|
-
time.time() - state.created_time
|
1069
|
+
recv_obj.prompt_tokens[i],
|
1070
|
+
completion_tokens,
|
1071
|
+
state.finished_time - state.created_time,
|
923
1072
|
)
|
924
|
-
# Compute time_per_output_token for the non-streaming case
|
925
|
-
if (
|
926
|
-
hasattr(state.obj, "stream")
|
927
|
-
and not state.obj.stream
|
928
|
-
and completion_tokens >= 1
|
929
|
-
):
|
930
|
-
self.metrics_collector.observe_time_per_output_token(
|
931
|
-
(time.time() - state.created_time) / completion_tokens
|
932
|
-
)
|
933
1073
|
|
934
1074
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
935
1075
|
self.dump_request_list.append(
|
@@ -984,7 +1124,7 @@ async def print_exception_wrapper(func):
|
|
984
1124
|
|
985
1125
|
|
986
1126
|
class SignalHandler:
|
987
|
-
def __init__(self, tokenizer_manager):
|
1127
|
+
def __init__(self, tokenizer_manager: TokenizerManager):
|
988
1128
|
self.tokenizer_manager = tokenizer_manager
|
989
1129
|
|
990
1130
|
def signal_handler(self, signum=None, frame=None):
|
@@ -998,22 +1138,38 @@ T = TypeVar("T")
|
|
998
1138
|
|
999
1139
|
|
1000
1140
|
class _Communicator(Generic[T]):
|
1141
|
+
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
1142
|
+
|
1001
1143
|
def __init__(self, sender, fan_out: int):
|
1002
1144
|
self._sender = sender
|
1003
1145
|
self._fan_out = fan_out
|
1004
|
-
self.
|
1146
|
+
self._result_event: Optional[asyncio.Event] = None
|
1005
1147
|
self._result_values: Optional[List[T]] = None
|
1148
|
+
self._ready_queue: Deque[asyncio.Future] = deque()
|
1006
1149
|
|
1007
1150
|
async def __call__(self, obj):
|
1008
|
-
|
1009
|
-
self.
|
1151
|
+
ready_event = asyncio.Event()
|
1152
|
+
if self._result_event is not None or len(self._ready_queue) > 0:
|
1153
|
+
self._ready_queue.append(ready_event)
|
1154
|
+
await ready_event.wait()
|
1155
|
+
assert self._result_event is None
|
1156
|
+
assert self._result_values is None
|
1157
|
+
|
1158
|
+
if obj:
|
1159
|
+
self._sender.send_pyobj(obj)
|
1160
|
+
|
1161
|
+
self._result_event = asyncio.Event()
|
1010
1162
|
self._result_values = []
|
1011
|
-
await self.
|
1163
|
+
await self._result_event.wait()
|
1012
1164
|
result_values = self._result_values
|
1013
|
-
self.
|
1165
|
+
self._result_event = self._result_values = None
|
1166
|
+
|
1167
|
+
if len(self._ready_queue) > 0:
|
1168
|
+
self._ready_queue.popleft().set()
|
1169
|
+
|
1014
1170
|
return result_values
|
1015
1171
|
|
1016
1172
|
def handle_recv(self, recv_obj: T):
|
1017
1173
|
self._result_values.append(recv_obj)
|
1018
1174
|
if len(self._result_values) == self._fan_out:
|
1019
|
-
self.
|
1175
|
+
self._result_event.set()
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,10 +15,13 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import threading
|
18
|
-
from typing import Optional
|
18
|
+
from typing import Optional, Tuple
|
19
|
+
|
20
|
+
import torch
|
19
21
|
|
20
22
|
from sglang.srt.configs.model_config import ModelConfig
|
21
23
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
22
25
|
from sglang.srt.managers.io_struct import (
|
23
26
|
GetWeightsByNameReqInput,
|
24
27
|
InitWeightsUpdateGroupReqInput,
|
@@ -27,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|
27
30
|
UpdateWeightsFromTensorReqInput,
|
28
31
|
)
|
29
32
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
33
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
30
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
31
35
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
32
36
|
from sglang.srt.server_args import ServerArgs
|
@@ -46,6 +50,8 @@ class TpModelWorker:
|
|
46
50
|
dp_rank: Optional[int],
|
47
51
|
nccl_port: int,
|
48
52
|
is_draft_worker: bool = False,
|
53
|
+
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
54
|
+
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
49
55
|
):
|
50
56
|
# Parse args
|
51
57
|
self.tp_rank = tp_rank
|
@@ -74,6 +80,8 @@ class TpModelWorker:
|
|
74
80
|
nccl_port=nccl_port,
|
75
81
|
server_args=server_args,
|
76
82
|
is_draft_worker=is_draft_worker,
|
83
|
+
req_to_token_pool=req_to_token_pool,
|
84
|
+
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
77
85
|
)
|
78
86
|
if server_args.skip_tokenizer_init:
|
79
87
|
self.tokenizer = self.processor = None
|
@@ -151,7 +159,7 @@ class TpModelWorker:
|
|
151
159
|
def get_memory_pool(self):
|
152
160
|
return (
|
153
161
|
self.model_runner.req_to_token_pool,
|
154
|
-
self.model_runner.
|
162
|
+
self.model_runner.token_to_kv_pool_allocator,
|
155
163
|
)
|
156
164
|
|
157
165
|
def forward_batch_generation(
|
@@ -159,7 +167,7 @@ class TpModelWorker:
|
|
159
167
|
model_worker_batch: ModelWorkerBatch,
|
160
168
|
launch_done: Optional[threading.Event] = None,
|
161
169
|
skip_sample: bool = False,
|
162
|
-
):
|
170
|
+
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
163
171
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
164
172
|
logits_output = self.model_runner.forward(forward_batch)
|
165
173
|
if launch_done:
|
@@ -205,7 +213,10 @@ class TpModelWorker:
|
|
205
213
|
|
206
214
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
207
215
|
success, message = self.model_runner.update_weights_from_tensor(
|
208
|
-
MultiprocessingSerializer.deserialize(
|
216
|
+
named_tensors=MultiprocessingSerializer.deserialize(
|
217
|
+
recv_req.serialized_named_tensors
|
218
|
+
),
|
219
|
+
load_format=recv_req.load_format,
|
209
220
|
)
|
210
221
|
return success, message
|
211
222
|
|