sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,614 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
6
|
+
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
|
7
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from sglang.srt.managers.scheduler import (
|
11
|
+
EmbeddingBatchResult,
|
12
|
+
GenerationBatchResult,
|
13
|
+
ScheduleBatch,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
class SchedulerOutputProcessorMixin:
|
18
|
+
"""
|
19
|
+
This class implements the output processing logic for Scheduler.
|
20
|
+
We put them into a separate file to make the `scheduler.py` shorter.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def process_batch_result_prefill(
|
24
|
+
self,
|
25
|
+
batch: ScheduleBatch,
|
26
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
27
|
+
):
|
28
|
+
skip_stream_req = None
|
29
|
+
|
30
|
+
if self.is_generation:
|
31
|
+
(
|
32
|
+
logits_output,
|
33
|
+
next_token_ids,
|
34
|
+
extend_input_len_per_req,
|
35
|
+
extend_logprob_start_len_per_req,
|
36
|
+
bid,
|
37
|
+
) = (
|
38
|
+
result.logits_output,
|
39
|
+
result.next_token_ids,
|
40
|
+
result.extend_input_len_per_req,
|
41
|
+
result.extend_logprob_start_len_per_req,
|
42
|
+
result.bid,
|
43
|
+
)
|
44
|
+
|
45
|
+
if self.enable_overlap:
|
46
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
47
|
+
else:
|
48
|
+
# Move next_token_ids and logprobs to cpu
|
49
|
+
next_token_ids = next_token_ids.tolist()
|
50
|
+
if batch.return_logprob:
|
51
|
+
if logits_output.next_token_logprobs is not None:
|
52
|
+
logits_output.next_token_logprobs = (
|
53
|
+
logits_output.next_token_logprobs.tolist()
|
54
|
+
)
|
55
|
+
if logits_output.input_token_logprobs is not None:
|
56
|
+
logits_output.input_token_logprobs = tuple(
|
57
|
+
logits_output.input_token_logprobs.tolist()
|
58
|
+
)
|
59
|
+
|
60
|
+
hidden_state_offset = 0
|
61
|
+
|
62
|
+
# Check finish conditions
|
63
|
+
logprob_pt = 0
|
64
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
65
|
+
if req.is_retracted:
|
66
|
+
continue
|
67
|
+
|
68
|
+
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
69
|
+
# Free the one delayed token for the mixed decode batch
|
70
|
+
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
71
|
+
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
72
|
+
continue
|
73
|
+
|
74
|
+
if req.is_chunked <= 0:
|
75
|
+
# req output_ids are set here
|
76
|
+
req.output_ids.append(next_token_id)
|
77
|
+
req.check_finished()
|
78
|
+
|
79
|
+
if req.finished():
|
80
|
+
self.tree_cache.cache_finished_req(req)
|
81
|
+
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
82
|
+
# This updates radix so others can match
|
83
|
+
self.tree_cache.cache_unfinished_req(req)
|
84
|
+
|
85
|
+
if req.return_logprob:
|
86
|
+
assert extend_logprob_start_len_per_req is not None
|
87
|
+
assert extend_input_len_per_req is not None
|
88
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
89
|
+
extend_input_len = extend_input_len_per_req[i]
|
90
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
91
|
+
self.add_logprob_return_values(
|
92
|
+
i,
|
93
|
+
req,
|
94
|
+
logprob_pt,
|
95
|
+
next_token_ids,
|
96
|
+
num_input_logprobs,
|
97
|
+
logits_output,
|
98
|
+
)
|
99
|
+
logprob_pt += num_input_logprobs
|
100
|
+
|
101
|
+
if (
|
102
|
+
req.return_hidden_states
|
103
|
+
and logits_output.hidden_states is not None
|
104
|
+
):
|
105
|
+
req.hidden_states.append(
|
106
|
+
logits_output.hidden_states[
|
107
|
+
hidden_state_offset : (
|
108
|
+
hidden_state_offset := hidden_state_offset
|
109
|
+
+ len(req.origin_input_ids)
|
110
|
+
)
|
111
|
+
]
|
112
|
+
.cpu()
|
113
|
+
.clone()
|
114
|
+
.tolist()
|
115
|
+
)
|
116
|
+
|
117
|
+
if req.grammar is not None:
|
118
|
+
req.grammar.accept_token(next_token_id)
|
119
|
+
req.grammar.finished = req.finished()
|
120
|
+
else:
|
121
|
+
# being chunked reqs' prefill is not finished
|
122
|
+
req.is_chunked -= 1
|
123
|
+
# There is only at most one request being currently chunked.
|
124
|
+
# Because this request does not finish prefill,
|
125
|
+
# we don't want to stream the request currently being chunked.
|
126
|
+
skip_stream_req = req
|
127
|
+
|
128
|
+
# Incrementally update input logprobs.
|
129
|
+
if req.return_logprob:
|
130
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
131
|
+
extend_input_len = extend_input_len_per_req[i]
|
132
|
+
if extend_logprob_start_len < extend_input_len:
|
133
|
+
# Update input logprobs.
|
134
|
+
num_input_logprobs = (
|
135
|
+
extend_input_len - extend_logprob_start_len
|
136
|
+
)
|
137
|
+
self.add_input_logprob_return_values(
|
138
|
+
i,
|
139
|
+
req,
|
140
|
+
logits_output,
|
141
|
+
logprob_pt,
|
142
|
+
num_input_logprobs,
|
143
|
+
last_prefill_chunk=False,
|
144
|
+
)
|
145
|
+
logprob_pt += num_input_logprobs
|
146
|
+
|
147
|
+
if batch.next_batch_sampling_info:
|
148
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
149
|
+
self.current_stream.synchronize()
|
150
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
151
|
+
|
152
|
+
else: # embedding or reward model
|
153
|
+
embeddings, bid = result.embeddings, result.bid
|
154
|
+
embeddings = embeddings.tolist()
|
155
|
+
|
156
|
+
# Check finish conditions
|
157
|
+
for i, req in enumerate(batch.reqs):
|
158
|
+
if req.is_retracted:
|
159
|
+
continue
|
160
|
+
|
161
|
+
req.embedding = embeddings[i]
|
162
|
+
if req.is_chunked <= 0:
|
163
|
+
# Dummy output token for embedding models
|
164
|
+
req.output_ids.append(0)
|
165
|
+
req.check_finished()
|
166
|
+
|
167
|
+
if req.finished():
|
168
|
+
self.tree_cache.cache_finished_req(req)
|
169
|
+
else:
|
170
|
+
self.tree_cache.cache_unfinished_req(req)
|
171
|
+
else:
|
172
|
+
# being chunked reqs' prefill is not finished
|
173
|
+
req.is_chunked -= 1
|
174
|
+
|
175
|
+
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
176
|
+
|
177
|
+
def process_batch_result_decode(
|
178
|
+
self,
|
179
|
+
batch: ScheduleBatch,
|
180
|
+
result: GenerationBatchResult,
|
181
|
+
):
|
182
|
+
logits_output, next_token_ids, bid = (
|
183
|
+
result.logits_output,
|
184
|
+
result.next_token_ids,
|
185
|
+
result.bid,
|
186
|
+
)
|
187
|
+
self.num_generated_tokens += len(batch.reqs)
|
188
|
+
|
189
|
+
if self.enable_overlap:
|
190
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
191
|
+
next_token_logprobs = logits_output.next_token_logprobs
|
192
|
+
elif batch.spec_algorithm.is_none():
|
193
|
+
# spec decoding handles output logprobs inside verify process.
|
194
|
+
next_token_ids = next_token_ids.tolist()
|
195
|
+
if batch.return_logprob:
|
196
|
+
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
197
|
+
|
198
|
+
self.token_to_kv_pool_allocator.free_group_begin()
|
199
|
+
|
200
|
+
# Check finish condition
|
201
|
+
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
202
|
+
# We should ignore using next_token_ids for spec decoding cases.
|
203
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
204
|
+
if req.is_retracted:
|
205
|
+
continue
|
206
|
+
|
207
|
+
if self.enable_overlap and req.finished():
|
208
|
+
# Free the one extra delayed token
|
209
|
+
if self.page_size == 1:
|
210
|
+
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
211
|
+
else:
|
212
|
+
# Only free when the extra token is in a new page
|
213
|
+
if (
|
214
|
+
len(req.origin_input_ids) + len(req.output_ids) - 1
|
215
|
+
) % self.page_size == 0:
|
216
|
+
self.token_to_kv_pool_allocator.free(
|
217
|
+
batch.out_cache_loc[i : i + 1]
|
218
|
+
)
|
219
|
+
continue
|
220
|
+
|
221
|
+
if batch.spec_algorithm.is_none():
|
222
|
+
# speculative worker will solve the output_ids in speculative decoding
|
223
|
+
req.output_ids.append(next_token_id)
|
224
|
+
|
225
|
+
req.check_finished()
|
226
|
+
if req.finished():
|
227
|
+
self.tree_cache.cache_finished_req(req)
|
228
|
+
|
229
|
+
if req.return_logprob and batch.spec_algorithm.is_none():
|
230
|
+
# speculative worker handles logprob in speculative decoding
|
231
|
+
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
232
|
+
req.output_token_logprobs_idx.append(next_token_id)
|
233
|
+
if req.top_logprobs_num > 0:
|
234
|
+
req.output_top_logprobs_val.append(
|
235
|
+
logits_output.next_token_top_logprobs_val[i]
|
236
|
+
)
|
237
|
+
req.output_top_logprobs_idx.append(
|
238
|
+
logits_output.next_token_top_logprobs_idx[i]
|
239
|
+
)
|
240
|
+
if req.token_ids_logprob is not None:
|
241
|
+
req.output_token_ids_logprobs_val.append(
|
242
|
+
logits_output.next_token_token_ids_logprobs_val[i]
|
243
|
+
)
|
244
|
+
req.output_token_ids_logprobs_idx.append(
|
245
|
+
logits_output.next_token_token_ids_logprobs_idx[i]
|
246
|
+
)
|
247
|
+
|
248
|
+
if req.return_hidden_states and logits_output.hidden_states is not None:
|
249
|
+
req.hidden_states.append(
|
250
|
+
logits_output.hidden_states[i].cpu().clone().tolist()
|
251
|
+
)
|
252
|
+
|
253
|
+
if req.grammar is not None and batch.spec_algorithm.is_none():
|
254
|
+
req.grammar.accept_token(next_token_id)
|
255
|
+
req.grammar.finished = req.finished()
|
256
|
+
|
257
|
+
if batch.next_batch_sampling_info:
|
258
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
259
|
+
self.current_stream.synchronize()
|
260
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
261
|
+
|
262
|
+
self.stream_output(batch.reqs, batch.return_logprob)
|
263
|
+
|
264
|
+
self.token_to_kv_pool_allocator.free_group_end()
|
265
|
+
|
266
|
+
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
267
|
+
if (
|
268
|
+
self.attn_tp_rank == 0
|
269
|
+
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
270
|
+
):
|
271
|
+
self.log_decode_stats()
|
272
|
+
|
273
|
+
def add_input_logprob_return_values(
|
274
|
+
self,
|
275
|
+
i: int,
|
276
|
+
req: Req,
|
277
|
+
output: LogitsProcessorOutput,
|
278
|
+
logprob_pt: int,
|
279
|
+
num_input_logprobs: int,
|
280
|
+
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
281
|
+
):
|
282
|
+
"""Incrementally add input logprobs to `req`.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
i: The request index in a batch.
|
286
|
+
req: The request. Input logprobs inside req are modified as a
|
287
|
+
consequence of the API
|
288
|
+
fill_ids: The prefill ids processed.
|
289
|
+
output: Logit processor output that's used to compute input logprobs
|
290
|
+
last_prefill_chunk: True if it is the last prefill (when chunked).
|
291
|
+
Some of input logprob operation should only happen at the last
|
292
|
+
prefill (e.g., computing input token logprobs).
|
293
|
+
"""
|
294
|
+
assert output.input_token_logprobs is not None
|
295
|
+
if req.input_token_logprobs is None:
|
296
|
+
req.input_token_logprobs = []
|
297
|
+
if req.temp_input_top_logprobs_val is None:
|
298
|
+
req.temp_input_top_logprobs_val = []
|
299
|
+
if req.temp_input_top_logprobs_idx is None:
|
300
|
+
req.temp_input_top_logprobs_idx = []
|
301
|
+
if req.temp_input_token_ids_logprobs_val is None:
|
302
|
+
req.temp_input_token_ids_logprobs_val = []
|
303
|
+
if req.temp_input_token_ids_logprobs_idx is None:
|
304
|
+
req.temp_input_token_ids_logprobs_idx = []
|
305
|
+
|
306
|
+
if req.input_token_logprobs_val is not None:
|
307
|
+
# The input logprob has been already computed. It only happens
|
308
|
+
# upon retract.
|
309
|
+
if req.top_logprobs_num > 0:
|
310
|
+
assert req.input_token_logprobs_val is not None
|
311
|
+
return
|
312
|
+
|
313
|
+
# Important for the performance.
|
314
|
+
assert isinstance(output.input_token_logprobs, tuple)
|
315
|
+
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
316
|
+
input_token_logprobs = input_token_logprobs[
|
317
|
+
logprob_pt : logprob_pt + num_input_logprobs
|
318
|
+
]
|
319
|
+
req.input_token_logprobs.extend(input_token_logprobs)
|
320
|
+
|
321
|
+
if req.top_logprobs_num > 0:
|
322
|
+
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
323
|
+
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
324
|
+
|
325
|
+
if req.token_ids_logprob is not None:
|
326
|
+
req.temp_input_token_ids_logprobs_val.append(
|
327
|
+
output.input_token_ids_logprobs_val[i]
|
328
|
+
)
|
329
|
+
req.temp_input_token_ids_logprobs_idx.append(
|
330
|
+
output.input_token_ids_logprobs_idx[i]
|
331
|
+
)
|
332
|
+
|
333
|
+
if last_prefill_chunk:
|
334
|
+
input_token_logprobs = req.input_token_logprobs
|
335
|
+
req.input_token_logprobs = None
|
336
|
+
assert req.input_token_logprobs_val is None
|
337
|
+
assert req.input_token_logprobs_idx is None
|
338
|
+
assert req.input_top_logprobs_val is None
|
339
|
+
assert req.input_top_logprobs_idx is None
|
340
|
+
|
341
|
+
# Compute input_token_logprobs_val
|
342
|
+
# Always pad the first one with None.
|
343
|
+
req.input_token_logprobs_val = [None]
|
344
|
+
req.input_token_logprobs_val.extend(input_token_logprobs)
|
345
|
+
# The last input logprob is for sampling, so just pop it out.
|
346
|
+
req.input_token_logprobs_val.pop()
|
347
|
+
|
348
|
+
# Compute input_token_logprobs_idx
|
349
|
+
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
350
|
+
# Clip the padded hash values from image tokens.
|
351
|
+
# Otherwise, it will lead to detokenization errors.
|
352
|
+
input_token_logprobs_idx = [
|
353
|
+
x if x < self.model_config.vocab_size - 1 else 0
|
354
|
+
for x in input_token_logprobs_idx
|
355
|
+
]
|
356
|
+
req.input_token_logprobs_idx = input_token_logprobs_idx
|
357
|
+
|
358
|
+
if req.top_logprobs_num > 0:
|
359
|
+
req.input_top_logprobs_val = [None]
|
360
|
+
req.input_top_logprobs_idx = [None]
|
361
|
+
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
362
|
+
req.temp_input_token_ids_logprobs_idx
|
363
|
+
)
|
364
|
+
for val, idx in zip(
|
365
|
+
req.temp_input_top_logprobs_val,
|
366
|
+
req.temp_input_top_logprobs_idx,
|
367
|
+
strict=True,
|
368
|
+
):
|
369
|
+
req.input_top_logprobs_val.extend(val)
|
370
|
+
req.input_top_logprobs_idx.extend(idx)
|
371
|
+
|
372
|
+
# Last token is a sample token.
|
373
|
+
req.input_top_logprobs_val.pop()
|
374
|
+
req.input_top_logprobs_idx.pop()
|
375
|
+
req.temp_input_top_logprobs_idx = None
|
376
|
+
req.temp_input_top_logprobs_val = None
|
377
|
+
|
378
|
+
if req.token_ids_logprob is not None:
|
379
|
+
req.input_token_ids_logprobs_val = [None]
|
380
|
+
req.input_token_ids_logprobs_idx = [None]
|
381
|
+
|
382
|
+
for val, idx in zip(
|
383
|
+
req.temp_input_token_ids_logprobs_val,
|
384
|
+
req.temp_input_token_ids_logprobs_idx,
|
385
|
+
strict=True,
|
386
|
+
):
|
387
|
+
req.input_token_ids_logprobs_val.extend(val)
|
388
|
+
req.input_token_ids_logprobs_idx.extend(idx)
|
389
|
+
|
390
|
+
# Last token is a sample token.
|
391
|
+
req.input_token_ids_logprobs_val.pop()
|
392
|
+
req.input_token_ids_logprobs_idx.pop()
|
393
|
+
req.temp_input_token_ids_logprobs_idx = None
|
394
|
+
req.temp_input_token_ids_logprobs_val = None
|
395
|
+
|
396
|
+
if req.return_logprob:
|
397
|
+
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
398
|
+
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
399
|
+
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
400
|
+
if req.top_logprobs_num > 0:
|
401
|
+
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
402
|
+
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
403
|
+
if req.token_ids_logprob is not None:
|
404
|
+
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
405
|
+
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
406
|
+
|
407
|
+
def add_logprob_return_values(
|
408
|
+
self,
|
409
|
+
i: int,
|
410
|
+
req: Req,
|
411
|
+
pt: int,
|
412
|
+
next_token_ids: List[int],
|
413
|
+
num_input_logprobs: int,
|
414
|
+
output: LogitsProcessorOutput,
|
415
|
+
):
|
416
|
+
"""Attach logprobs to the return values."""
|
417
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
418
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
419
|
+
|
420
|
+
self.add_input_logprob_return_values(
|
421
|
+
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
422
|
+
)
|
423
|
+
|
424
|
+
if req.top_logprobs_num > 0:
|
425
|
+
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
426
|
+
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
427
|
+
|
428
|
+
if req.token_ids_logprob is not None:
|
429
|
+
req.output_token_ids_logprobs_val.append(
|
430
|
+
output.next_token_token_ids_logprobs_val[i]
|
431
|
+
)
|
432
|
+
req.output_token_ids_logprobs_idx.append(
|
433
|
+
output.next_token_token_ids_logprobs_idx[i]
|
434
|
+
)
|
435
|
+
|
436
|
+
return num_input_logprobs
|
437
|
+
|
438
|
+
def stream_output(
|
439
|
+
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
440
|
+
):
|
441
|
+
"""Stream the output to detokenizer."""
|
442
|
+
if self.is_generation:
|
443
|
+
self.stream_output_generation(reqs, return_logprob, skip_req)
|
444
|
+
else: # embedding or reward model
|
445
|
+
self.stream_output_embedding(reqs)
|
446
|
+
|
447
|
+
def stream_output_generation(
|
448
|
+
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
449
|
+
):
|
450
|
+
rids = []
|
451
|
+
finished_reasons: List[BaseFinishReason] = []
|
452
|
+
|
453
|
+
decoded_texts = []
|
454
|
+
decode_ids_list = []
|
455
|
+
read_offsets = []
|
456
|
+
output_ids = []
|
457
|
+
|
458
|
+
skip_special_tokens = []
|
459
|
+
spaces_between_special_tokens = []
|
460
|
+
no_stop_trim = []
|
461
|
+
prompt_tokens = []
|
462
|
+
completion_tokens = []
|
463
|
+
cached_tokens = []
|
464
|
+
spec_verify_ct = []
|
465
|
+
output_hidden_states = None
|
466
|
+
|
467
|
+
if return_logprob:
|
468
|
+
input_token_logprobs_val = []
|
469
|
+
input_token_logprobs_idx = []
|
470
|
+
output_token_logprobs_val = []
|
471
|
+
output_token_logprobs_idx = []
|
472
|
+
input_top_logprobs_val = []
|
473
|
+
input_top_logprobs_idx = []
|
474
|
+
output_top_logprobs_val = []
|
475
|
+
output_top_logprobs_idx = []
|
476
|
+
input_token_ids_logprobs_val = []
|
477
|
+
input_token_ids_logprobs_idx = []
|
478
|
+
output_token_ids_logprobs_val = []
|
479
|
+
output_token_ids_logprobs_idx = []
|
480
|
+
else:
|
481
|
+
input_token_logprobs_val = input_token_logprobs_idx = (
|
482
|
+
output_token_logprobs_val
|
483
|
+
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
484
|
+
input_top_logprobs_idx
|
485
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
486
|
+
input_token_ids_logprobs_val
|
487
|
+
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
488
|
+
output_token_ids_logprobs_idx
|
489
|
+
) = None
|
490
|
+
|
491
|
+
for req in reqs:
|
492
|
+
if req is skip_req:
|
493
|
+
continue
|
494
|
+
|
495
|
+
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
496
|
+
if self.model_config.is_multimodal_gen and req.to_abort:
|
497
|
+
continue
|
498
|
+
|
499
|
+
if (
|
500
|
+
req.finished()
|
501
|
+
# If stream, follow the given stream_interval
|
502
|
+
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
503
|
+
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
504
|
+
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
505
|
+
# always increase one-by-one.
|
506
|
+
or (
|
507
|
+
not req.stream
|
508
|
+
and len(req.output_ids) % 50 == 0
|
509
|
+
and not self.model_config.is_multimodal_gen
|
510
|
+
)
|
511
|
+
):
|
512
|
+
rids.append(req.rid)
|
513
|
+
finished_reasons.append(
|
514
|
+
req.finished_reason.to_json() if req.finished_reason else None
|
515
|
+
)
|
516
|
+
decoded_texts.append(req.decoded_text)
|
517
|
+
decode_ids, read_offset = req.init_incremental_detokenize()
|
518
|
+
decode_ids_list.append(decode_ids)
|
519
|
+
read_offsets.append(read_offset)
|
520
|
+
if self.skip_tokenizer_init:
|
521
|
+
output_ids.append(req.output_ids)
|
522
|
+
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
523
|
+
spaces_between_special_tokens.append(
|
524
|
+
req.sampling_params.spaces_between_special_tokens
|
525
|
+
)
|
526
|
+
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
527
|
+
prompt_tokens.append(len(req.origin_input_ids))
|
528
|
+
completion_tokens.append(len(req.output_ids))
|
529
|
+
cached_tokens.append(req.cached_tokens)
|
530
|
+
|
531
|
+
if not self.spec_algorithm.is_none():
|
532
|
+
spec_verify_ct.append(req.spec_verify_ct)
|
533
|
+
|
534
|
+
if return_logprob:
|
535
|
+
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
536
|
+
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
537
|
+
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
538
|
+
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
539
|
+
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
540
|
+
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
541
|
+
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
542
|
+
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
543
|
+
input_token_ids_logprobs_val.append(
|
544
|
+
req.input_token_ids_logprobs_val
|
545
|
+
)
|
546
|
+
input_token_ids_logprobs_idx.append(
|
547
|
+
req.input_token_ids_logprobs_idx
|
548
|
+
)
|
549
|
+
output_token_ids_logprobs_val.append(
|
550
|
+
req.output_token_ids_logprobs_val
|
551
|
+
)
|
552
|
+
output_token_ids_logprobs_idx.append(
|
553
|
+
req.output_token_ids_logprobs_idx
|
554
|
+
)
|
555
|
+
|
556
|
+
if req.return_hidden_states:
|
557
|
+
if output_hidden_states is None:
|
558
|
+
output_hidden_states = []
|
559
|
+
output_hidden_states.append(req.hidden_states)
|
560
|
+
|
561
|
+
# Send to detokenizer
|
562
|
+
if rids:
|
563
|
+
if self.model_config.is_multimodal_gen:
|
564
|
+
return
|
565
|
+
self.send_to_detokenizer.send_pyobj(
|
566
|
+
BatchTokenIDOut(
|
567
|
+
rids,
|
568
|
+
finished_reasons,
|
569
|
+
decoded_texts,
|
570
|
+
decode_ids_list,
|
571
|
+
read_offsets,
|
572
|
+
output_ids,
|
573
|
+
skip_special_tokens,
|
574
|
+
spaces_between_special_tokens,
|
575
|
+
no_stop_trim,
|
576
|
+
prompt_tokens,
|
577
|
+
completion_tokens,
|
578
|
+
cached_tokens,
|
579
|
+
spec_verify_ct,
|
580
|
+
input_token_logprobs_val,
|
581
|
+
input_token_logprobs_idx,
|
582
|
+
output_token_logprobs_val,
|
583
|
+
output_token_logprobs_idx,
|
584
|
+
input_top_logprobs_val,
|
585
|
+
input_top_logprobs_idx,
|
586
|
+
output_top_logprobs_val,
|
587
|
+
output_top_logprobs_idx,
|
588
|
+
input_token_ids_logprobs_val,
|
589
|
+
input_token_ids_logprobs_idx,
|
590
|
+
output_token_ids_logprobs_val,
|
591
|
+
output_token_ids_logprobs_idx,
|
592
|
+
output_hidden_states,
|
593
|
+
)
|
594
|
+
)
|
595
|
+
|
596
|
+
def stream_output_embedding(self, reqs: List[Req]):
|
597
|
+
rids = []
|
598
|
+
finished_reasons: List[BaseFinishReason] = []
|
599
|
+
|
600
|
+
embeddings = []
|
601
|
+
prompt_tokens = []
|
602
|
+
cached_tokens = []
|
603
|
+
for req in reqs:
|
604
|
+
if req.finished():
|
605
|
+
rids.append(req.rid)
|
606
|
+
finished_reasons.append(req.finished_reason.to_json())
|
607
|
+
embeddings.append(req.embedding)
|
608
|
+
prompt_tokens.append(len(req.origin_input_ids))
|
609
|
+
cached_tokens.append(req.cached_tokens)
|
610
|
+
self.send_to_detokenizer.send_pyobj(
|
611
|
+
BatchEmbeddingOut(
|
612
|
+
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
|
613
|
+
)
|
614
|
+
)
|
@@ -372,13 +372,12 @@ class TokenizerManager:
|
|
372
372
|
)
|
373
373
|
input_ids = self.tokenizer.encode(input_text)
|
374
374
|
|
375
|
+
image_inputs: Dict = await self.image_processor.process_images_async(
|
376
|
+
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
377
|
+
)
|
378
|
+
if image_inputs and "input_ids" in image_inputs:
|
379
|
+
input_ids = image_inputs["input_ids"]
|
375
380
|
if self.is_generation:
|
376
|
-
# TODO: also support getting embeddings for multimodal models
|
377
|
-
image_inputs: Dict = await self.image_processor.process_images_async(
|
378
|
-
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
379
|
-
)
|
380
|
-
if image_inputs and "input_ids" in image_inputs:
|
381
|
-
input_ids = image_inputs["input_ids"]
|
382
381
|
return_logprob = obj.return_logprob
|
383
382
|
logprob_start_len = obj.logprob_start_len
|
384
383
|
top_logprobs_num = obj.top_logprobs_num
|
@@ -438,6 +437,7 @@ class TokenizerManager:
|
|
438
437
|
obj.rid,
|
439
438
|
input_text,
|
440
439
|
input_ids,
|
440
|
+
image_inputs,
|
441
441
|
sampling_params,
|
442
442
|
)
|
443
443
|
|
@@ -103,6 +103,9 @@ class TpModelWorkerClient:
|
|
103
103
|
self.worker.model_runner.token_to_kv_pool_allocator,
|
104
104
|
)
|
105
105
|
|
106
|
+
def get_kv_cache(self):
|
107
|
+
return self.worker.model_runner.token_to_kv_pool
|
108
|
+
|
106
109
|
def forward_thread_func(self):
|
107
110
|
try:
|
108
111
|
with torch.get_device_module(self.device).stream(self.forward_stream):
|
@@ -203,7 +206,7 @@ class TpModelWorkerClient:
|
|
203
206
|
-(self.future_token_ids_ct + 1),
|
204
207
|
-(self.future_token_ids_ct + 1 + bs),
|
205
208
|
-1,
|
206
|
-
dtype=torch.
|
209
|
+
dtype=torch.int64,
|
207
210
|
device=self.device,
|
208
211
|
)
|
209
212
|
self.future_token_ids_ct = (
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import
|
2
|
+
from typing import Any, List, Tuple
|
3
3
|
|
4
4
|
|
5
5
|
class BasePrefixCache(ABC):
|
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
|
|
26
26
|
pass
|
27
27
|
|
28
28
|
@abstractmethod
|
29
|
-
def evict(self, num_tokens: int
|
29
|
+
def evict(self, num_tokens: int):
|
30
30
|
pass
|
31
31
|
|
32
32
|
@abstractmethod
|
33
|
-
def inc_lock_ref(self, node):
|
33
|
+
def inc_lock_ref(self, node: Any):
|
34
34
|
pass
|
35
35
|
|
36
36
|
@abstractmethod
|
37
|
-
def dec_lock_ref(self, node):
|
37
|
+
def dec_lock_ref(self, node: Any):
|
38
38
|
pass
|
39
39
|
|
40
|
-
@abstractmethod
|
41
40
|
def evictable_size(self):
|
42
|
-
|
41
|
+
return 0
|
43
42
|
|
44
|
-
@abstractmethod
|
45
43
|
def protected_size(self):
|
46
|
-
|
44
|
+
return 0
|
47
45
|
|
48
46
|
def total_size(self):
|
49
47
|
raise NotImplementedError()
|