sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +46 -25
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +184 -63
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -248
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/model_executor/cuda_graph_runner.py +15 -19
- sglang/srt/model_executor/forward_batch_info.py +94 -95
- sglang/srt/model_executor/model_runner.py +76 -75
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +14 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +71 -26
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +18 -9
- sglang/version.py +1 -1
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -474
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.2.dist-info/RECORD +0 -135
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -18,7 +18,6 @@ The definition of objects transfered between different
|
|
18
18
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
19
19
|
"""
|
20
20
|
|
21
|
-
import copy
|
22
21
|
import uuid
|
23
22
|
from dataclasses import dataclass
|
24
23
|
from typing import Dict, List, Optional, Union
|
@@ -37,7 +36,7 @@ class GenerateReqInput:
|
|
37
36
|
# See also python/sglang/srt/utils.py:load_image.
|
38
37
|
image_data: Optional[Union[List[str], str]] = None
|
39
38
|
# The sampling_params. See descriptions below.
|
40
|
-
sampling_params: Union[List[Dict], Dict] = None
|
39
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
41
40
|
# The request id.
|
42
41
|
rid: Optional[Union[List[str], str]] = None
|
43
42
|
# Whether to return logprobs.
|
@@ -53,9 +52,6 @@ class GenerateReqInput:
|
|
53
52
|
stream: bool = False
|
54
53
|
# The modalities of the image data [image, multi-images, video]
|
55
54
|
modalities: Optional[List[str]] = None
|
56
|
-
|
57
|
-
is_single: bool = True
|
58
|
-
|
59
55
|
# LoRA related
|
60
56
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
61
57
|
|
@@ -65,19 +61,41 @@ class GenerateReqInput:
|
|
65
61
|
):
|
66
62
|
raise ValueError("Either text or input_ids should be provided.")
|
67
63
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
64
|
+
self.is_single = False
|
65
|
+
if self.text is not None:
|
66
|
+
if isinstance(self.text, str):
|
67
|
+
self.is_single = True
|
68
|
+
self.batch_size = 1
|
69
|
+
else:
|
70
|
+
self.batch_size = len(self.text)
|
73
71
|
else:
|
74
|
-
if self.
|
75
|
-
is_single =
|
72
|
+
if isinstance(self.input_ids[0], int):
|
73
|
+
self.is_single = True
|
74
|
+
self.batch_size = 1
|
76
75
|
else:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
76
|
+
self.batch_size = len(self.input_ids)
|
77
|
+
|
78
|
+
if self.sampling_params is None:
|
79
|
+
self.parallel_sample_num = 1
|
80
|
+
elif isinstance(self.sampling_params, dict):
|
81
|
+
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
82
|
+
else: # isinstance(self.sampling_params, list):
|
83
|
+
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
84
|
+
for sp in self.sampling_params:
|
85
|
+
# TODO cope with the case that the parallel_sample_num is different for different samples
|
86
|
+
assert self.parallel_sample_num == sp.get(
|
87
|
+
"n", 1
|
88
|
+
), "The parallel_sample_num should be the same for all samples in sample params."
|
89
|
+
|
90
|
+
if self.parallel_sample_num > 1:
|
91
|
+
if self.is_single:
|
92
|
+
self.is_single = False
|
93
|
+
if self.text is not None:
|
94
|
+
self.text = [self.text]
|
95
|
+
if self.input_ids is not None:
|
96
|
+
self.input_ids = [self.input_ids]
|
97
|
+
|
98
|
+
if self.is_single:
|
81
99
|
if self.sampling_params is None:
|
82
100
|
self.sampling_params = {}
|
83
101
|
if self.rid is None:
|
@@ -89,79 +107,54 @@ class GenerateReqInput:
|
|
89
107
|
if self.top_logprobs_num is None:
|
90
108
|
self.top_logprobs_num = 0
|
91
109
|
else:
|
92
|
-
|
93
|
-
|
94
|
-
parallel_sample_num = self.sampling_params.get("n", 1)
|
95
|
-
elif isinstance(self.sampling_params, list):
|
96
|
-
for sp in self.sampling_params:
|
97
|
-
parallel_sample_num = sp.get("n", 1)
|
98
|
-
parallel_sample_num_list.append(parallel_sample_num)
|
99
|
-
parallel_sample_num = max(parallel_sample_num_list)
|
100
|
-
all_equal = all(
|
101
|
-
element == parallel_sample_num
|
102
|
-
for element in parallel_sample_num_list
|
103
|
-
)
|
104
|
-
if parallel_sample_num > 1 and (not all_equal):
|
105
|
-
# TODO cope with the case that the parallel_sample_num is different for different samples
|
106
|
-
raise ValueError(
|
107
|
-
"The parallel_sample_num should be the same for all samples in sample params."
|
108
|
-
)
|
110
|
+
if self.parallel_sample_num == 1:
|
111
|
+
num = self.batch_size
|
109
112
|
else:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
if parallel_sample_num != 1:
|
114
|
-
# parallel sampling +1 represents the original prefill stage
|
115
|
-
num = parallel_sample_num + 1
|
116
|
-
if isinstance(self.text, list):
|
117
|
-
# suppot batch operation
|
118
|
-
self.batch_size = len(self.text)
|
119
|
-
num = num * len(self.text)
|
120
|
-
elif isinstance(self.input_ids, list) and isinstance(
|
121
|
-
self.input_ids[0], list
|
122
|
-
):
|
123
|
-
self.batch_size = len(self.input_ids)
|
124
|
-
num = num * len(self.input_ids)
|
125
|
-
else:
|
126
|
-
self.batch_size = 1
|
127
|
-
else:
|
128
|
-
# support select operation
|
129
|
-
num = len(self.text) if self.text is not None else len(self.input_ids)
|
130
|
-
self.batch_size = num
|
113
|
+
# FIXME support cascade inference
|
114
|
+
# first bs samples are used for caching the prefix for parallel sampling
|
115
|
+
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
131
116
|
|
132
117
|
if self.image_data is None:
|
133
118
|
self.image_data = [None] * num
|
134
119
|
elif not isinstance(self.image_data, list):
|
135
120
|
self.image_data = [self.image_data] * num
|
136
121
|
elif isinstance(self.image_data, list):
|
137
|
-
#
|
122
|
+
# FIXME incorrect order for duplication
|
138
123
|
self.image_data = self.image_data * num
|
139
124
|
|
140
125
|
if self.sampling_params is None:
|
141
126
|
self.sampling_params = [{}] * num
|
142
127
|
elif not isinstance(self.sampling_params, list):
|
143
128
|
self.sampling_params = [self.sampling_params] * num
|
129
|
+
else:
|
130
|
+
assert self.parallel_sample_num == 1
|
144
131
|
|
145
132
|
if self.rid is None:
|
146
133
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
147
134
|
else:
|
148
|
-
|
149
|
-
|
135
|
+
assert isinstance(self.rid, list), "The rid should be a list."
|
136
|
+
assert self.parallel_sample_num == 1
|
150
137
|
|
151
138
|
if self.return_logprob is None:
|
152
139
|
self.return_logprob = [False] * num
|
153
140
|
elif not isinstance(self.return_logprob, list):
|
154
141
|
self.return_logprob = [self.return_logprob] * num
|
142
|
+
else:
|
143
|
+
assert self.parallel_sample_num == 1
|
155
144
|
|
156
145
|
if self.logprob_start_len is None:
|
157
146
|
self.logprob_start_len = [-1] * num
|
158
147
|
elif not isinstance(self.logprob_start_len, list):
|
159
148
|
self.logprob_start_len = [self.logprob_start_len] * num
|
149
|
+
else:
|
150
|
+
assert self.parallel_sample_num == 1
|
160
151
|
|
161
152
|
if self.top_logprobs_num is None:
|
162
153
|
self.top_logprobs_num = [0] * num
|
163
154
|
elif not isinstance(self.top_logprobs_num, list):
|
164
155
|
self.top_logprobs_num = [self.top_logprobs_num] * num
|
156
|
+
else:
|
157
|
+
assert self.parallel_sample_num == 1
|
165
158
|
|
166
159
|
|
167
160
|
@dataclass
|
@@ -172,12 +165,8 @@ class TokenizedGenerateReqInput:
|
|
172
165
|
input_text: str
|
173
166
|
# The input token ids
|
174
167
|
input_ids: List[int]
|
175
|
-
# The
|
176
|
-
|
177
|
-
# The hash values of input images
|
178
|
-
image_hashes: List[int]
|
179
|
-
# The image sizes
|
180
|
-
image_sizes: List[List[int]]
|
168
|
+
# The image input
|
169
|
+
image_inputs: dict
|
181
170
|
# The sampling parameters
|
182
171
|
sampling_params: SamplingParams
|
183
172
|
# Whether to return the logprobs
|
@@ -188,8 +177,6 @@ class TokenizedGenerateReqInput:
|
|
188
177
|
top_logprobs_num: int
|
189
178
|
# Whether to stream output
|
190
179
|
stream: bool
|
191
|
-
# Modalities of the input images
|
192
|
-
modalites: Optional[List[str]] = None
|
193
180
|
|
194
181
|
# LoRA related
|
195
182
|
lora_path: Optional[str] = None # None means just use the base model
|
@@ -206,8 +193,6 @@ class EmbeddingReqInput:
|
|
206
193
|
# Dummy sampling params for compatibility
|
207
194
|
sampling_params: Union[List[Dict], Dict] = None
|
208
195
|
|
209
|
-
is_single: bool = True
|
210
|
-
|
211
196
|
def post_init(self):
|
212
197
|
if (self.text is None and self.input_ids is None) or (
|
213
198
|
self.text is not None and self.input_ids is not None
|
@@ -215,12 +200,11 @@ class EmbeddingReqInput:
|
|
215
200
|
raise ValueError("Either text or input_ids should be provided.")
|
216
201
|
|
217
202
|
if self.text is not None:
|
218
|
-
is_single = isinstance(self.text, str)
|
203
|
+
self.is_single = isinstance(self.text, str)
|
219
204
|
else:
|
220
|
-
is_single = isinstance(self.input_ids[0], int)
|
221
|
-
self.is_single = is_single
|
205
|
+
self.is_single = isinstance(self.input_ids[0], int)
|
222
206
|
|
223
|
-
if is_single:
|
207
|
+
if self.is_single:
|
224
208
|
if self.rid is None:
|
225
209
|
self.rid = uuid.uuid4().hex
|
226
210
|
if self.sampling_params is None:
|
@@ -254,6 +238,50 @@ class TokenizedEmbeddingReqInput:
|
|
254
238
|
sampling_params: SamplingParams
|
255
239
|
|
256
240
|
|
241
|
+
@dataclass
|
242
|
+
class RewardReqInput:
|
243
|
+
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
244
|
+
conv: Union[List[List[Dict]], List[Dict]]
|
245
|
+
# The request id.
|
246
|
+
rid: Optional[Union[List[str], str]] = None
|
247
|
+
# Dummy sampling params for compatibility
|
248
|
+
sampling_params: Union[List[Dict], Dict] = None
|
249
|
+
|
250
|
+
def post_init(self):
|
251
|
+
self.is_single = isinstance(self.conv[0], dict)
|
252
|
+
|
253
|
+
if self.is_single:
|
254
|
+
if self.rid is None:
|
255
|
+
self.rid = uuid.uuid4().hex
|
256
|
+
if self.sampling_params is None:
|
257
|
+
self.sampling_params = {}
|
258
|
+
self.sampling_params["max_new_tokens"] = 1
|
259
|
+
else:
|
260
|
+
# support select operation
|
261
|
+
self.batch_size = len(self.conv)
|
262
|
+
if self.rid is None:
|
263
|
+
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
264
|
+
else:
|
265
|
+
if not isinstance(self.rid, list):
|
266
|
+
raise ValueError("The rid should be a list.")
|
267
|
+
if self.sampling_params is None:
|
268
|
+
self.sampling_params = [{}] * self.batch_size
|
269
|
+
for i in range(self.batch_size):
|
270
|
+
self.sampling_params[i]["max_new_tokens"] = 1
|
271
|
+
|
272
|
+
|
273
|
+
@dataclass
|
274
|
+
class TokenizedRewardReqInput:
|
275
|
+
# The request id
|
276
|
+
rid: str
|
277
|
+
# The input text
|
278
|
+
input_text: str
|
279
|
+
# The input token ids
|
280
|
+
input_ids: List[int]
|
281
|
+
# Dummy sampling params for compatibility
|
282
|
+
sampling_params: SamplingParams
|
283
|
+
|
284
|
+
|
257
285
|
@dataclass
|
258
286
|
class BatchTokenIDOut:
|
259
287
|
# The request id
|
@@ -268,10 +296,6 @@ class BatchTokenIDOut:
|
|
268
296
|
meta_info: List[Dict]
|
269
297
|
finished_reason: List[BaseFinishReason]
|
270
298
|
|
271
|
-
def __post_init__(self):
|
272
|
-
# deepcopy meta_info to avoid modification in place
|
273
|
-
self.meta_info = copy.deepcopy(self.meta_info)
|
274
|
-
|
275
299
|
|
276
300
|
@dataclass
|
277
301
|
class BatchStrOut:
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
"""
|
4
2
|
Copyright 2023-2024 SGLang Team
|
5
3
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
|
|
15
13
|
limitations under the License.
|
16
14
|
"""
|
17
15
|
|
18
|
-
"""
|
16
|
+
"""
|
17
|
+
Store information about requests and batches.
|
18
|
+
|
19
|
+
The following is the flow of data structures for a batch:
|
20
|
+
|
21
|
+
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
22
|
+
|
23
|
+
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
24
|
+
It contains high-level scheduling data. Most of the data is on the CPU.
|
25
|
+
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
26
|
+
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
27
|
+
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
28
|
+
"""
|
19
29
|
|
20
30
|
import logging
|
21
31
|
from dataclasses import dataclass
|
@@ -31,6 +41,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
31
41
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
33
43
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
44
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
34
45
|
from sglang.srt.server_args import ServerArgs
|
35
46
|
|
36
47
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
@@ -102,14 +113,50 @@ class FINISH_ABORT(BaseFinishReason):
|
|
102
113
|
}
|
103
114
|
|
104
115
|
|
116
|
+
@dataclass
|
117
|
+
class ImageInputs:
|
118
|
+
"""The image related inputs."""
|
119
|
+
|
120
|
+
pixel_values: torch.Tensor
|
121
|
+
image_hash: int
|
122
|
+
image_sizes: Optional[list] = None
|
123
|
+
image_offsets: Optional[list] = None
|
124
|
+
pad_values: Optional[list] = None
|
125
|
+
modalities: Optional[list] = None
|
126
|
+
|
127
|
+
image_embeds: Optional[List[torch.Tensor]] = None
|
128
|
+
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
129
|
+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
130
|
+
|
131
|
+
@staticmethod
|
132
|
+
def from_dict(obj, vocab_size):
|
133
|
+
# Use image hash as fake token_ids, which is then used for prefix matching
|
134
|
+
ret = ImageInputs(
|
135
|
+
pixel_values=obj["pixel_values"],
|
136
|
+
image_hash=hash(tuple(obj["image_hashes"])),
|
137
|
+
)
|
138
|
+
image_hash = ret.image_hash
|
139
|
+
ret.pad_values = [
|
140
|
+
(image_hash) % vocab_size,
|
141
|
+
(image_hash >> 16) % vocab_size,
|
142
|
+
(image_hash >> 32) % vocab_size,
|
143
|
+
(image_hash >> 64) % vocab_size,
|
144
|
+
]
|
145
|
+
ret.image_sizes = obj["image_sizes"]
|
146
|
+
# Only when pixel values is not None we have modalities
|
147
|
+
ret.modalities = obj["modalities"] or ["image"]
|
148
|
+
return ret
|
149
|
+
|
150
|
+
|
105
151
|
class Req:
|
106
|
-
"""
|
152
|
+
"""The input and output status of a request."""
|
107
153
|
|
108
154
|
def __init__(
|
109
155
|
self,
|
110
156
|
rid: str,
|
111
157
|
origin_input_text: str,
|
112
158
|
origin_input_ids: Tuple[int],
|
159
|
+
sampling_params: SamplingParams,
|
113
160
|
lora_path: Optional[str] = None,
|
114
161
|
):
|
115
162
|
# Input and output info
|
@@ -119,6 +166,8 @@ class Req:
|
|
119
166
|
self.origin_input_ids = origin_input_ids
|
120
167
|
self.output_ids = [] # Each decode stage's output ids
|
121
168
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
169
|
+
|
170
|
+
self.sampling_params = sampling_params
|
122
171
|
self.lora_path = lora_path
|
123
172
|
|
124
173
|
# Memory info
|
@@ -127,6 +176,7 @@ class Req:
|
|
127
176
|
# Check finish
|
128
177
|
self.tokenizer = None
|
129
178
|
self.finished_reason = None
|
179
|
+
self.stream = False
|
130
180
|
|
131
181
|
# For incremental decoding
|
132
182
|
# ----- | --------- read_ids -------|
|
@@ -147,21 +197,13 @@ class Req:
|
|
147
197
|
self.completion_tokens_wo_jump_forward = 0
|
148
198
|
|
149
199
|
# For vision inputs
|
150
|
-
self.
|
151
|
-
self.image_sizes = None
|
152
|
-
self.image_offsets = None
|
153
|
-
self.pad_value = None
|
154
|
-
self.modalities = None
|
200
|
+
self.image_inputs: Optional[ImageInputs] = None
|
155
201
|
|
156
202
|
# Prefix info
|
157
203
|
self.prefix_indices = []
|
158
204
|
self.extend_input_len = 0
|
159
205
|
self.last_node = None
|
160
206
|
|
161
|
-
# Sampling parameters
|
162
|
-
self.sampling_params = None
|
163
|
-
self.stream = False
|
164
|
-
|
165
207
|
# Logprobs (arguments)
|
166
208
|
self.return_logprob = False
|
167
209
|
self.logprob_start_len = 0
|
@@ -363,28 +405,32 @@ class ScheduleBatch:
|
|
363
405
|
sampling_info: SamplingBatchInfo = None
|
364
406
|
|
365
407
|
# Batched arguments to model runner
|
366
|
-
input_ids:
|
367
|
-
req_pool_indices:
|
368
|
-
seq_lens:
|
369
|
-
position_ids_offsets: torch.Tensor = None
|
408
|
+
input_ids: List[int] = None
|
409
|
+
req_pool_indices: List[int] = None
|
410
|
+
seq_lens: List[int] = None
|
370
411
|
out_cache_loc: torch.Tensor = None
|
371
|
-
extend_num_tokens: int = None
|
372
|
-
|
373
|
-
# For mixed chunekd prefill
|
374
|
-
prefix_lens_cpu: List[int] = None
|
375
|
-
running_bs: int = None
|
376
412
|
|
377
413
|
# For processing logprobs
|
378
414
|
return_logprob: bool = False
|
379
|
-
top_logprobs_nums: List[int] = None
|
415
|
+
top_logprobs_nums: Optional[List[int]] = None
|
416
|
+
|
417
|
+
# For extend and mixed chunekd prefill
|
418
|
+
prefix_lens: List[int] = None
|
419
|
+
extend_lens: List[int] = None
|
420
|
+
extend_num_tokens: int = None
|
421
|
+
running_bs: int = None
|
380
422
|
|
381
423
|
# Stream
|
382
424
|
has_stream: bool = False
|
383
425
|
|
426
|
+
# Has regex
|
427
|
+
has_regex: bool = False
|
428
|
+
|
384
429
|
@classmethod
|
385
430
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
386
431
|
return_logprob = any(req.return_logprob for req in reqs)
|
387
432
|
has_stream = any(req.stream for req in reqs)
|
433
|
+
has_regex = any(req.regex_fsm for req in reqs)
|
388
434
|
|
389
435
|
return cls(
|
390
436
|
reqs=reqs,
|
@@ -393,6 +439,7 @@ class ScheduleBatch:
|
|
393
439
|
tree_cache=tree_cache,
|
394
440
|
return_logprob=return_logprob,
|
395
441
|
has_stream=has_stream,
|
442
|
+
has_regex=has_regex,
|
396
443
|
)
|
397
444
|
|
398
445
|
def batch_size(self):
|
@@ -436,12 +483,12 @@ class ScheduleBatch:
|
|
436
483
|
seq_lens = []
|
437
484
|
|
438
485
|
# Allocate memory
|
439
|
-
|
486
|
+
req_pool_indices = self.alloc_req_slots(bs)
|
440
487
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
441
488
|
|
442
489
|
pt = 0
|
443
490
|
for i, req in enumerate(reqs):
|
444
|
-
req.req_pool_idx =
|
491
|
+
req.req_pool_idx = req_pool_indices[i]
|
445
492
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
446
493
|
seq_lens.append(seq_len)
|
447
494
|
assert seq_len - pre_len == req.extend_input_len
|
@@ -467,18 +514,19 @@ class ScheduleBatch:
|
|
467
514
|
pt += req.extend_input_len
|
468
515
|
|
469
516
|
# Set fields
|
470
|
-
with
|
517
|
+
with out_cache_loc.device:
|
471
518
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
472
|
-
self.req_pool_indices = torch.tensor(
|
473
|
-
self.seq_lens = torch.tensor(seq_lens
|
474
|
-
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
519
|
+
self.req_pool_indices = torch.tensor(req_pool_indices)
|
520
|
+
self.seq_lens = torch.tensor(seq_lens)
|
475
521
|
|
476
522
|
self.extend_num_tokens = extend_num_tokens
|
477
523
|
self.out_cache_loc = out_cache_loc
|
478
|
-
self.
|
479
|
-
|
480
|
-
self.
|
481
|
-
self.
|
524
|
+
if self.return_logprob:
|
525
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
526
|
+
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
527
|
+
self.extend_lens = [r.extend_input_len for r in reqs]
|
528
|
+
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
529
|
+
|
482
530
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
483
531
|
|
484
532
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -493,20 +541,20 @@ class ScheduleBatch:
|
|
493
541
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
494
542
|
extend_num_tokens = self.extend_num_tokens + running_bs
|
495
543
|
|
496
|
-
self.
|
544
|
+
self.merge_batch(running_batch)
|
497
545
|
self.input_ids = input_ids
|
498
546
|
self.out_cache_loc = out_cache_loc
|
499
547
|
self.extend_num_tokens = extend_num_tokens
|
500
548
|
|
501
549
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
502
|
-
self.
|
550
|
+
self.prefix_lens.extend(
|
503
551
|
[
|
504
552
|
len(r.origin_input_ids) + len(r.output_ids) - 1
|
505
553
|
for r in running_batch.reqs
|
506
554
|
]
|
507
555
|
)
|
508
|
-
self.
|
509
|
-
self.
|
556
|
+
self.extend_lens.extend([1] * running_bs)
|
557
|
+
self.extend_logprob_start_lens.extend([0] * running_bs)
|
510
558
|
|
511
559
|
def check_decode_mem(self):
|
512
560
|
bs = len(self.reqs)
|
@@ -598,7 +646,7 @@ class ScheduleBatch:
|
|
598
646
|
|
599
647
|
return retracted_reqs, new_estimate_ratio
|
600
648
|
|
601
|
-
def check_for_jump_forward(self,
|
649
|
+
def check_for_jump_forward(self, pad_input_ids_func):
|
602
650
|
jump_forward_reqs = []
|
603
651
|
filter_indices = [i for i in range(len(self.reqs))]
|
604
652
|
|
@@ -654,15 +702,9 @@ class ScheduleBatch:
|
|
654
702
|
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
655
703
|
|
656
704
|
# re-applying image padding
|
657
|
-
if req.
|
658
|
-
(
|
659
|
-
req.
|
660
|
-
req.image_offsets,
|
661
|
-
) = model_runner.model.pad_input_ids(
|
662
|
-
req.origin_input_ids_unpadded,
|
663
|
-
req.pad_value,
|
664
|
-
req.pixel_values,
|
665
|
-
req.image_sizes,
|
705
|
+
if req.image_inputs is not None:
|
706
|
+
req.origin_input_ids = pad_input_ids_func(
|
707
|
+
req.origin_input_ids_unpadded, req.image_inputs
|
666
708
|
)
|
667
709
|
|
668
710
|
jump_forward_reqs.append(req)
|
@@ -681,7 +723,9 @@ class ScheduleBatch:
|
|
681
723
|
for r in self.reqs
|
682
724
|
]
|
683
725
|
|
684
|
-
self.input_ids = torch.tensor(
|
726
|
+
self.input_ids = torch.tensor(
|
727
|
+
input_ids, dtype=torch.int32, device=self.seq_lens.device
|
728
|
+
)
|
685
729
|
self.seq_lens.add_(1)
|
686
730
|
|
687
731
|
# Alloc mem
|
@@ -703,33 +747,110 @@ class ScheduleBatch:
|
|
703
747
|
return
|
704
748
|
|
705
749
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
706
|
-
new_indices = torch.tensor(
|
707
|
-
|
708
|
-
|
750
|
+
new_indices = torch.tensor(
|
751
|
+
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
|
752
|
+
)
|
709
753
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
710
|
-
self.
|
754
|
+
self.seq_lens = self.seq_lens[new_indices]
|
711
755
|
self.out_cache_loc = None
|
712
|
-
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
713
756
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
757
|
+
if self.return_logprob:
|
758
|
+
self.top_logprobs_nums = [
|
759
|
+
self.top_logprobs_nums[i] for i in unfinished_indices
|
760
|
+
]
|
761
|
+
else:
|
762
|
+
self.top_logprobs_nums = None
|
763
|
+
|
714
764
|
self.has_stream = any(req.stream for req in self.reqs)
|
765
|
+
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
715
766
|
|
716
|
-
self.sampling_info.
|
767
|
+
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
717
768
|
|
718
|
-
def
|
769
|
+
def merge_batch(self, other: "ScheduleBatch"):
|
719
770
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
720
771
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
721
772
|
# needs to be called with pre-merged Batch.reqs.
|
722
|
-
self.sampling_info.
|
773
|
+
self.sampling_info.merge_batch(other.sampling_info)
|
723
774
|
|
724
|
-
self.reqs.extend(other.reqs)
|
725
775
|
self.req_pool_indices = torch.concat(
|
726
776
|
[self.req_pool_indices, other.req_pool_indices]
|
727
777
|
)
|
728
778
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
729
|
-
self.position_ids_offsets = torch.concat(
|
730
|
-
[self.position_ids_offsets, other.position_ids_offsets]
|
731
|
-
)
|
732
779
|
self.out_cache_loc = None
|
733
|
-
self.
|
734
|
-
|
735
|
-
|
780
|
+
if self.return_logprob and other.return_logprob:
|
781
|
+
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
782
|
+
elif self.return_logprob:
|
783
|
+
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
784
|
+
elif other.return_logprob:
|
785
|
+
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
786
|
+
self.reqs.extend(other.reqs)
|
787
|
+
|
788
|
+
self.return_logprob = self.return_logprob or other.return_logprob
|
789
|
+
self.has_stream = self.has_stream or other.has_stream
|
790
|
+
self.has_regex = self.has_regex or other.has_regex
|
791
|
+
|
792
|
+
def get_model_worker_batch(self):
|
793
|
+
if self.forward_mode.is_decode():
|
794
|
+
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
795
|
+
image_inputs
|
796
|
+
) = None
|
797
|
+
else:
|
798
|
+
extend_seq_lens = self.extend_lens
|
799
|
+
extend_prefix_lens = self.prefix_lens
|
800
|
+
extend_logprob_start_lens = self.extend_logprob_start_lens
|
801
|
+
image_inputs = [r.image_inputs for r in self.reqs]
|
802
|
+
|
803
|
+
lora_paths = [req.lora_path for req in self.reqs]
|
804
|
+
if self.has_regex:
|
805
|
+
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
806
|
+
self.sampling_info.regex_fsm_states = [
|
807
|
+
req.regex_fsm_state for req in self.reqs
|
808
|
+
]
|
809
|
+
|
810
|
+
return ModelWorkerBatch(
|
811
|
+
forward_mode=self.forward_mode,
|
812
|
+
input_ids=self.input_ids,
|
813
|
+
req_pool_indices=self.req_pool_indices,
|
814
|
+
seq_lens=self.seq_lens,
|
815
|
+
out_cache_loc=self.out_cache_loc,
|
816
|
+
return_logprob=self.return_logprob,
|
817
|
+
top_logprobs_nums=self.top_logprobs_nums,
|
818
|
+
extend_seq_lens=extend_seq_lens,
|
819
|
+
extend_prefix_lens=extend_prefix_lens,
|
820
|
+
extend_logprob_start_lens=extend_logprob_start_lens,
|
821
|
+
image_inputs=image_inputs,
|
822
|
+
lora_paths=lora_paths,
|
823
|
+
sampling_info=self.sampling_info,
|
824
|
+
)
|
825
|
+
|
826
|
+
|
827
|
+
@dataclass
|
828
|
+
class ModelWorkerBatch:
|
829
|
+
# The forward mode
|
830
|
+
forward_mode: ForwardMode
|
831
|
+
# The input ids
|
832
|
+
input_ids: torch.Tensor
|
833
|
+
# The indices of requests in the req_to_token_pool
|
834
|
+
req_pool_indices: torch.Tensor
|
835
|
+
# The sequence length
|
836
|
+
seq_lens: torch.Tensor
|
837
|
+
# The indices of output tokens in the token_to_kv_pool
|
838
|
+
out_cache_loc: torch.Tensor
|
839
|
+
|
840
|
+
# For logprob
|
841
|
+
return_logprob: bool
|
842
|
+
top_logprobs_nums: Optional[List[int]]
|
843
|
+
|
844
|
+
# For extend
|
845
|
+
extend_seq_lens: Optional[List[int]]
|
846
|
+
extend_prefix_lens: Optional[List[int]]
|
847
|
+
extend_logprob_start_lens: Optional[List[int]]
|
848
|
+
|
849
|
+
# For multimodal
|
850
|
+
image_inputs: Optional[List[ImageInputs]]
|
851
|
+
|
852
|
+
# For LoRA
|
853
|
+
lora_paths: Optional[List[str]]
|
854
|
+
|
855
|
+
# Sampling info
|
856
|
+
sampling_info: SamplingBatchInfo
|