sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -71,12 +71,10 @@ class ControllerMulti:
|
|
71
71
|
self,
|
72
72
|
server_args: ServerArgs,
|
73
73
|
port_args: PortArgs,
|
74
|
-
model_override_args,
|
75
74
|
):
|
76
75
|
# Parse args
|
77
76
|
self.server_args = server_args
|
78
77
|
self.port_args = port_args
|
79
|
-
self.model_override_args = model_override_args
|
80
78
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
81
79
|
server_args.load_balance_method
|
82
80
|
)
|
@@ -114,7 +112,6 @@ class ControllerMulti:
|
|
114
112
|
self.server_args,
|
115
113
|
self.port_args,
|
116
114
|
pipe_controller_writer,
|
117
|
-
self.model_override_args,
|
118
115
|
True,
|
119
116
|
gpu_ids,
|
120
117
|
dp_worker_id,
|
@@ -189,14 +186,13 @@ def start_controller_process(
|
|
189
186
|
server_args: ServerArgs,
|
190
187
|
port_args: PortArgs,
|
191
188
|
pipe_writer,
|
192
|
-
model_override_args: dict,
|
193
189
|
):
|
194
190
|
"""Start a controller process."""
|
195
191
|
|
196
192
|
configure_logger(server_args)
|
197
193
|
|
198
194
|
try:
|
199
|
-
controller = ControllerMulti(server_args, port_args
|
195
|
+
controller = ControllerMulti(server_args, port_args)
|
200
196
|
except Exception:
|
201
197
|
pipe_writer.send(get_exception_traceback())
|
202
198
|
raise
|
@@ -40,7 +40,6 @@ class ControllerSingle:
|
|
40
40
|
self,
|
41
41
|
server_args: ServerArgs,
|
42
42
|
port_args: PortArgs,
|
43
|
-
model_override_args: dict,
|
44
43
|
gpu_ids: List[int],
|
45
44
|
is_data_parallel_worker: bool,
|
46
45
|
dp_worker_id: int,
|
@@ -76,7 +75,6 @@ class ControllerSingle:
|
|
76
75
|
tp_rank_range,
|
77
76
|
server_args,
|
78
77
|
port_args.nccl_ports[dp_worker_id],
|
79
|
-
model_override_args,
|
80
78
|
)
|
81
79
|
|
82
80
|
# Launch tp rank 0
|
@@ -85,7 +83,6 @@ class ControllerSingle:
|
|
85
83
|
0,
|
86
84
|
server_args,
|
87
85
|
port_args.nccl_ports[dp_worker_id],
|
88
|
-
model_override_args,
|
89
86
|
)
|
90
87
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
91
88
|
|
@@ -126,7 +123,6 @@ def start_controller_process(
|
|
126
123
|
server_args: ServerArgs,
|
127
124
|
port_args: PortArgs,
|
128
125
|
pipe_writer: multiprocessing.connection.Connection,
|
129
|
-
model_override_args: dict,
|
130
126
|
is_data_parallel_worker: bool = False,
|
131
127
|
gpu_ids: List[int] = None,
|
132
128
|
dp_worker_id: int = None,
|
@@ -149,7 +145,6 @@ def start_controller_process(
|
|
149
145
|
controller = ControllerSingle(
|
150
146
|
server_args,
|
151
147
|
port_args,
|
152
|
-
model_override_args,
|
153
148
|
gpu_ids,
|
154
149
|
is_data_parallel_worker,
|
155
150
|
dp_worker_id,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
20
20
|
|
21
21
|
import copy
|
22
22
|
import uuid
|
23
|
-
from dataclasses import dataclass
|
23
|
+
from dataclasses import dataclass
|
24
24
|
from typing import Dict, List, Optional, Union
|
25
25
|
|
26
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
@@ -43,6 +43,7 @@ class GenerateReqInput:
|
|
43
43
|
# Whether to return logprobs.
|
44
44
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
45
45
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
46
|
+
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
46
47
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
47
48
|
# If return logprobs, the number of top logprobs to return at each position.
|
48
49
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
@@ -50,6 +51,13 @@ class GenerateReqInput:
|
|
50
51
|
return_text_in_logprobs: bool = False
|
51
52
|
# Whether to stream output.
|
52
53
|
stream: bool = False
|
54
|
+
# The modalities of the image data [image, multi-images, video]
|
55
|
+
modalities: Optional[List[str]] = None
|
56
|
+
|
57
|
+
is_single: bool = True
|
58
|
+
|
59
|
+
# LoRA related
|
60
|
+
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
53
61
|
|
54
62
|
def post_init(self):
|
55
63
|
if (self.text is None and self.input_ids is None) or (
|
@@ -177,6 +185,11 @@ class TokenizedGenerateReqInput:
|
|
177
185
|
top_logprobs_num: int
|
178
186
|
# Whether to stream output
|
179
187
|
stream: bool
|
188
|
+
# Modalities of the input images
|
189
|
+
modalites: Optional[List[str]] = None
|
190
|
+
|
191
|
+
# LoRA related
|
192
|
+
lora_path: Optional[str] = None # None means just use the base model
|
180
193
|
|
181
194
|
|
182
195
|
@dataclass
|
@@ -190,6 +203,8 @@ class EmbeddingReqInput:
|
|
190
203
|
# Dummy sampling params for compatibility
|
191
204
|
sampling_params: Union[List[Dict], Dict] = None
|
192
205
|
|
206
|
+
is_single: bool = True
|
207
|
+
|
193
208
|
def post_init(self):
|
194
209
|
if (self.text is None and self.input_ids is None) or (
|
195
210
|
self.text is not None and self.input_ids is not None
|
@@ -108,18 +108,25 @@ class PrefillAdder:
|
|
108
108
|
def __init__(
|
109
109
|
self,
|
110
110
|
tree_cache: BasePrefixCache,
|
111
|
+
running_batch: ScheduleBatch,
|
112
|
+
new_token_ratio: float,
|
111
113
|
rem_total_tokens: int,
|
112
114
|
rem_input_tokens: int,
|
113
115
|
rem_chunk_tokens: Optional[int],
|
114
116
|
mixed_with_decode_tokens: int = 0,
|
115
117
|
):
|
116
118
|
self.tree_cache = tree_cache
|
119
|
+
self.running_batch = running_batch
|
120
|
+
self.new_token_ratio = new_token_ratio
|
117
121
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
122
|
+
self.rem_total_tokens_ = self.rem_total_tokens
|
123
|
+
self.total_tokens = rem_total_tokens
|
118
124
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
119
125
|
self.rem_chunk_tokens = rem_chunk_tokens
|
120
126
|
if self.rem_chunk_tokens is not None:
|
121
127
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
122
128
|
|
129
|
+
self.req_states = None
|
123
130
|
self.can_run_list = []
|
124
131
|
self.new_inflight_req = None
|
125
132
|
self.log_hit_tokens = 0
|
@@ -136,16 +143,20 @@ class PrefillAdder:
|
|
136
143
|
)
|
137
144
|
)
|
138
145
|
|
139
|
-
def remove_running_tokens(
|
140
|
-
self, running_batch: ScheduleBatch, new_token_ratio: float
|
141
|
-
):
|
146
|
+
def remove_running_tokens(self, running_batch: ScheduleBatch):
|
142
147
|
self.rem_total_tokens -= sum(
|
143
148
|
[
|
144
149
|
min(
|
145
150
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
146
151
|
CLIP_MAX_NEW_TOKENS,
|
147
152
|
)
|
148
|
-
* new_token_ratio
|
153
|
+
* self.new_token_ratio
|
154
|
+
for r in running_batch.reqs
|
155
|
+
]
|
156
|
+
)
|
157
|
+
self.rem_total_tokens_ -= sum(
|
158
|
+
[
|
159
|
+
r.sampling_params.max_new_tokens - len(r.output_ids)
|
149
160
|
for r in running_batch.reqs
|
150
161
|
]
|
151
162
|
)
|
@@ -154,6 +165,7 @@ class PrefillAdder:
|
|
154
165
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
155
166
|
):
|
156
167
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
168
|
+
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
|
157
169
|
self.rem_input_tokens -= extend_input_len
|
158
170
|
if self.rem_chunk_tokens is not None:
|
159
171
|
self.rem_chunk_tokens -= extend_input_len
|
@@ -161,7 +173,29 @@ class PrefillAdder:
|
|
161
173
|
self.log_hit_tokens += prefix_len
|
162
174
|
self.log_input_tokens += extend_input_len
|
163
175
|
|
176
|
+
def add_inflight_req_ignore_eos(self, req: Req):
|
177
|
+
truncated = req.extend_input_len > self.rem_chunk_tokens
|
178
|
+
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
179
|
+
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
180
|
+
self.can_run_list.append(req)
|
181
|
+
|
182
|
+
self._prefill_one_req(
|
183
|
+
0,
|
184
|
+
req.extend_input_len,
|
185
|
+
(
|
186
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
187
|
+
if not truncated
|
188
|
+
else 0
|
189
|
+
),
|
190
|
+
)
|
191
|
+
|
192
|
+
# Return if chunked prefill not finished
|
193
|
+
return req if truncated else None
|
194
|
+
|
164
195
|
def add_inflight_req(self, req: Req):
|
196
|
+
if req.sampling_params.ignore_eos:
|
197
|
+
return self.add_inflight_req_ignore_eos(req)
|
198
|
+
|
165
199
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
166
200
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
167
201
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
@@ -190,7 +224,90 @@ class PrefillAdder:
|
|
190
224
|
delta = self.tree_cache.dec_lock_ref(last_node)
|
191
225
|
self.rem_total_tokens += delta
|
192
226
|
|
227
|
+
def add_one_req_ignore_eos(self, req: Req):
|
228
|
+
def get_req_state(r):
|
229
|
+
new_token_ratio = (
|
230
|
+
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
231
|
+
)
|
232
|
+
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
|
233
|
+
r.output_ids
|
234
|
+
)
|
235
|
+
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
236
|
+
|
237
|
+
if tokens_left > 0:
|
238
|
+
return (tokens_left, tokens_occupied)
|
239
|
+
|
240
|
+
return None
|
241
|
+
|
242
|
+
# Quick Check
|
243
|
+
can_run = False
|
244
|
+
if (
|
245
|
+
req.extend_input_len + req.sampling_params.max_new_tokens
|
246
|
+
<= self.rem_total_tokens
|
247
|
+
):
|
248
|
+
can_run = True
|
249
|
+
|
250
|
+
if not can_run:
|
251
|
+
if self.req_states is None:
|
252
|
+
self.req_states = []
|
253
|
+
if self.running_batch is not None:
|
254
|
+
for r in self.running_batch.reqs:
|
255
|
+
state = get_req_state(r)
|
256
|
+
if state is not None:
|
257
|
+
self.req_states.append(state)
|
258
|
+
for r in self.can_run_list:
|
259
|
+
state = get_req_state(r)
|
260
|
+
if state is not None:
|
261
|
+
self.req_states.append(state)
|
262
|
+
state = get_req_state(req)
|
263
|
+
if state is not None:
|
264
|
+
self.req_states.append(state)
|
265
|
+
|
266
|
+
self.req_states.sort(key=lambda x: x[0])
|
267
|
+
else:
|
268
|
+
state = get_req_state(req)
|
269
|
+
if state is not None:
|
270
|
+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
271
|
+
if tokens_left >= state[0]:
|
272
|
+
self.req_states.insert(i, state)
|
273
|
+
break
|
274
|
+
else:
|
275
|
+
self.req_states.append(state)
|
276
|
+
|
277
|
+
tokens_freed = 0
|
278
|
+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
279
|
+
decode_steps = (
|
280
|
+
self.req_states[i + 1][0]
|
281
|
+
if i + 1 < len(self.req_states)
|
282
|
+
else tokens_left
|
283
|
+
)
|
284
|
+
bs = len(self.req_states) - i
|
285
|
+
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
286
|
+
return False
|
287
|
+
tokens_freed += tokens_occupied
|
288
|
+
|
289
|
+
if req.extend_input_len <= self.rem_chunk_tokens:
|
290
|
+
self.can_run_list.append(req)
|
291
|
+
self._prefill_one_req(
|
292
|
+
0,
|
293
|
+
req.extend_input_len,
|
294
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
295
|
+
)
|
296
|
+
else:
|
297
|
+
# Chunked prefill
|
298
|
+
trunc_len = self.rem_chunk_tokens
|
299
|
+
req.extend_input_len = trunc_len
|
300
|
+
req.fill_ids = req.fill_ids[:trunc_len]
|
301
|
+
self.can_run_list.append(req)
|
302
|
+
self.new_inflight_req = req
|
303
|
+
self._prefill_one_req(0, trunc_len, 0)
|
304
|
+
|
305
|
+
return True
|
306
|
+
|
193
307
|
def add_one_req(self, req: Req):
|
308
|
+
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
309
|
+
return self.add_one_req_ignore_eos(req)
|
310
|
+
|
194
311
|
total_tokens = req.extend_input_len + min(
|
195
312
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
196
313
|
)
|
@@ -233,4 +350,4 @@ class PrefillAdder:
|
|
233
350
|
self.tree_cache.inc_lock_ref(req.last_node)
|
234
351
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
235
352
|
|
236
|
-
return True
|
353
|
+
return True and not self.no_remaining_tokens()
|
@@ -19,7 +19,7 @@ limitations under the License.
|
|
19
19
|
|
20
20
|
import logging
|
21
21
|
from dataclasses import dataclass
|
22
|
-
from typing import
|
22
|
+
from typing import List, Optional, Tuple, Union
|
23
23
|
|
24
24
|
import torch
|
25
25
|
|
@@ -29,20 +29,19 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|
29
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
30
30
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
31
31
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
32
33
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
33
|
-
|
34
|
-
if TYPE_CHECKING:
|
35
|
-
from sglang.srt.layers.sampler import SampleOutput
|
36
|
-
|
34
|
+
from sglang.srt.server_args import ServerArgs
|
37
35
|
|
38
36
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
39
37
|
|
40
38
|
# Put some global args for easy access
|
41
39
|
global_server_args_dict = {
|
42
|
-
"
|
43
|
-
"
|
44
|
-
"triton_attention_reduce_in_fp32":
|
45
|
-
"enable_mla":
|
40
|
+
"attention_backend": ServerArgs.attention_backend,
|
41
|
+
"sampling_backend": ServerArgs.sampling_backend,
|
42
|
+
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
43
|
+
"enable_mla": ServerArgs.enable_mla,
|
44
|
+
"torchao_config": ServerArgs.torchao_config,
|
46
45
|
}
|
47
46
|
|
48
47
|
|
@@ -53,8 +52,8 @@ class BaseFinishReason:
|
|
53
52
|
def __init__(self, is_error: bool = False):
|
54
53
|
self.is_error = is_error
|
55
54
|
|
56
|
-
def
|
57
|
-
raise NotImplementedError(
|
55
|
+
def to_json(self):
|
56
|
+
raise NotImplementedError()
|
58
57
|
|
59
58
|
|
60
59
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
@@ -62,40 +61,57 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
|
62
61
|
super().__init__()
|
63
62
|
self.matched = matched
|
64
63
|
|
65
|
-
def
|
66
|
-
return
|
64
|
+
def to_json(self):
|
65
|
+
return {
|
66
|
+
"type": "stop", # to match OpenAI API's return value
|
67
|
+
"matched": self.matched,
|
68
|
+
}
|
67
69
|
|
68
70
|
|
69
|
-
class
|
70
|
-
def __init__(self,
|
71
|
+
class FINISH_MATCHED_STR(BaseFinishReason):
|
72
|
+
def __init__(self, matched: str):
|
71
73
|
super().__init__()
|
72
|
-
self.
|
74
|
+
self.matched = matched
|
73
75
|
|
74
|
-
def
|
75
|
-
return
|
76
|
+
def to_json(self):
|
77
|
+
return {
|
78
|
+
"type": "stop", # to match OpenAI API's return value
|
79
|
+
"matched": self.matched,
|
80
|
+
}
|
76
81
|
|
77
82
|
|
78
|
-
class
|
79
|
-
def __init__(self,
|
83
|
+
class FINISH_LENGTH(BaseFinishReason):
|
84
|
+
def __init__(self, length: int):
|
80
85
|
super().__init__()
|
81
|
-
self.
|
86
|
+
self.length = length
|
82
87
|
|
83
|
-
def
|
84
|
-
return
|
88
|
+
def to_json(self):
|
89
|
+
return {
|
90
|
+
"type": "length", # to match OpenAI API's return value
|
91
|
+
"length": self.length,
|
92
|
+
}
|
85
93
|
|
86
94
|
|
87
95
|
class FINISH_ABORT(BaseFinishReason):
|
88
96
|
def __init__(self):
|
89
97
|
super().__init__(is_error=True)
|
90
98
|
|
91
|
-
def
|
92
|
-
return
|
99
|
+
def to_json(self):
|
100
|
+
return {
|
101
|
+
"type": "abort",
|
102
|
+
}
|
93
103
|
|
94
104
|
|
95
105
|
class Req:
|
96
106
|
"""Store all inforamtion of a request."""
|
97
107
|
|
98
|
-
def __init__(
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
rid: str,
|
111
|
+
origin_input_text: str,
|
112
|
+
origin_input_ids: Tuple[int],
|
113
|
+
lora_path: Optional[str] = None,
|
114
|
+
):
|
99
115
|
# Input and output info
|
100
116
|
self.rid = rid
|
101
117
|
self.origin_input_text = origin_input_text
|
@@ -103,10 +119,15 @@ class Req:
|
|
103
119
|
self.origin_input_ids = origin_input_ids
|
104
120
|
self.output_ids = [] # Each decode stage's output ids
|
105
121
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
122
|
+
self.lora_path = lora_path
|
106
123
|
|
107
124
|
# Memory info
|
108
125
|
self.req_pool_idx = None
|
109
126
|
|
127
|
+
# Check finish
|
128
|
+
self.tokenizer = None
|
129
|
+
self.finished_reason = None
|
130
|
+
|
110
131
|
# For incremental decoding
|
111
132
|
# ----- | --------- read_ids -------|
|
112
133
|
# ----- | surr_ids |
|
@@ -125,38 +146,43 @@ class Req:
|
|
125
146
|
# this does not include the jump forward tokens.
|
126
147
|
self.completion_tokens_wo_jump_forward = 0
|
127
148
|
|
128
|
-
# For vision
|
149
|
+
# For vision inputs
|
129
150
|
self.pixel_values = None
|
130
151
|
self.image_sizes = None
|
131
152
|
self.image_offsets = None
|
132
153
|
self.pad_value = None
|
154
|
+
self.modalities = None
|
133
155
|
|
134
156
|
# Prefix info
|
135
|
-
self.extend_input_len = 0
|
136
157
|
self.prefix_indices = []
|
158
|
+
self.extend_input_len = 0
|
137
159
|
self.last_node = None
|
138
160
|
|
139
161
|
# Sampling parameters
|
140
162
|
self.sampling_params = None
|
141
163
|
self.stream = False
|
142
164
|
|
143
|
-
#
|
144
|
-
self.tokenizer = None
|
145
|
-
self.finished_reason = None
|
146
|
-
|
147
|
-
# Logprobs
|
165
|
+
# Logprobs (arguments)
|
148
166
|
self.return_logprob = False
|
149
|
-
self.embedding = None
|
150
167
|
self.logprob_start_len = 0
|
151
168
|
self.top_logprobs_num = 0
|
169
|
+
|
170
|
+
# Logprobs (return value)
|
152
171
|
self.normalized_prompt_logprob = None
|
153
172
|
self.input_token_logprobs = None
|
154
173
|
self.input_top_logprobs = None
|
155
174
|
self.output_token_logprobs = []
|
156
175
|
self.output_top_logprobs = []
|
176
|
+
|
177
|
+
# Logprobs (internal values)
|
157
178
|
# The tokens is prefilled but need to be considered as decode tokens
|
158
179
|
# and should be updated for the decode logprobs
|
159
180
|
self.last_update_decode_tokens = 0
|
181
|
+
# The relative logprob_start_len in an extend batch
|
182
|
+
self.extend_logprob_start_len = 0
|
183
|
+
|
184
|
+
# Embedding
|
185
|
+
self.embedding = None
|
160
186
|
|
161
187
|
# Constrained decoding
|
162
188
|
self.regex_fsm: RegexGuide = None
|
@@ -333,6 +359,9 @@ class ScheduleBatch:
|
|
333
359
|
token_to_kv_pool: BaseTokenToKVPool
|
334
360
|
tree_cache: BasePrefixCache
|
335
361
|
|
362
|
+
forward_mode: ForwardMode = None
|
363
|
+
sampling_info: SamplingBatchInfo = None
|
364
|
+
|
336
365
|
# Batched arguments to model runner
|
337
366
|
input_ids: torch.Tensor = None
|
338
367
|
req_pool_indices: torch.Tensor = None
|
@@ -343,14 +372,19 @@ class ScheduleBatch:
|
|
343
372
|
|
344
373
|
# For mixed chunekd prefill
|
345
374
|
prefix_lens_cpu: List[int] = None
|
375
|
+
running_bs: int = None
|
346
376
|
|
347
377
|
# For processing logprobs
|
348
378
|
return_logprob: bool = False
|
349
379
|
top_logprobs_nums: List[int] = None
|
350
380
|
|
381
|
+
# Stream
|
382
|
+
has_stream: bool = False
|
383
|
+
|
351
384
|
@classmethod
|
352
385
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
353
386
|
return_logprob = any(req.return_logprob for req in reqs)
|
387
|
+
has_stream = any(req.stream for req in reqs)
|
354
388
|
|
355
389
|
return cls(
|
356
390
|
reqs=reqs,
|
@@ -358,18 +392,15 @@ class ScheduleBatch:
|
|
358
392
|
token_to_kv_pool=token_to_kv_pool,
|
359
393
|
tree_cache=tree_cache,
|
360
394
|
return_logprob=return_logprob,
|
395
|
+
has_stream=has_stream,
|
361
396
|
)
|
362
397
|
|
363
398
|
def batch_size(self):
|
364
|
-
return len(self.reqs)
|
399
|
+
return len(self.reqs)
|
365
400
|
|
366
401
|
def is_empty(self):
|
367
402
|
return len(self.reqs) == 0
|
368
403
|
|
369
|
-
def has_stream(self) -> bool:
|
370
|
-
# Return whether batch has at least 1 streaming request
|
371
|
-
return any(r.stream for r in self.reqs)
|
372
|
-
|
373
404
|
def alloc_req_slots(self, num_reqs):
|
374
405
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
375
406
|
if req_pool_indices is None:
|
@@ -396,6 +427,8 @@ class ScheduleBatch:
|
|
396
427
|
return out_cache_loc
|
397
428
|
|
398
429
|
def prepare_for_extend(self, vocab_size: int):
|
430
|
+
self.forward_mode = ForwardMode.EXTEND
|
431
|
+
|
399
432
|
bs = self.batch_size()
|
400
433
|
reqs = self.reqs
|
401
434
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
@@ -410,8 +443,8 @@ class ScheduleBatch:
|
|
410
443
|
for i, req in enumerate(reqs):
|
411
444
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
412
445
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
413
|
-
ext_len = seq_len - pre_len
|
414
446
|
seq_lens.append(seq_len)
|
447
|
+
assert seq_len - pre_len == req.extend_input_len
|
415
448
|
|
416
449
|
if pre_len > 0:
|
417
450
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
@@ -419,9 +452,19 @@ class ScheduleBatch:
|
|
419
452
|
] = req.prefix_indices
|
420
453
|
|
421
454
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
422
|
-
out_cache_loc[pt : pt +
|
455
|
+
out_cache_loc[pt : pt + req.extend_input_len]
|
423
456
|
)
|
424
|
-
|
457
|
+
|
458
|
+
# Compute the relative logprob_start_len in an extend batch
|
459
|
+
if req.logprob_start_len >= pre_len:
|
460
|
+
extend_logprob_start_len = min(
|
461
|
+
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
462
|
+
)
|
463
|
+
else:
|
464
|
+
extend_logprob_start_len = req.extend_input_len - 1
|
465
|
+
|
466
|
+
req.extend_logprob_start_len = extend_logprob_start_len
|
467
|
+
pt += req.extend_input_len
|
425
468
|
|
426
469
|
# Set fields
|
427
470
|
with torch.device("cuda"):
|
@@ -434,18 +477,13 @@ class ScheduleBatch:
|
|
434
477
|
self.out_cache_loc = out_cache_loc
|
435
478
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
436
479
|
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
437
|
-
|
480
|
+
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
|
481
|
+
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
438
482
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
439
483
|
|
440
484
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
441
|
-
|
442
|
-
|
443
|
-
prefix_lens_cpu.extend(
|
444
|
-
[
|
445
|
-
len(r.origin_input_ids) + len(r.output_ids) - 1
|
446
|
-
for r in running_batch.reqs
|
447
|
-
]
|
448
|
-
)
|
485
|
+
self.forward_mode = ForwardMode.MIXED
|
486
|
+
running_bs = running_batch.batch_size()
|
449
487
|
|
450
488
|
for req in running_batch.reqs:
|
451
489
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
@@ -453,12 +491,22 @@ class ScheduleBatch:
|
|
453
491
|
|
454
492
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
455
493
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
456
|
-
extend_num_tokens = self.extend_num_tokens +
|
494
|
+
extend_num_tokens = self.extend_num_tokens + running_bs
|
495
|
+
|
457
496
|
self.merge(running_batch)
|
458
497
|
self.input_ids = input_ids
|
459
498
|
self.out_cache_loc = out_cache_loc
|
460
499
|
self.extend_num_tokens = extend_num_tokens
|
461
|
-
|
500
|
+
|
501
|
+
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
502
|
+
self.prefix_lens_cpu.extend(
|
503
|
+
[
|
504
|
+
len(r.origin_input_ids) + len(r.output_ids) - 1
|
505
|
+
for r in running_batch.reqs
|
506
|
+
]
|
507
|
+
)
|
508
|
+
self.extend_lens_cpu.extend([1] * running_bs)
|
509
|
+
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
462
510
|
|
463
511
|
def check_decode_mem(self):
|
464
512
|
bs = self.batch_size()
|
@@ -625,6 +673,8 @@ class ScheduleBatch:
|
|
625
673
|
return jump_forward_reqs
|
626
674
|
|
627
675
|
def prepare_for_decode(self, input_ids=None):
|
676
|
+
self.forward_mode = ForwardMode.DECODE
|
677
|
+
|
628
678
|
if input_ids is None:
|
629
679
|
input_ids = [
|
630
680
|
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
@@ -644,8 +694,6 @@ class ScheduleBatch:
|
|
644
694
|
self.req_pool_indices, self.seq_lens - 1
|
645
695
|
] = self.out_cache_loc
|
646
696
|
|
647
|
-
self.sampling_info.update_regex_vocab_mask(self)
|
648
|
-
|
649
697
|
def filter_batch(self, unfinished_indices: List[int]):
|
650
698
|
if unfinished_indices is None or len(unfinished_indices) == 0:
|
651
699
|
# Filter out all requests
|
@@ -665,6 +713,7 @@ class ScheduleBatch:
|
|
665
713
|
self.out_cache_loc = None
|
666
714
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
667
715
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
716
|
+
self.has_stream = any(req.stream for req in self.reqs)
|
668
717
|
|
669
718
|
self.sampling_info.filter(unfinished_indices, new_indices)
|
670
719
|
|
@@ -675,7 +724,6 @@ class ScheduleBatch:
|
|
675
724
|
self.sampling_info.merge(other.sampling_info)
|
676
725
|
|
677
726
|
self.reqs.extend(other.reqs)
|
678
|
-
|
679
727
|
self.req_pool_indices = torch.concat(
|
680
728
|
[self.req_pool_indices, other.req_pool_indices]
|
681
729
|
)
|
@@ -686,18 +734,4 @@ class ScheduleBatch:
|
|
686
734
|
self.out_cache_loc = None
|
687
735
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
688
736
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
689
|
-
|
690
|
-
def check_sample_results(self, sample_output: SampleOutput):
|
691
|
-
if not torch.all(sample_output.success):
|
692
|
-
probs = sample_output.probs
|
693
|
-
batch_next_token_ids = sample_output.batch_next_token_ids
|
694
|
-
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
695
|
-
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
696
|
-
argmax_ids = torch.argmax(probs, dim=-1)
|
697
|
-
batch_next_token_ids = torch.where(
|
698
|
-
sample_output.success, batch_next_token_ids, argmax_ids
|
699
|
-
)
|
700
|
-
sample_output.probs = probs
|
701
|
-
sample_output.batch_next_token_ids = batch_next_token_ids
|
702
|
-
|
703
|
-
return sample_output.batch_next_token_ids
|
737
|
+
self.has_stream = any(req.stream for req in self.reqs)
|