sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,18 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
import warnings
|
20
20
|
from dataclasses import dataclass
|
21
|
-
from
|
22
|
-
from typing import List, Union
|
21
|
+
from typing import List, Optional, Union
|
23
22
|
|
24
|
-
import numpy as np
|
25
23
|
import torch
|
26
24
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
27
25
|
|
26
|
+
import sglang.srt.sampling.penaltylib as penaltylib
|
28
27
|
from sglang.global_config import global_config
|
29
28
|
from sglang.srt.constrained import RegexGuide
|
30
29
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
30
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
31
31
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
32
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
33
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache
|
34
33
|
|
35
34
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
36
35
|
|
@@ -46,15 +45,6 @@ global_server_args_dict = {
|
|
46
45
|
logger = logging.getLogger(__name__)
|
47
46
|
|
48
47
|
|
49
|
-
class ForwardMode(IntEnum):
|
50
|
-
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
51
|
-
PREFILL = auto()
|
52
|
-
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
53
|
-
EXTEND = auto()
|
54
|
-
# Decode one token.
|
55
|
-
DECODE = auto()
|
56
|
-
|
57
|
-
|
58
48
|
class BaseFinishReason:
|
59
49
|
def __init__(self, is_error: bool = False):
|
60
50
|
self.is_error = is_error
|
@@ -108,7 +98,10 @@ class Req:
|
|
108
98
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
109
99
|
self.origin_input_ids = origin_input_ids
|
110
100
|
self.output_ids = [] # Each decode stage's output ids
|
111
|
-
self.
|
101
|
+
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
102
|
+
|
103
|
+
# Memory info
|
104
|
+
self.req_pool_idx = None
|
112
105
|
|
113
106
|
# For incremental decoding
|
114
107
|
# ----- | --------- read_ids -------|
|
@@ -131,7 +124,7 @@ class Req:
|
|
131
124
|
# For vision input
|
132
125
|
self.pixel_values = None
|
133
126
|
self.image_size = None
|
134
|
-
self.image_offset =
|
127
|
+
self.image_offset = None
|
135
128
|
self.pad_value = None
|
136
129
|
|
137
130
|
# Prefix info
|
@@ -149,6 +142,7 @@ class Req:
|
|
149
142
|
|
150
143
|
# Logprobs
|
151
144
|
self.return_logprob = False
|
145
|
+
self.embedding = None
|
152
146
|
self.logprob_start_len = 0
|
153
147
|
self.top_logprobs_num = 0
|
154
148
|
self.normalized_prompt_logprob = None
|
@@ -169,6 +163,32 @@ class Req:
|
|
169
163
|
def finished(self) -> bool:
|
170
164
|
return self.finished_reason is not None
|
171
165
|
|
166
|
+
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
167
|
+
self.fill_ids = self.origin_input_ids + self.output_ids
|
168
|
+
if tree_cache is not None:
|
169
|
+
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
170
|
+
rid=self.rid, key=self.adjust_max_prefix_ids()
|
171
|
+
)
|
172
|
+
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
173
|
+
|
174
|
+
def adjust_max_prefix_ids(self):
|
175
|
+
self.fill_ids = self.origin_input_ids + self.output_ids
|
176
|
+
input_len = len(self.fill_ids)
|
177
|
+
max_prefix_len = input_len
|
178
|
+
|
179
|
+
if self.sampling_params.max_new_tokens > 0:
|
180
|
+
# Need at least one token to compute logits
|
181
|
+
max_prefix_len = min(max_prefix_len, input_len - 1)
|
182
|
+
|
183
|
+
if self.return_logprob:
|
184
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
185
|
+
|
186
|
+
if self.normalized_prompt_logprob is None:
|
187
|
+
# Need at least two tokens to compute normalized logprob
|
188
|
+
max_prefix_len = min(max_prefix_len, input_len - 2)
|
189
|
+
|
190
|
+
return self.fill_ids[:max_prefix_len]
|
191
|
+
|
172
192
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
173
193
|
def init_incremental_detokenize(self):
|
174
194
|
first_iter = self.surr_offset is None or self.read_offset is None
|
@@ -183,6 +203,8 @@ class Req:
|
|
183
203
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
184
204
|
|
185
205
|
def get_next_inc_detokenization(self):
|
206
|
+
if self.tokenizer is None:
|
207
|
+
return False, ""
|
186
208
|
read_ids, read_offset = self.init_incremental_detokenize()
|
187
209
|
surr_ids = read_ids[:read_offset]
|
188
210
|
|
@@ -207,16 +229,18 @@ class Req:
|
|
207
229
|
return
|
208
230
|
|
209
231
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
210
|
-
self.finished_reason = FINISH_LENGTH(
|
232
|
+
self.finished_reason = FINISH_LENGTH(
|
233
|
+
length=self.sampling_params.max_new_tokens
|
234
|
+
)
|
211
235
|
return
|
212
236
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
)
|
237
|
+
last_token_id = self.output_ids[-1]
|
238
|
+
if self.tokenizer is None:
|
239
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
240
|
+
else:
|
241
|
+
matched_eos = last_token_id == self.tokenizer.eos_token_id
|
242
|
+
if matched_eos and not self.sampling_params.ignore_eos:
|
243
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
220
244
|
return
|
221
245
|
|
222
246
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -284,20 +308,19 @@ class Req:
|
|
284
308
|
|
285
309
|
|
286
310
|
@dataclass
|
287
|
-
class
|
311
|
+
class ScheduleBatch:
|
288
312
|
"""Store all inforamtion of a batch."""
|
289
313
|
|
290
314
|
# Request, memory pool, and cache
|
291
315
|
reqs: List[Req]
|
292
316
|
req_to_token_pool: ReqToTokenPool
|
293
317
|
token_to_kv_pool: BaseTokenToKVPool
|
294
|
-
tree_cache:
|
318
|
+
tree_cache: BasePrefixCache
|
295
319
|
|
296
320
|
# Batched arguments to model runner
|
297
321
|
input_ids: torch.Tensor = None
|
298
322
|
req_pool_indices: torch.Tensor = None
|
299
323
|
seq_lens: torch.Tensor = None
|
300
|
-
prefix_lens: torch.Tensor = None
|
301
324
|
position_ids_offsets: torch.Tensor = None
|
302
325
|
out_cache_loc: torch.Tensor = None
|
303
326
|
extend_num_tokens: int = None
|
@@ -306,17 +329,11 @@ class Batch:
|
|
306
329
|
return_logprob: bool = False
|
307
330
|
top_logprobs_nums: List[int] = None
|
308
331
|
|
309
|
-
# For multimodal
|
310
|
-
pixel_values: List[torch.Tensor] = None
|
311
|
-
image_sizes: List[List[int]] = None
|
312
|
-
image_offsets: List[int] = None
|
313
|
-
|
314
332
|
# Batched sampling params
|
315
333
|
temperatures: torch.Tensor = None
|
316
334
|
top_ps: torch.Tensor = None
|
317
335
|
top_ks: torch.Tensor = None
|
318
|
-
|
319
|
-
presence_penalties: torch.Tensor = None
|
336
|
+
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
320
337
|
logit_bias: torch.Tensor = None
|
321
338
|
|
322
339
|
@classmethod
|
@@ -331,6 +348,9 @@ class Batch:
|
|
331
348
|
return_logprob=return_logprob,
|
332
349
|
)
|
333
350
|
|
351
|
+
def batch_size(self):
|
352
|
+
return len(self.reqs) if self.reqs is not None else 0
|
353
|
+
|
334
354
|
def is_empty(self):
|
335
355
|
return len(self.reqs) == 0
|
336
356
|
|
@@ -338,52 +358,22 @@ class Batch:
|
|
338
358
|
# Return whether batch has at least 1 streaming request
|
339
359
|
return any(r.stream for r in self.reqs)
|
340
360
|
|
341
|
-
def
|
342
|
-
|
343
|
-
bs = len(self.reqs)
|
344
|
-
reqs = self.reqs
|
345
|
-
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
346
|
-
prefix_indices = [r.prefix_indices for r in reqs]
|
347
|
-
|
348
|
-
# Handle prefix
|
349
|
-
flatten_input_ids = []
|
350
|
-
extend_lens = []
|
351
|
-
prefix_lens = []
|
352
|
-
seq_lens = []
|
353
|
-
|
354
|
-
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
355
|
-
|
361
|
+
def alloc_req_slots(self, num_reqs):
|
362
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
356
363
|
if req_pool_indices is None:
|
357
364
|
raise RuntimeError(
|
358
365
|
"Out of memory. "
|
359
366
|
"Please set a smaller number for `--max-running-requests`."
|
360
367
|
)
|
368
|
+
return req_pool_indices
|
361
369
|
|
362
|
-
|
363
|
-
|
364
|
-
flatten_input_ids.extend(input_ids[i])
|
365
|
-
extend_lens.append(len(input_ids[i]))
|
366
|
-
|
367
|
-
if len(prefix_indices[i]) == 0:
|
368
|
-
prefix_lens.append(0)
|
369
|
-
else:
|
370
|
-
prefix_lens.append(len(prefix_indices[i]))
|
371
|
-
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
372
|
-
: len(prefix_indices[i])
|
373
|
-
] = prefix_indices[i]
|
374
|
-
|
375
|
-
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
370
|
+
def alloc_token_slots(self, num_tokens: int):
|
371
|
+
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
376
372
|
|
377
|
-
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
378
|
-
|
379
|
-
# Allocate memory
|
380
|
-
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
381
|
-
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
382
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
383
373
|
if out_cache_loc is None:
|
384
374
|
if self.tree_cache is not None:
|
385
|
-
self.tree_cache.evict(
|
386
|
-
out_cache_loc = self.token_to_kv_pool.alloc(
|
375
|
+
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
376
|
+
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
387
377
|
|
388
378
|
if out_cache_loc is None:
|
389
379
|
logger.error("Prefill out of memory. Try to lower your batch size.")
|
@@ -391,40 +381,11 @@ class Batch:
|
|
391
381
|
self.tree_cache.pretty_print()
|
392
382
|
exit(1)
|
393
383
|
|
394
|
-
|
395
|
-
for i in range(bs):
|
396
|
-
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
397
|
-
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
398
|
-
] = out_cache_loc[pt : pt + extend_lens[i]]
|
399
|
-
pt += extend_lens[i]
|
400
|
-
|
401
|
-
# Handle logit bias but only allocate when needed
|
402
|
-
logit_bias = None
|
403
|
-
for i in range(bs):
|
404
|
-
if reqs[i].sampling_params.dtype == "int":
|
405
|
-
if logit_bias is None:
|
406
|
-
logit_bias = torch.zeros(
|
407
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
408
|
-
)
|
409
|
-
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
410
|
-
|
411
|
-
# Set fields
|
412
|
-
self.input_ids = torch.tensor(
|
413
|
-
flatten_input_ids, dtype=torch.int32, device=device
|
414
|
-
)
|
415
|
-
self.pixel_values = [r.pixel_values for r in reqs]
|
416
|
-
self.image_sizes = [r.image_size for r in reqs]
|
417
|
-
self.image_offsets = [
|
418
|
-
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
419
|
-
]
|
420
|
-
self.req_pool_indices = req_pool_indices
|
421
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
422
|
-
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
423
|
-
self.position_ids_offsets = position_ids_offsets
|
424
|
-
self.extend_num_tokens = extend_num_tokens
|
425
|
-
self.out_cache_loc = out_cache_loc
|
426
|
-
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
384
|
+
return out_cache_loc
|
427
385
|
|
386
|
+
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
387
|
+
device = "cuda"
|
388
|
+
bs, reqs = self.batch_size(), self.reqs
|
428
389
|
self.temperatures = torch.tensor(
|
429
390
|
[r.sampling_params.temperature for r in reqs],
|
430
391
|
dtype=torch.float,
|
@@ -436,20 +397,79 @@ class Batch:
|
|
436
397
|
self.top_ks = torch.tensor(
|
437
398
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
438
399
|
)
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
400
|
+
|
401
|
+
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
402
|
+
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
403
|
+
# should not add hefty computation overhead other than simple checks.
|
404
|
+
#
|
405
|
+
# While we choose not to even create the class instances if they are not required, this
|
406
|
+
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
407
|
+
# handle {filter_batch()} and {merge()} cases as well.
|
408
|
+
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
409
|
+
vocab_size=vocab_size,
|
410
|
+
batch=self,
|
447
411
|
device=device,
|
412
|
+
Penalizers={
|
413
|
+
penaltylib.BatchedFrequencyPenalizer,
|
414
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
415
|
+
penaltylib.BatchedPresencePenalizer,
|
416
|
+
penaltylib.BatchedRepetitionPenalizer,
|
417
|
+
},
|
448
418
|
)
|
449
|
-
|
419
|
+
|
420
|
+
# Handle logit bias but only allocate when needed
|
421
|
+
self.logit_bias = None
|
422
|
+
for i in range(bs):
|
423
|
+
if reqs[i].sampling_params.dtype == "int":
|
424
|
+
if self.logit_bias is None:
|
425
|
+
self.logit_bias = torch.zeros(
|
426
|
+
(bs, vocab_size), dtype=torch.float32, device=device
|
427
|
+
)
|
428
|
+
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
429
|
+
|
430
|
+
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
431
|
+
bs = self.batch_size()
|
432
|
+
reqs = self.reqs
|
433
|
+
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
434
|
+
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
435
|
+
seq_lens = []
|
436
|
+
|
437
|
+
# Allocate memory
|
438
|
+
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
439
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
440
|
+
|
441
|
+
pt = 0
|
442
|
+
for i, req in enumerate(reqs):
|
443
|
+
req.req_pool_idx = req_pool_indices_cpu[i]
|
444
|
+
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
445
|
+
ext_len = seq_len - pre_len
|
446
|
+
seq_lens.append(seq_len)
|
447
|
+
|
448
|
+
if pre_len > 0:
|
449
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
450
|
+
:pre_len
|
451
|
+
] = req.prefix_indices
|
452
|
+
|
453
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
454
|
+
out_cache_loc[pt : pt + ext_len]
|
455
|
+
)
|
456
|
+
pt += ext_len
|
457
|
+
|
458
|
+
# Set fields
|
459
|
+
with torch.device("cuda"):
|
460
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
461
|
+
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
462
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
463
|
+
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
464
|
+
|
465
|
+
self.extend_num_tokens = extend_num_tokens
|
466
|
+
self.out_cache_loc = out_cache_loc
|
467
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
468
|
+
|
469
|
+
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
450
470
|
|
451
471
|
def check_decode_mem(self):
|
452
|
-
bs =
|
472
|
+
bs = self.batch_size()
|
453
473
|
if self.token_to_kv_pool.available_size() >= bs:
|
454
474
|
return True
|
455
475
|
|
@@ -474,7 +494,6 @@ class Batch:
|
|
474
494
|
|
475
495
|
retracted_reqs = []
|
476
496
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
477
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
478
497
|
while (
|
479
498
|
self.token_to_kv_pool.available_size()
|
480
499
|
< len(sorted_indices) * global_config.retract_decode_steps
|
@@ -492,20 +511,20 @@ class Batch:
|
|
492
511
|
|
493
512
|
if isinstance(self.tree_cache, ChunkCache):
|
494
513
|
# ChunkCache does not have eviction
|
495
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
496
|
-
|
497
|
-
]
|
514
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
515
|
+
: seq_lens_cpu[idx]
|
516
|
+
]
|
498
517
|
self.token_to_kv_pool.free(token_indices)
|
499
|
-
self.req_to_token_pool.free(
|
518
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
500
519
|
del self.tree_cache.entries[req.rid]
|
501
520
|
else:
|
502
521
|
# TODO: apply more fine-grained retraction
|
503
522
|
last_uncached_pos = len(req.prefix_indices)
|
504
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
505
|
-
|
506
|
-
]
|
523
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
524
|
+
last_uncached_pos : seq_lens_cpu[idx]
|
525
|
+
]
|
507
526
|
self.token_to_kv_pool.free(token_indices)
|
508
|
-
self.req_to_token_pool.free(
|
527
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
509
528
|
|
510
529
|
# release the last node
|
511
530
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -518,7 +537,7 @@ class Batch:
|
|
518
537
|
residual_size = max(0, residual_size)
|
519
538
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
520
539
|
|
521
|
-
req.prefix_indices =
|
540
|
+
req.prefix_indices = []
|
522
541
|
req.last_node = None
|
523
542
|
req.extend_input_len = 0
|
524
543
|
|
@@ -543,8 +562,6 @@ class Batch:
|
|
543
562
|
jump_forward_reqs = []
|
544
563
|
filter_indices = [i for i in range(len(self.reqs))]
|
545
564
|
|
546
|
-
req_pool_indices_cpu = None
|
547
|
-
|
548
565
|
for i, req in enumerate(self.reqs):
|
549
566
|
if req.jump_forward_map is not None:
|
550
567
|
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
@@ -594,17 +611,7 @@ class Batch:
|
|
594
611
|
req.vid += 1
|
595
612
|
|
596
613
|
# insert the old request into tree_cache
|
597
|
-
|
598
|
-
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
599
|
-
self.tree_cache.cache_req(
|
600
|
-
rid=req.rid,
|
601
|
-
token_ids=cur_all_ids,
|
602
|
-
last_uncached_pos=len(req.prefix_indices),
|
603
|
-
req_pool_idx=req_pool_indices_cpu[i],
|
604
|
-
)
|
605
|
-
|
606
|
-
# unlock the last node
|
607
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
614
|
+
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
608
615
|
|
609
616
|
# re-applying image padding
|
610
617
|
if req.pixel_values is not None:
|
@@ -621,66 +628,74 @@ class Batch:
|
|
621
628
|
jump_forward_reqs.append(req)
|
622
629
|
filter_indices.remove(i)
|
623
630
|
|
624
|
-
|
625
|
-
self.filter_batch(filter_indices)
|
631
|
+
self.filter_batch(filter_indices)
|
626
632
|
|
627
633
|
return jump_forward_reqs
|
628
634
|
|
629
635
|
def prepare_for_decode(self, input_ids=None):
|
630
636
|
if input_ids is None:
|
631
637
|
input_ids = [
|
632
|
-
r.output_ids[-1] if r.output_ids else r.
|
638
|
+
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
639
|
+
for r in self.reqs
|
633
640
|
]
|
641
|
+
else:
|
642
|
+
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
643
|
+
|
634
644
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
635
645
|
self.seq_lens.add_(1)
|
636
|
-
self.prefix_lens = None
|
637
646
|
|
638
647
|
# Alloc mem
|
639
|
-
bs =
|
640
|
-
self.out_cache_loc = self.
|
641
|
-
|
642
|
-
if self.out_cache_loc is None:
|
643
|
-
logger.error("Decode out of memory. Try to lower your batch size.")
|
644
|
-
if self.tree_cache is not None:
|
645
|
-
self.tree_cache.pretty_print()
|
646
|
-
exit(1)
|
648
|
+
bs = self.batch_size()
|
649
|
+
self.out_cache_loc = self.alloc_token_slots(bs)
|
647
650
|
|
648
651
|
self.req_to_token_pool.req_to_token[
|
649
652
|
self.req_pool_indices, self.seq_lens - 1
|
650
653
|
] = self.out_cache_loc
|
651
654
|
|
652
655
|
def filter_batch(self, unfinished_indices: List[int]):
|
656
|
+
if unfinished_indices is None or len(unfinished_indices) == 0:
|
657
|
+
# Filter out all requests
|
658
|
+
self.reqs = []
|
659
|
+
return
|
660
|
+
|
661
|
+
if len(unfinished_indices) == len(self.reqs):
|
662
|
+
# No need to filter
|
663
|
+
return
|
664
|
+
|
653
665
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
654
666
|
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
655
667
|
self.seq_lens = self.seq_lens[new_indices]
|
656
668
|
self.input_ids = None
|
657
669
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
658
|
-
self.prefix_lens = None
|
659
670
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
660
671
|
self.out_cache_loc = None
|
661
672
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
662
673
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
663
674
|
|
675
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
676
|
+
|
664
677
|
for item in [
|
665
678
|
"temperatures",
|
666
679
|
"top_ps",
|
667
680
|
"top_ks",
|
668
|
-
"frequency_penalties",
|
669
|
-
"presence_penalties",
|
670
681
|
"logit_bias",
|
671
682
|
]:
|
672
683
|
self_val = getattr(self, item, None)
|
673
684
|
if self_val is not None: # logit_bias can be None
|
674
685
|
setattr(self, item, self_val[new_indices])
|
675
686
|
|
676
|
-
def merge(self, other: "
|
687
|
+
def merge(self, other: "ScheduleBatch"):
|
688
|
+
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
689
|
+
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
690
|
+
# needs to be called with pre-merged Batch.reqs.
|
691
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
692
|
+
|
677
693
|
self.reqs.extend(other.reqs)
|
678
694
|
|
679
695
|
self.req_pool_indices = torch.concat(
|
680
696
|
[self.req_pool_indices, other.req_pool_indices]
|
681
697
|
)
|
682
698
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
683
|
-
self.prefix_lens = None
|
684
699
|
self.position_ids_offsets = torch.concat(
|
685
700
|
[self.position_ids_offsets, other.position_ids_offsets]
|
686
701
|
)
|
@@ -692,8 +707,6 @@ class Batch:
|
|
692
707
|
"temperatures",
|
693
708
|
"top_ps",
|
694
709
|
"top_ks",
|
695
|
-
"frequency_penalties",
|
696
|
-
"presence_penalties",
|
697
710
|
]:
|
698
711
|
self_val = getattr(self, item, None)
|
699
712
|
other_val = getattr(other, item, None)
|
@@ -717,6 +730,7 @@ class Batch:
|
|
717
730
|
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
718
731
|
|
719
732
|
def sample(self, logits: torch.Tensor):
|
733
|
+
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
720
734
|
# Post process logits
|
721
735
|
logits = logits.contiguous()
|
722
736
|
logits.div_(self.temperatures)
|
@@ -734,7 +748,8 @@ class Batch:
|
|
734
748
|
] = 1
|
735
749
|
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
736
750
|
|
737
|
-
|
751
|
+
logits = self.penalizer_orchestrator.apply(logits)
|
752
|
+
|
738
753
|
probs = torch.softmax(logits, dim=-1)
|
739
754
|
|
740
755
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
@@ -767,230 +782,9 @@ class Batch:
|
|
767
782
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
768
783
|
)
|
769
784
|
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
@dataclass
|
774
|
-
class InputMetadata:
|
775
|
-
"""Store all inforamtion of a forward pass."""
|
776
|
-
|
777
|
-
forward_mode: ForwardMode
|
778
|
-
batch_size: int
|
779
|
-
total_num_tokens: int
|
780
|
-
req_pool_indices: torch.Tensor
|
781
|
-
seq_lens: torch.Tensor
|
782
|
-
positions: torch.Tensor
|
783
|
-
req_to_token_pool: ReqToTokenPool
|
784
|
-
token_to_kv_pool: BaseTokenToKVPool
|
785
|
+
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
785
786
|
|
786
|
-
|
787
|
-
extend_seq_lens: torch.Tensor
|
788
|
-
extend_start_loc: torch.Tensor
|
789
|
-
extend_no_prefix: bool
|
790
|
-
|
791
|
-
# Output location of the KV cache
|
792
|
-
out_cache_loc: torch.Tensor = None
|
793
|
-
|
794
|
-
# Output options
|
795
|
-
return_logprob: bool = False
|
796
|
-
top_logprobs_nums: List[int] = None
|
797
|
-
|
798
|
-
# Trition attention backend
|
799
|
-
triton_max_seq_len: int = 0
|
800
|
-
triton_max_extend_len: int = 0
|
801
|
-
triton_start_loc: torch.Tensor = None
|
802
|
-
triton_prefix_lens: torch.Tensor = None
|
803
|
-
|
804
|
-
# FlashInfer attention backend
|
805
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
806
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
807
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
808
|
-
flashinfer_use_ragged: bool = False
|
809
|
-
|
810
|
-
@classmethod
|
811
|
-
def create(
|
812
|
-
cls,
|
813
|
-
model_runner,
|
814
|
-
forward_mode,
|
815
|
-
req_pool_indices,
|
816
|
-
seq_lens,
|
817
|
-
prefix_lens,
|
818
|
-
position_ids_offsets,
|
819
|
-
out_cache_loc,
|
820
|
-
top_logprobs_nums=None,
|
821
|
-
return_logprob=False,
|
822
|
-
skip_flashinfer_init=False,
|
823
|
-
):
|
824
|
-
flashinfer_use_ragged = False
|
825
|
-
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
826
|
-
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
827
|
-
flashinfer_use_ragged = True
|
828
|
-
init_flashinfer_args(
|
829
|
-
forward_mode,
|
830
|
-
model_runner,
|
831
|
-
req_pool_indices,
|
832
|
-
seq_lens,
|
833
|
-
prefix_lens,
|
834
|
-
model_runner.flashinfer_decode_wrapper,
|
835
|
-
flashinfer_use_ragged,
|
836
|
-
)
|
837
|
-
|
838
|
-
batch_size = len(req_pool_indices)
|
839
|
-
|
840
|
-
if forward_mode == ForwardMode.DECODE:
|
841
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
842
|
-
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
843
|
-
if not model_runner.server_args.disable_flashinfer:
|
844
|
-
# This variable is not needed in this case,
|
845
|
-
# we do not compute it to make it compatbile with cuda graph.
|
846
|
-
total_num_tokens = None
|
847
|
-
else:
|
848
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
849
|
-
else:
|
850
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
851
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
852
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
853
|
-
positions = torch.tensor(
|
854
|
-
np.concatenate(
|
855
|
-
[
|
856
|
-
np.arange(
|
857
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
858
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
859
|
-
)
|
860
|
-
for i in range(batch_size)
|
861
|
-
],
|
862
|
-
axis=0,
|
863
|
-
),
|
864
|
-
device="cuda",
|
865
|
-
)
|
866
|
-
extend_seq_lens = seq_lens - prefix_lens
|
867
|
-
extend_start_loc = torch.zeros_like(seq_lens)
|
868
|
-
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
869
|
-
extend_no_prefix = torch.all(prefix_lens == 0)
|
870
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
871
|
-
|
872
|
-
ret = cls(
|
873
|
-
forward_mode=forward_mode,
|
874
|
-
batch_size=batch_size,
|
875
|
-
total_num_tokens=total_num_tokens,
|
876
|
-
req_pool_indices=req_pool_indices,
|
877
|
-
seq_lens=seq_lens,
|
878
|
-
positions=positions,
|
879
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
880
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
881
|
-
out_cache_loc=out_cache_loc,
|
882
|
-
extend_seq_lens=extend_seq_lens,
|
883
|
-
extend_start_loc=extend_start_loc,
|
884
|
-
extend_no_prefix=extend_no_prefix,
|
885
|
-
return_logprob=return_logprob,
|
886
|
-
top_logprobs_nums=top_logprobs_nums,
|
887
|
-
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
888
|
-
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
889
|
-
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
890
|
-
flashinfer_use_ragged=flashinfer_use_ragged,
|
891
|
-
)
|
892
|
-
|
893
|
-
if model_runner.server_args.disable_flashinfer:
|
894
|
-
(
|
895
|
-
ret.triton_max_seq_len,
|
896
|
-
ret.triton_max_extend_len,
|
897
|
-
ret.triton_start_loc,
|
898
|
-
ret.triton_prefix_lens,
|
899
|
-
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
900
|
-
|
901
|
-
return ret
|
902
|
-
|
903
|
-
|
904
|
-
def init_flashinfer_args(
|
905
|
-
forward_mode,
|
906
|
-
model_runner,
|
907
|
-
req_pool_indices,
|
908
|
-
seq_lens,
|
909
|
-
prefix_lens,
|
910
|
-
flashinfer_decode_wrapper,
|
911
|
-
flashinfer_use_ragged=False,
|
912
|
-
):
|
913
|
-
"""Init auxiliary variables for FlashInfer attention backend."""
|
914
|
-
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
915
|
-
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
916
|
-
head_dim = model_runner.model_config.head_dim
|
917
|
-
batch_size = len(req_pool_indices)
|
918
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
919
|
-
|
920
|
-
if flashinfer_use_ragged:
|
921
|
-
paged_kernel_lens = prefix_lens
|
922
|
-
else:
|
923
|
-
paged_kernel_lens = seq_lens
|
924
|
-
|
925
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
926
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
927
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
928
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
929
|
-
kv_indices = torch.cat(
|
930
|
-
[
|
931
|
-
model_runner.req_to_token_pool.req_to_token[
|
932
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
933
|
-
]
|
934
|
-
for i in range(batch_size)
|
935
|
-
],
|
936
|
-
dim=0,
|
937
|
-
).contiguous()
|
938
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
939
|
-
|
940
|
-
if forward_mode == ForwardMode.DECODE:
|
941
|
-
flashinfer_decode_wrapper.end_forward()
|
942
|
-
flashinfer_decode_wrapper.begin_forward(
|
943
|
-
kv_indptr,
|
944
|
-
kv_indices,
|
945
|
-
kv_last_page_len,
|
946
|
-
num_qo_heads,
|
947
|
-
num_kv_heads,
|
948
|
-
head_dim,
|
949
|
-
1,
|
950
|
-
)
|
951
|
-
else:
|
952
|
-
# extend part
|
953
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
954
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
955
|
-
|
956
|
-
if flashinfer_use_ragged:
|
957
|
-
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
958
|
-
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
959
|
-
qo_indptr,
|
960
|
-
qo_indptr,
|
961
|
-
num_qo_heads,
|
962
|
-
num_kv_heads,
|
963
|
-
head_dim,
|
964
|
-
)
|
965
|
-
|
966
|
-
# cached part
|
967
|
-
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
968
|
-
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
969
|
-
qo_indptr,
|
970
|
-
kv_indptr,
|
971
|
-
kv_indices,
|
972
|
-
kv_last_page_len,
|
973
|
-
num_qo_heads,
|
974
|
-
num_kv_heads,
|
975
|
-
head_dim,
|
976
|
-
1,
|
977
|
-
)
|
978
|
-
|
979
|
-
|
980
|
-
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
981
|
-
"""Init auxiliary variables for triton attention backend."""
|
982
|
-
batch_size = len(seq_lens)
|
983
|
-
max_seq_len = int(torch.max(seq_lens))
|
984
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
985
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
986
|
-
|
987
|
-
if forward_mode == ForwardMode.DECODE:
|
988
|
-
max_extend_len = None
|
989
|
-
else:
|
990
|
-
extend_seq_lens = seq_lens - prefix_lens
|
991
|
-
max_extend_len = int(torch.max(extend_seq_lens))
|
992
|
-
|
993
|
-
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
787
|
+
return batch_next_token_ids
|
994
788
|
|
995
789
|
|
996
790
|
def top_k_top_p_sampling_from_probs_torch(
|
@@ -1009,7 +803,7 @@ def top_k_top_p_sampling_from_probs_torch(
|
|
1009
803
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
1010
804
|
except RuntimeError:
|
1011
805
|
batch_next_token_ids = torch.zeros(
|
1012
|
-
(probs_sort.shape[0],), dtype=torch.
|
806
|
+
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
1013
807
|
)
|
1014
808
|
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
1015
809
|
return batch_next_token_ids, success
|