sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
|
49
49
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
50
50
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
51
51
|
from sglang.srt.server_args import ServerArgs
|
52
|
+
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
52
53
|
|
53
54
|
if TYPE_CHECKING:
|
54
55
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
@@ -158,15 +159,19 @@ class ImageInputs:
|
|
158
159
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
159
160
|
mrope_position_delta: Optional[torch.Tensor] = None
|
160
161
|
|
161
|
-
#
|
162
|
+
# The id of the single-image placeholder token
|
163
|
+
im_token_id: Optional[torch.Tensor] = None
|
162
164
|
# All the images in the batch should share the same special image
|
163
165
|
# bound token ids.
|
164
|
-
im_start_id: Optional[
|
165
|
-
im_end_id: Optional[
|
166
|
-
slice_start_id: Optional[
|
167
|
-
slice_end_id: Optional[
|
166
|
+
im_start_id: Optional[int] = None
|
167
|
+
im_end_id: Optional[int] = None
|
168
|
+
slice_start_id: Optional[int] = None
|
169
|
+
slice_end_id: Optional[int] = None
|
168
170
|
tgt_sizes: Optional[list] = None
|
169
171
|
|
172
|
+
# denotes the number of valid image tokens in each image
|
173
|
+
images_emb_mask: Optional[torch.BoolTensor] = None
|
174
|
+
|
170
175
|
@staticmethod
|
171
176
|
def from_dict(obj: dict):
|
172
177
|
ret = ImageInputs(
|
@@ -186,11 +191,13 @@ class ImageInputs:
|
|
186
191
|
"aspect_ratio_ids",
|
187
192
|
"aspect_ratio_mask",
|
188
193
|
"image_grid_thws",
|
194
|
+
"im_token_id",
|
189
195
|
"im_start_id",
|
190
196
|
"im_end_id",
|
191
197
|
"slice_start_id",
|
192
198
|
"slice_end_id",
|
193
199
|
"tgt_sizes",
|
200
|
+
"images_emb_mask",
|
194
201
|
]
|
195
202
|
for arg in optional_args:
|
196
203
|
if arg in obj:
|
@@ -267,7 +274,6 @@ class Req:
|
|
267
274
|
"__req__": self
|
268
275
|
}
|
269
276
|
self.sampling_params = sampling_params
|
270
|
-
|
271
277
|
self.custom_logit_processor = custom_logit_processor
|
272
278
|
self.return_hidden_states = return_hidden_states
|
273
279
|
|
@@ -309,6 +315,7 @@ class Req:
|
|
309
315
|
# The relative logprob_start_len in an extend batch
|
310
316
|
self.extend_logprob_start_len = 0
|
311
317
|
self.last_node = None
|
318
|
+
self.last_node_global = None
|
312
319
|
|
313
320
|
# Whether or not if it is chunked. It increments whenever
|
314
321
|
# it is chunked, and decrement whenever chunked request is
|
@@ -324,6 +331,8 @@ class Req:
|
|
324
331
|
self.logprob_start_len = 0
|
325
332
|
self.top_logprobs_num = top_logprobs_num
|
326
333
|
self.token_ids_logprob = token_ids_logprob
|
334
|
+
self.temp_scaled_logprobs = False
|
335
|
+
self.top_p_normalized_logprobs = False
|
327
336
|
|
328
337
|
# Logprobs (return values)
|
329
338
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
@@ -352,7 +361,7 @@ class Req:
|
|
352
361
|
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
353
362
|
self.output_token_ids_logprobs_idx
|
354
363
|
) = None
|
355
|
-
self.hidden_states = []
|
364
|
+
self.hidden_states: List[List[float]] = []
|
356
365
|
|
357
366
|
# Embedding (return values)
|
358
367
|
self.embedding = None
|
@@ -383,13 +392,24 @@ class Req:
|
|
383
392
|
# Whether request reached finished condition
|
384
393
|
return self.finished_reason is not None
|
385
394
|
|
386
|
-
def init_next_round_input(
|
395
|
+
def init_next_round_input(
|
396
|
+
self,
|
397
|
+
tree_cache: Optional[BasePrefixCache] = None,
|
398
|
+
enable_hierarchical_cache=False,
|
399
|
+
):
|
387
400
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
388
401
|
if tree_cache is not None:
|
389
402
|
# tree cache is None if the prefix is not computed with tree cache.
|
390
|
-
|
391
|
-
|
392
|
-
|
403
|
+
if enable_hierarchical_cache:
|
404
|
+
self.prefix_indices, self.last_node, self.last_node_global = (
|
405
|
+
tree_cache.match_prefix(
|
406
|
+
key=self.adjust_max_prefix_ids(), include_evicted=True
|
407
|
+
)
|
408
|
+
)
|
409
|
+
else:
|
410
|
+
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
411
|
+
rid=self.rid, key=self.adjust_max_prefix_ids()
|
412
|
+
)
|
393
413
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
394
414
|
|
395
415
|
def adjust_max_prefix_ids(self):
|
@@ -423,28 +443,6 @@ class Req:
|
|
423
443
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
424
444
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
425
445
|
|
426
|
-
def get_next_inc_detokenization(self):
|
427
|
-
if self.tokenizer is None:
|
428
|
-
return False, ""
|
429
|
-
read_ids, read_offset = self.init_incremental_detokenize()
|
430
|
-
surr_ids = read_ids[:read_offset]
|
431
|
-
|
432
|
-
surr_text = self.tokenizer.decode(
|
433
|
-
surr_ids,
|
434
|
-
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
435
|
-
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
436
|
-
)
|
437
|
-
new_text = self.tokenizer.decode(
|
438
|
-
read_ids,
|
439
|
-
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
440
|
-
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
441
|
-
)
|
442
|
-
|
443
|
-
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
444
|
-
return True, new_text[len(surr_text) :]
|
445
|
-
|
446
|
-
return False, ""
|
447
|
-
|
448
446
|
def check_finished(self):
|
449
447
|
if self.finished():
|
450
448
|
return
|
@@ -528,19 +526,23 @@ class ScheduleBatch:
|
|
528
526
|
model_config: ModelConfig = None
|
529
527
|
forward_mode: ForwardMode = None
|
530
528
|
enable_overlap: bool = False
|
529
|
+
# Tell whether the current running batch is full so that we can skip
|
530
|
+
# the check of whether to prefill new requests.
|
531
|
+
# This is an optimization to reduce the overhead of the prefill check.
|
532
|
+
batch_is_full: bool = False
|
531
533
|
|
532
534
|
# Sampling info
|
533
535
|
sampling_info: SamplingBatchInfo = None
|
534
536
|
next_batch_sampling_info: SamplingBatchInfo = None
|
535
537
|
|
536
538
|
# Batched arguments to model runner
|
537
|
-
input_ids: torch.Tensor = None # shape: [b],
|
539
|
+
input_ids: torch.Tensor = None # shape: [b], int64
|
538
540
|
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
539
|
-
req_pool_indices: torch.Tensor = None # shape: [b],
|
541
|
+
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
540
542
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
541
543
|
# The output locations of the KV cache
|
542
|
-
out_cache_loc: torch.Tensor = None # shape: [b],
|
543
|
-
output_ids: torch.Tensor = None # shape: [b],
|
544
|
+
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
545
|
+
output_ids: torch.Tensor = None # shape: [b], int64
|
544
546
|
|
545
547
|
# The sum of all sequence lengths
|
546
548
|
seq_lens_sum: int = None
|
@@ -555,6 +557,10 @@ class ScheduleBatch:
|
|
555
557
|
top_logprobs_nums: Optional[List[int]] = None
|
556
558
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
557
559
|
|
560
|
+
# For logits and logprob post processing
|
561
|
+
temp_scaled_logprobs: bool = False
|
562
|
+
top_p_normalized_logprobs: bool = False
|
563
|
+
|
558
564
|
# For extend and mixed chunekd prefill
|
559
565
|
prefix_lens: List[int] = None
|
560
566
|
extend_lens: List[int] = None
|
@@ -564,7 +570,7 @@ class ScheduleBatch:
|
|
564
570
|
# It comes empty list if logprob is not required.
|
565
571
|
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
566
572
|
|
567
|
-
# For encoder-decoder
|
573
|
+
# For encoder-decoder architectures
|
568
574
|
encoder_cached: Optional[List[bool]] = None
|
569
575
|
encoder_lens: Optional[torch.Tensor] = None
|
570
576
|
encoder_lens_cpu: Optional[List[int]] = None
|
@@ -601,6 +607,8 @@ class ScheduleBatch:
|
|
601
607
|
spec_algorithm: SpeculativeAlgorithm,
|
602
608
|
enable_custom_logit_processor: bool,
|
603
609
|
):
|
610
|
+
return_logprob = any(req.return_logprob for req in reqs)
|
611
|
+
|
604
612
|
return cls(
|
605
613
|
reqs=reqs,
|
606
614
|
req_to_token_pool=req_to_token_pool,
|
@@ -608,7 +616,7 @@ class ScheduleBatch:
|
|
608
616
|
tree_cache=tree_cache,
|
609
617
|
model_config=model_config,
|
610
618
|
enable_overlap=enable_overlap,
|
611
|
-
return_logprob=
|
619
|
+
return_logprob=return_logprob,
|
612
620
|
has_stream=any(req.stream for req in reqs),
|
613
621
|
has_grammar=any(req.grammar for req in reqs),
|
614
622
|
device=req_to_token_pool.device,
|
@@ -635,24 +643,83 @@ class ScheduleBatch:
|
|
635
643
|
return req_pool_indices
|
636
644
|
|
637
645
|
def alloc_token_slots(self, num_tokens: int):
|
646
|
+
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
647
|
+
if self.tree_cache is not None:
|
648
|
+
self.tree_cache.evict(num_tokens)
|
649
|
+
|
638
650
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
651
|
+
if out_cache_loc is None:
|
652
|
+
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
653
|
+
error_msg = (
|
654
|
+
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
655
|
+
f"Try to allocate {num_tokens} tokens.\n"
|
656
|
+
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
657
|
+
)
|
658
|
+
logger.error(error_msg)
|
659
|
+
if self.tree_cache is not None:
|
660
|
+
self.tree_cache.pretty_print()
|
661
|
+
raise RuntimeError(error_msg)
|
662
|
+
|
663
|
+
return out_cache_loc
|
664
|
+
|
665
|
+
def alloc_paged_token_slots_extend(
|
666
|
+
self,
|
667
|
+
prefix_lens: torch.Tensor,
|
668
|
+
seq_lens: torch.Tensor,
|
669
|
+
last_loc: torch.Tensor,
|
670
|
+
extend_num_tokens: int,
|
671
|
+
):
|
672
|
+
if (
|
673
|
+
self.token_to_kv_pool_allocator.available_size()
|
674
|
+
< extend_num_tokens
|
675
|
+
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
676
|
+
):
|
677
|
+
if self.tree_cache is not None:
|
678
|
+
self.tree_cache.evict(
|
679
|
+
extend_num_tokens
|
680
|
+
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
681
|
+
)
|
639
682
|
|
683
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
684
|
+
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
685
|
+
)
|
640
686
|
if out_cache_loc is None:
|
687
|
+
error_msg = (
|
688
|
+
f"Prefill out of memory. Try to lower your batch size.\n"
|
689
|
+
f"Try to allocate {extend_num_tokens} tokens.\n"
|
690
|
+
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
691
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
692
|
+
f"{self.tree_cache.evictable_size()=}\n"
|
693
|
+
)
|
694
|
+
logger.error(error_msg)
|
695
|
+
raise RuntimeError(error_msg)
|
696
|
+
return out_cache_loc
|
697
|
+
|
698
|
+
def alloc_paged_token_slots_decode(
|
699
|
+
self,
|
700
|
+
seq_lens: torch.Tensor,
|
701
|
+
last_loc: torch.Tensor,
|
702
|
+
):
|
703
|
+
if (
|
704
|
+
self.token_to_kv_pool_allocator.available_size()
|
705
|
+
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
706
|
+
):
|
641
707
|
if self.tree_cache is not None:
|
642
|
-
self.tree_cache.evict(
|
643
|
-
|
644
|
-
|
645
|
-
if out_cache_loc is None:
|
646
|
-
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
647
|
-
logger.error(
|
648
|
-
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
649
|
-
f"Try to allocate {num_tokens} tokens.\n"
|
650
|
-
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
708
|
+
self.tree_cache.evict(
|
709
|
+
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
651
710
|
)
|
652
|
-
|
653
|
-
self.tree_cache.pretty_print()
|
654
|
-
exit(1)
|
711
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
655
712
|
|
713
|
+
if out_cache_loc is None:
|
714
|
+
error_msg = (
|
715
|
+
f"Decode out of memory. Try to lower your batch size.\n"
|
716
|
+
f"Try to allocate {len(seq_lens)} tokens.\n"
|
717
|
+
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
718
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
719
|
+
f"{self.tree_cache.evictable_size()=}\n"
|
720
|
+
)
|
721
|
+
logger.error(error_msg)
|
722
|
+
raise RuntimeError(error_msg)
|
656
723
|
return out_cache_loc
|
657
724
|
|
658
725
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
@@ -703,7 +770,7 @@ class ScheduleBatch:
|
|
703
770
|
pt += req.extend_input_len
|
704
771
|
|
705
772
|
# Reassign
|
706
|
-
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.
|
773
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
707
774
|
self.device, non_blocking=True
|
708
775
|
)
|
709
776
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
@@ -711,14 +778,14 @@ class ScheduleBatch:
|
|
711
778
|
)
|
712
779
|
|
713
780
|
if not decoder_out_cache_loc:
|
714
|
-
self.out_cache_loc = torch.zeros(0, dtype=torch.
|
781
|
+
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
715
782
|
self.device, non_blocking=True
|
716
783
|
)
|
717
784
|
else:
|
718
785
|
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
719
786
|
|
720
787
|
if not encoder_out_cache_loc:
|
721
|
-
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.
|
788
|
+
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
722
789
|
self.device, non_blocking=True
|
723
790
|
)
|
724
791
|
else:
|
@@ -729,25 +796,38 @@ class ScheduleBatch:
|
|
729
796
|
def prepare_for_extend(self):
|
730
797
|
self.forward_mode = ForwardMode.EXTEND
|
731
798
|
|
799
|
+
# Allocate req slots
|
732
800
|
bs = len(self.reqs)
|
801
|
+
req_pool_indices = self.alloc_req_slots(bs)
|
802
|
+
|
803
|
+
# Init tensors
|
733
804
|
reqs = self.reqs
|
734
805
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
735
806
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
736
|
-
seq_lens = []
|
737
|
-
|
807
|
+
seq_lens = [len(r.fill_ids) for r in reqs]
|
808
|
+
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
809
|
+
extend_lens = [r.extend_input_len for r in reqs]
|
738
810
|
|
739
|
-
|
740
|
-
|
741
|
-
|
811
|
+
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
812
|
+
self.device, non_blocking=True
|
813
|
+
)
|
814
|
+
input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
|
815
|
+
self.device, non_blocking=True
|
816
|
+
)
|
817
|
+
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
818
|
+
self.device, non_blocking=True
|
819
|
+
)
|
820
|
+
prefix_lens_tensor = torch.tensor(
|
821
|
+
prefix_lens, dtype=torch.int64, device=self.device
|
822
|
+
)
|
823
|
+
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
742
824
|
|
825
|
+
# Copy prefix and do some basic check
|
743
826
|
input_embeds = []
|
744
827
|
extend_input_logprob_token_ids = []
|
745
828
|
|
746
|
-
|
747
|
-
for i, req in enumerate(reqs):
|
829
|
+
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
748
830
|
req.req_pool_idx = req_pool_indices[i]
|
749
|
-
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
750
|
-
seq_lens.append(seq_len)
|
751
831
|
assert seq_len - pre_len == req.extend_input_len
|
752
832
|
|
753
833
|
if pre_len > 0:
|
@@ -763,7 +843,7 @@ class ScheduleBatch:
|
|
763
843
|
req.cached_tokens += pre_len - req.already_computed
|
764
844
|
req.already_computed = seq_len
|
765
845
|
req.is_retracted = False
|
766
|
-
|
846
|
+
|
767
847
|
# Compute the relative logprob_start_len in an extend batch
|
768
848
|
if req.logprob_start_len >= pre_len:
|
769
849
|
req.extend_logprob_start_len = min(
|
@@ -819,60 +899,62 @@ class ScheduleBatch:
|
|
819
899
|
else:
|
820
900
|
extend_input_logprob_token_ids = None
|
821
901
|
|
902
|
+
# Allocate memory
|
903
|
+
if self.token_to_kv_pool_allocator.page_size == 1:
|
904
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
905
|
+
else:
|
906
|
+
last_loc = get_last_loc(
|
907
|
+
self.req_to_token_pool.req_to_token,
|
908
|
+
req_pool_indices_tensor,
|
909
|
+
prefix_lens_tensor,
|
910
|
+
)
|
911
|
+
out_cache_loc = self.alloc_paged_token_slots_extend(
|
912
|
+
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
|
913
|
+
)
|
914
|
+
|
822
915
|
# Set fields
|
823
|
-
self.input_ids =
|
824
|
-
|
825
|
-
|
826
|
-
self.
|
827
|
-
self.device, non_blocking=True
|
828
|
-
)
|
829
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
830
|
-
self.device, non_blocking=True
|
831
|
-
)
|
916
|
+
self.input_ids = input_ids_tensor
|
917
|
+
self.req_pool_indices = req_pool_indices_tensor
|
918
|
+
self.seq_lens = seq_lens_tensor
|
919
|
+
self.out_cache_loc = out_cache_loc
|
832
920
|
self.input_embeds = (
|
833
921
|
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
834
922
|
if input_embeds
|
835
923
|
else None
|
836
924
|
)
|
837
|
-
|
838
|
-
self.out_cache_loc = out_cache_loc
|
839
|
-
|
840
925
|
self.seq_lens_sum = sum(seq_lens)
|
926
|
+
|
841
927
|
if self.return_logprob:
|
842
928
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
843
929
|
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
844
|
-
|
845
|
-
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
846
|
-
self.extend_lens = [r.extend_input_len for r in reqs]
|
930
|
+
|
847
931
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
932
|
+
self.extend_num_tokens = extend_num_tokens
|
933
|
+
self.prefix_lens = prefix_lens
|
934
|
+
self.extend_lens = extend_lens
|
848
935
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
849
936
|
|
850
937
|
# Write to req_to_token_pool
|
851
|
-
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
852
|
-
self.device, non_blocking=True
|
853
|
-
)
|
854
|
-
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
855
|
-
self.device, non_blocking=True
|
856
|
-
)
|
857
938
|
if global_server_args_dict["attention_backend"] != "torch_native":
|
939
|
+
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
940
|
+
|
858
941
|
write_req_to_token_pool_triton[(bs,)](
|
859
942
|
self.req_to_token_pool.req_to_token,
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
943
|
+
req_pool_indices_tensor,
|
944
|
+
prefix_lens_tensor,
|
945
|
+
seq_lens_tensor,
|
946
|
+
extend_lens_tensor,
|
947
|
+
out_cache_loc,
|
865
948
|
self.req_to_token_pool.req_to_token.shape[1],
|
866
949
|
)
|
867
950
|
else:
|
868
951
|
pt = 0
|
869
952
|
for i in range(bs):
|
870
953
|
self.req_to_token_pool.write(
|
871
|
-
(
|
872
|
-
|
954
|
+
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
955
|
+
out_cache_loc[pt : pt + extend_lens[i]],
|
873
956
|
)
|
874
|
-
pt +=
|
875
|
-
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
957
|
+
pt += extend_lens[i]
|
876
958
|
|
877
959
|
if self.model_config.is_encoder_decoder:
|
878
960
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
@@ -918,7 +1000,7 @@ class ScheduleBatch:
|
|
918
1000
|
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
919
1001
|
return True
|
920
1002
|
|
921
|
-
self.tree_cache.evict(bs
|
1003
|
+
self.tree_cache.evict(bs)
|
922
1004
|
|
923
1005
|
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
924
1006
|
return True
|
@@ -943,10 +1025,6 @@ class ScheduleBatch:
|
|
943
1025
|
reverse=True,
|
944
1026
|
)
|
945
1027
|
|
946
|
-
retracted_reqs = []
|
947
|
-
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
948
|
-
first_iter = True
|
949
|
-
|
950
1028
|
def get_required_tokens(num_reqs: int):
|
951
1029
|
headroom_for_spec_decode = 0
|
952
1030
|
if server_args.speculative_algorithm:
|
@@ -960,6 +1038,9 @@ class ScheduleBatch:
|
|
960
1038
|
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
961
1039
|
)
|
962
1040
|
|
1041
|
+
retracted_reqs = []
|
1042
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1043
|
+
first_iter = True
|
963
1044
|
while (
|
964
1045
|
self.token_to_kv_pool_allocator.available_size()
|
965
1046
|
< get_required_tokens(len(sorted_indices))
|
@@ -984,7 +1065,6 @@ class ScheduleBatch:
|
|
984
1065
|
]
|
985
1066
|
self.token_to_kv_pool_allocator.free(token_indices)
|
986
1067
|
self.req_to_token_pool.free(req.req_pool_idx)
|
987
|
-
del self.tree_cache.entries[req.rid]
|
988
1068
|
else:
|
989
1069
|
# TODO: apply more fine-grained retraction
|
990
1070
|
last_uncached_pos = len(req.prefix_indices)
|
@@ -1003,9 +1083,7 @@ class ScheduleBatch:
|
|
1003
1083
|
- self.token_to_kv_pool_allocator.available_size()
|
1004
1084
|
)
|
1005
1085
|
residual_size = max(0, residual_size)
|
1006
|
-
self.tree_cache.evict(
|
1007
|
-
residual_size, self.token_to_kv_pool_allocator.free
|
1008
|
-
)
|
1086
|
+
self.tree_cache.evict(residual_size)
|
1009
1087
|
|
1010
1088
|
req.reset_for_retract()
|
1011
1089
|
|
@@ -1028,9 +1106,9 @@ class ScheduleBatch:
|
|
1028
1106
|
|
1029
1107
|
def prepare_for_idle(self):
|
1030
1108
|
self.forward_mode = ForwardMode.IDLE
|
1031
|
-
self.input_ids = torch.empty(0, dtype=torch.
|
1109
|
+
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1032
1110
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1033
|
-
self.out_cache_loc = torch.empty(0, dtype=torch.
|
1111
|
+
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1034
1112
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1035
1113
|
self.seq_lens_sum = 0
|
1036
1114
|
self.extend_num_tokens = 0
|
@@ -1041,6 +1119,8 @@ class ScheduleBatch:
|
|
1041
1119
|
|
1042
1120
|
def prepare_for_decode(self):
|
1043
1121
|
self.forward_mode = ForwardMode.DECODE
|
1122
|
+
bs = len(self.reqs)
|
1123
|
+
|
1044
1124
|
if self.spec_algorithm.is_eagle():
|
1045
1125
|
# if spec decoding is used, the decode batch is prepared inside
|
1046
1126
|
# `forward_batch_speculative_generation` after running draft models.
|
@@ -1069,33 +1149,39 @@ class ScheduleBatch:
|
|
1069
1149
|
self.output_ids.to(torch.int64)
|
1070
1150
|
)
|
1071
1151
|
|
1152
|
+
# Update fields
|
1072
1153
|
self.input_ids = self.output_ids
|
1073
1154
|
self.output_ids = None
|
1074
1155
|
|
1075
|
-
# Alloc mem
|
1076
|
-
bs = len(self.reqs)
|
1077
|
-
self.out_cache_loc = self.alloc_token_slots(bs)
|
1078
|
-
|
1079
1156
|
if self.model_config.is_encoder_decoder:
|
1080
1157
|
locs = self.encoder_lens + self.seq_lens
|
1081
1158
|
self.prepare_encoder_info_decode()
|
1082
1159
|
else:
|
1083
|
-
locs = self.seq_lens
|
1160
|
+
locs = self.seq_lens.clone()
|
1084
1161
|
|
1085
1162
|
if self.enable_overlap:
|
1086
1163
|
# Do not use in-place operations in the overlap mode
|
1087
|
-
self.req_to_token_pool.write(
|
1088
|
-
(self.req_pool_indices, locs), self.out_cache_loc
|
1089
|
-
)
|
1090
1164
|
self.seq_lens = self.seq_lens + 1
|
1091
1165
|
else:
|
1092
1166
|
# A faster in-place version
|
1093
|
-
self.req_to_token_pool.write(
|
1094
|
-
(self.req_pool_indices, locs), self.out_cache_loc
|
1095
|
-
)
|
1096
1167
|
self.seq_lens.add_(1)
|
1097
1168
|
self.seq_lens_sum += bs
|
1098
1169
|
|
1170
|
+
# Allocate memory
|
1171
|
+
if self.token_to_kv_pool_allocator.page_size == 1:
|
1172
|
+
self.out_cache_loc = self.alloc_token_slots(bs)
|
1173
|
+
else:
|
1174
|
+
last_loc = self.req_to_token_pool.req_to_token[
|
1175
|
+
self.req_pool_indices, self.seq_lens - 2
|
1176
|
+
]
|
1177
|
+
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
1178
|
+
self.seq_lens, last_loc
|
1179
|
+
)
|
1180
|
+
|
1181
|
+
self.req_to_token_pool.write(
|
1182
|
+
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
|
1183
|
+
)
|
1184
|
+
|
1099
1185
|
def filter_batch(
|
1100
1186
|
self,
|
1101
1187
|
chunked_req_to_exclude: Optional[Req] = None,
|
@@ -1349,8 +1435,8 @@ def write_req_to_token_pool_triton(
|
|
1349
1435
|
pre_len = tl.load(pre_lens + pid)
|
1350
1436
|
seq_len = tl.load(seq_lens + pid)
|
1351
1437
|
|
1352
|
-
#
|
1353
|
-
cumsum_start = 0
|
1438
|
+
# NOTE: This can be slow for large bs
|
1439
|
+
cumsum_start = tl.cast(0, tl.int64)
|
1354
1440
|
for i in range(pid):
|
1355
1441
|
cumsum_start += tl.load(extend_lens + i)
|
1356
1442
|
|
@@ -1367,3 +1453,12 @@ def write_req_to_token_pool_triton(
|
|
1367
1453
|
value,
|
1368
1454
|
mask=mask,
|
1369
1455
|
)
|
1456
|
+
|
1457
|
+
|
1458
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
1459
|
+
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
|
1460
|
+
return torch.where(
|
1461
|
+
prefix_lens_tensor > 0,
|
1462
|
+
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
1463
|
+
torch.full_like(prefix_lens_tensor, -1),
|
1464
|
+
)
|