sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
49
49
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
50
50
  from sglang.srt.sampling.sampling_params import SamplingParams
51
51
  from sglang.srt.server_args import ServerArgs
52
+ from sglang.srt.utils import get_compiler_backend, next_power_of_2
52
53
 
53
54
  if TYPE_CHECKING:
54
55
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -158,15 +159,19 @@ class ImageInputs:
158
159
  image_grid_thws: List[Tuple[int, int, int]] = None
159
160
  mrope_position_delta: Optional[torch.Tensor] = None
160
161
 
161
- # MiniCPMV related
162
+ # The id of the single-image placeholder token
163
+ im_token_id: Optional[torch.Tensor] = None
162
164
  # All the images in the batch should share the same special image
163
165
  # bound token ids.
164
- im_start_id: Optional[torch.Tensor] = None
165
- im_end_id: Optional[torch.Tensor] = None
166
- slice_start_id: Optional[torch.Tensor] = None
167
- slice_end_id: Optional[torch.Tensor] = None
166
+ im_start_id: Optional[int] = None
167
+ im_end_id: Optional[int] = None
168
+ slice_start_id: Optional[int] = None
169
+ slice_end_id: Optional[int] = None
168
170
  tgt_sizes: Optional[list] = None
169
171
 
172
+ # denotes the number of valid image tokens in each image
173
+ images_emb_mask: Optional[torch.BoolTensor] = None
174
+
170
175
  @staticmethod
171
176
  def from_dict(obj: dict):
172
177
  ret = ImageInputs(
@@ -186,11 +191,13 @@ class ImageInputs:
186
191
  "aspect_ratio_ids",
187
192
  "aspect_ratio_mask",
188
193
  "image_grid_thws",
194
+ "im_token_id",
189
195
  "im_start_id",
190
196
  "im_end_id",
191
197
  "slice_start_id",
192
198
  "slice_end_id",
193
199
  "tgt_sizes",
200
+ "images_emb_mask",
194
201
  ]
195
202
  for arg in optional_args:
196
203
  if arg in obj:
@@ -267,7 +274,6 @@ class Req:
267
274
  "__req__": self
268
275
  }
269
276
  self.sampling_params = sampling_params
270
-
271
277
  self.custom_logit_processor = custom_logit_processor
272
278
  self.return_hidden_states = return_hidden_states
273
279
 
@@ -309,6 +315,7 @@ class Req:
309
315
  # The relative logprob_start_len in an extend batch
310
316
  self.extend_logprob_start_len = 0
311
317
  self.last_node = None
318
+ self.last_node_global = None
312
319
 
313
320
  # Whether or not if it is chunked. It increments whenever
314
321
  # it is chunked, and decrement whenever chunked request is
@@ -324,6 +331,8 @@ class Req:
324
331
  self.logprob_start_len = 0
325
332
  self.top_logprobs_num = top_logprobs_num
326
333
  self.token_ids_logprob = token_ids_logprob
334
+ self.temp_scaled_logprobs = False
335
+ self.top_p_normalized_logprobs = False
327
336
 
328
337
  # Logprobs (return values)
329
338
  self.input_token_logprobs_val: Optional[List[float]] = None
@@ -352,7 +361,7 @@ class Req:
352
361
  ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
353
362
  self.output_token_ids_logprobs_idx
354
363
  ) = None
355
- self.hidden_states = []
364
+ self.hidden_states: List[List[float]] = []
356
365
 
357
366
  # Embedding (return values)
358
367
  self.embedding = None
@@ -383,13 +392,24 @@ class Req:
383
392
  # Whether request reached finished condition
384
393
  return self.finished_reason is not None
385
394
 
386
- def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
395
+ def init_next_round_input(
396
+ self,
397
+ tree_cache: Optional[BasePrefixCache] = None,
398
+ enable_hierarchical_cache=False,
399
+ ):
387
400
  self.fill_ids = self.origin_input_ids + self.output_ids
388
401
  if tree_cache is not None:
389
402
  # tree cache is None if the prefix is not computed with tree cache.
390
- self.prefix_indices, self.last_node = tree_cache.match_prefix(
391
- rid=self.rid, key=self.adjust_max_prefix_ids()
392
- )
403
+ if enable_hierarchical_cache:
404
+ self.prefix_indices, self.last_node, self.last_node_global = (
405
+ tree_cache.match_prefix(
406
+ key=self.adjust_max_prefix_ids(), include_evicted=True
407
+ )
408
+ )
409
+ else:
410
+ self.prefix_indices, self.last_node = tree_cache.match_prefix(
411
+ rid=self.rid, key=self.adjust_max_prefix_ids()
412
+ )
393
413
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
394
414
 
395
415
  def adjust_max_prefix_ids(self):
@@ -423,28 +443,6 @@ class Req:
423
443
  all_ids = self.origin_input_ids_unpadded + self.output_ids
424
444
  return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
425
445
 
426
- def get_next_inc_detokenization(self):
427
- if self.tokenizer is None:
428
- return False, ""
429
- read_ids, read_offset = self.init_incremental_detokenize()
430
- surr_ids = read_ids[:read_offset]
431
-
432
- surr_text = self.tokenizer.decode(
433
- surr_ids,
434
- skip_special_tokens=self.sampling_params.skip_special_tokens,
435
- spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
436
- )
437
- new_text = self.tokenizer.decode(
438
- read_ids,
439
- skip_special_tokens=self.sampling_params.skip_special_tokens,
440
- spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
441
- )
442
-
443
- if len(new_text) > len(surr_text) and not new_text.endswith("�"):
444
- return True, new_text[len(surr_text) :]
445
-
446
- return False, ""
447
-
448
446
  def check_finished(self):
449
447
  if self.finished():
450
448
  return
@@ -528,19 +526,23 @@ class ScheduleBatch:
528
526
  model_config: ModelConfig = None
529
527
  forward_mode: ForwardMode = None
530
528
  enable_overlap: bool = False
529
+ # Tell whether the current running batch is full so that we can skip
530
+ # the check of whether to prefill new requests.
531
+ # This is an optimization to reduce the overhead of the prefill check.
532
+ batch_is_full: bool = False
531
533
 
532
534
  # Sampling info
533
535
  sampling_info: SamplingBatchInfo = None
534
536
  next_batch_sampling_info: SamplingBatchInfo = None
535
537
 
536
538
  # Batched arguments to model runner
537
- input_ids: torch.Tensor = None # shape: [b], int32
539
+ input_ids: torch.Tensor = None # shape: [b], int64
538
540
  input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
539
- req_pool_indices: torch.Tensor = None # shape: [b], int32
541
+ req_pool_indices: torch.Tensor = None # shape: [b], int64
540
542
  seq_lens: torch.Tensor = None # shape: [b], int64
541
543
  # The output locations of the KV cache
542
- out_cache_loc: torch.Tensor = None # shape: [b], int32
543
- output_ids: torch.Tensor = None # shape: [b], int32
544
+ out_cache_loc: torch.Tensor = None # shape: [b], int64
545
+ output_ids: torch.Tensor = None # shape: [b], int64
544
546
 
545
547
  # The sum of all sequence lengths
546
548
  seq_lens_sum: int = None
@@ -555,6 +557,10 @@ class ScheduleBatch:
555
557
  top_logprobs_nums: Optional[List[int]] = None
556
558
  token_ids_logprobs: Optional[List[List[int]]] = None
557
559
 
560
+ # For logits and logprob post processing
561
+ temp_scaled_logprobs: bool = False
562
+ top_p_normalized_logprobs: bool = False
563
+
558
564
  # For extend and mixed chunekd prefill
559
565
  prefix_lens: List[int] = None
560
566
  extend_lens: List[int] = None
@@ -564,7 +570,7 @@ class ScheduleBatch:
564
570
  # It comes empty list if logprob is not required.
565
571
  extend_input_logprob_token_ids: Optional[torch.Tensor] = None
566
572
 
567
- # For encoder-decoder
573
+ # For encoder-decoder architectures
568
574
  encoder_cached: Optional[List[bool]] = None
569
575
  encoder_lens: Optional[torch.Tensor] = None
570
576
  encoder_lens_cpu: Optional[List[int]] = None
@@ -601,6 +607,8 @@ class ScheduleBatch:
601
607
  spec_algorithm: SpeculativeAlgorithm,
602
608
  enable_custom_logit_processor: bool,
603
609
  ):
610
+ return_logprob = any(req.return_logprob for req in reqs)
611
+
604
612
  return cls(
605
613
  reqs=reqs,
606
614
  req_to_token_pool=req_to_token_pool,
@@ -608,7 +616,7 @@ class ScheduleBatch:
608
616
  tree_cache=tree_cache,
609
617
  model_config=model_config,
610
618
  enable_overlap=enable_overlap,
611
- return_logprob=any(req.return_logprob for req in reqs),
619
+ return_logprob=return_logprob,
612
620
  has_stream=any(req.stream for req in reqs),
613
621
  has_grammar=any(req.grammar for req in reqs),
614
622
  device=req_to_token_pool.device,
@@ -635,24 +643,83 @@ class ScheduleBatch:
635
643
  return req_pool_indices
636
644
 
637
645
  def alloc_token_slots(self, num_tokens: int):
646
+ if self.token_to_kv_pool_allocator.available_size() < num_tokens:
647
+ if self.tree_cache is not None:
648
+ self.tree_cache.evict(num_tokens)
649
+
638
650
  out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
651
+ if out_cache_loc is None:
652
+ phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
653
+ error_msg = (
654
+ f"{phase_str} out of memory. Try to lower your batch size.\n"
655
+ f"Try to allocate {num_tokens} tokens.\n"
656
+ f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
657
+ )
658
+ logger.error(error_msg)
659
+ if self.tree_cache is not None:
660
+ self.tree_cache.pretty_print()
661
+ raise RuntimeError(error_msg)
662
+
663
+ return out_cache_loc
664
+
665
+ def alloc_paged_token_slots_extend(
666
+ self,
667
+ prefix_lens: torch.Tensor,
668
+ seq_lens: torch.Tensor,
669
+ last_loc: torch.Tensor,
670
+ extend_num_tokens: int,
671
+ ):
672
+ if (
673
+ self.token_to_kv_pool_allocator.available_size()
674
+ < extend_num_tokens
675
+ + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
676
+ ):
677
+ if self.tree_cache is not None:
678
+ self.tree_cache.evict(
679
+ extend_num_tokens
680
+ + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
681
+ )
639
682
 
683
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
684
+ prefix_lens, seq_lens, last_loc, extend_num_tokens
685
+ )
640
686
  if out_cache_loc is None:
687
+ error_msg = (
688
+ f"Prefill out of memory. Try to lower your batch size.\n"
689
+ f"Try to allocate {extend_num_tokens} tokens.\n"
690
+ f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
691
+ f"{self.token_to_kv_pool_allocator.available_size()=}\n"
692
+ f"{self.tree_cache.evictable_size()=}\n"
693
+ )
694
+ logger.error(error_msg)
695
+ raise RuntimeError(error_msg)
696
+ return out_cache_loc
697
+
698
+ def alloc_paged_token_slots_decode(
699
+ self,
700
+ seq_lens: torch.Tensor,
701
+ last_loc: torch.Tensor,
702
+ ):
703
+ if (
704
+ self.token_to_kv_pool_allocator.available_size()
705
+ < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
706
+ ):
641
707
  if self.tree_cache is not None:
642
- self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
643
- out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
644
-
645
- if out_cache_loc is None:
646
- phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
647
- logger.error(
648
- f"{phase_str} out of memory. Try to lower your batch size.\n"
649
- f"Try to allocate {num_tokens} tokens.\n"
650
- f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
708
+ self.tree_cache.evict(
709
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
651
710
  )
652
- if self.tree_cache is not None:
653
- self.tree_cache.pretty_print()
654
- exit(1)
711
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
655
712
 
713
+ if out_cache_loc is None:
714
+ error_msg = (
715
+ f"Decode out of memory. Try to lower your batch size.\n"
716
+ f"Try to allocate {len(seq_lens)} tokens.\n"
717
+ f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
718
+ f"{self.token_to_kv_pool_allocator.available_size()=}\n"
719
+ f"{self.tree_cache.evictable_size()=}\n"
720
+ )
721
+ logger.error(error_msg)
722
+ raise RuntimeError(error_msg)
656
723
  return out_cache_loc
657
724
 
658
725
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
@@ -703,7 +770,7 @@ class ScheduleBatch:
703
770
  pt += req.extend_input_len
704
771
 
705
772
  # Reassign
706
- self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
773
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
707
774
  self.device, non_blocking=True
708
775
  )
709
776
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
@@ -711,14 +778,14 @@ class ScheduleBatch:
711
778
  )
712
779
 
713
780
  if not decoder_out_cache_loc:
714
- self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
781
+ self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
715
782
  self.device, non_blocking=True
716
783
  )
717
784
  else:
718
785
  self.out_cache_loc = torch.cat(decoder_out_cache_loc)
719
786
 
720
787
  if not encoder_out_cache_loc:
721
- self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
788
+ self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
722
789
  self.device, non_blocking=True
723
790
  )
724
791
  else:
@@ -729,25 +796,38 @@ class ScheduleBatch:
729
796
  def prepare_for_extend(self):
730
797
  self.forward_mode = ForwardMode.EXTEND
731
798
 
799
+ # Allocate req slots
732
800
  bs = len(self.reqs)
801
+ req_pool_indices = self.alloc_req_slots(bs)
802
+
803
+ # Init tensors
733
804
  reqs = self.reqs
734
805
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
735
806
  extend_num_tokens = sum(len(ids) for ids in input_ids)
736
- seq_lens = []
737
- pre_lens = []
807
+ seq_lens = [len(r.fill_ids) for r in reqs]
808
+ prefix_lens = [len(r.prefix_indices) for r in reqs]
809
+ extend_lens = [r.extend_input_len for r in reqs]
738
810
 
739
- # Allocate memory
740
- req_pool_indices = self.alloc_req_slots(bs)
741
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
811
+ req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
812
+ self.device, non_blocking=True
813
+ )
814
+ input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
815
+ self.device, non_blocking=True
816
+ )
817
+ seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
818
+ self.device, non_blocking=True
819
+ )
820
+ prefix_lens_tensor = torch.tensor(
821
+ prefix_lens, dtype=torch.int64, device=self.device
822
+ )
823
+ extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
742
824
 
825
+ # Copy prefix and do some basic check
743
826
  input_embeds = []
744
827
  extend_input_logprob_token_ids = []
745
828
 
746
- pt = 0
747
- for i, req in enumerate(reqs):
829
+ for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
748
830
  req.req_pool_idx = req_pool_indices[i]
749
- pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
750
- seq_lens.append(seq_len)
751
831
  assert seq_len - pre_len == req.extend_input_len
752
832
 
753
833
  if pre_len > 0:
@@ -763,7 +843,7 @@ class ScheduleBatch:
763
843
  req.cached_tokens += pre_len - req.already_computed
764
844
  req.already_computed = seq_len
765
845
  req.is_retracted = False
766
- pre_lens.append(pre_len)
846
+
767
847
  # Compute the relative logprob_start_len in an extend batch
768
848
  if req.logprob_start_len >= pre_len:
769
849
  req.extend_logprob_start_len = min(
@@ -819,60 +899,62 @@ class ScheduleBatch:
819
899
  else:
820
900
  extend_input_logprob_token_ids = None
821
901
 
902
+ # Allocate memory
903
+ if self.token_to_kv_pool_allocator.page_size == 1:
904
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
905
+ else:
906
+ last_loc = get_last_loc(
907
+ self.req_to_token_pool.req_to_token,
908
+ req_pool_indices_tensor,
909
+ prefix_lens_tensor,
910
+ )
911
+ out_cache_loc = self.alloc_paged_token_slots_extend(
912
+ prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
913
+ )
914
+
822
915
  # Set fields
823
- self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
824
- self.device, non_blocking=True
825
- )
826
- self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
827
- self.device, non_blocking=True
828
- )
829
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
830
- self.device, non_blocking=True
831
- )
916
+ self.input_ids = input_ids_tensor
917
+ self.req_pool_indices = req_pool_indices_tensor
918
+ self.seq_lens = seq_lens_tensor
919
+ self.out_cache_loc = out_cache_loc
832
920
  self.input_embeds = (
833
921
  torch.tensor(input_embeds).to(self.device, non_blocking=True)
834
922
  if input_embeds
835
923
  else None
836
924
  )
837
-
838
- self.out_cache_loc = out_cache_loc
839
-
840
925
  self.seq_lens_sum = sum(seq_lens)
926
+
841
927
  if self.return_logprob:
842
928
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
843
929
  self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
844
- self.extend_num_tokens = extend_num_tokens
845
- self.prefix_lens = [len(r.prefix_indices) for r in reqs]
846
- self.extend_lens = [r.extend_input_len for r in reqs]
930
+
847
931
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
932
+ self.extend_num_tokens = extend_num_tokens
933
+ self.prefix_lens = prefix_lens
934
+ self.extend_lens = extend_lens
848
935
  self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
849
936
 
850
937
  # Write to req_to_token_pool
851
- pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
852
- self.device, non_blocking=True
853
- )
854
- extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
855
- self.device, non_blocking=True
856
- )
857
938
  if global_server_args_dict["attention_backend"] != "torch_native":
939
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
940
+
858
941
  write_req_to_token_pool_triton[(bs,)](
859
942
  self.req_to_token_pool.req_to_token,
860
- self.req_pool_indices,
861
- pre_lens,
862
- self.seq_lens,
863
- extend_lens,
864
- self.out_cache_loc,
943
+ req_pool_indices_tensor,
944
+ prefix_lens_tensor,
945
+ seq_lens_tensor,
946
+ extend_lens_tensor,
947
+ out_cache_loc,
865
948
  self.req_to_token_pool.req_to_token.shape[1],
866
949
  )
867
950
  else:
868
951
  pt = 0
869
952
  for i in range(bs):
870
953
  self.req_to_token_pool.write(
871
- (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
872
- self.out_cache_loc[pt : pt + self.extend_lens[i]],
954
+ (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
955
+ out_cache_loc[pt : pt + extend_lens[i]],
873
956
  )
874
- pt += self.extend_lens[i]
875
- # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
957
+ pt += extend_lens[i]
876
958
 
877
959
  if self.model_config.is_encoder_decoder:
878
960
  self.prepare_encoder_info_extend(input_ids, seq_lens)
@@ -918,7 +1000,7 @@ class ScheduleBatch:
918
1000
  if self.token_to_kv_pool_allocator.available_size() >= bs:
919
1001
  return True
920
1002
 
921
- self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
1003
+ self.tree_cache.evict(bs)
922
1004
 
923
1005
  if self.token_to_kv_pool_allocator.available_size() >= bs:
924
1006
  return True
@@ -943,10 +1025,6 @@ class ScheduleBatch:
943
1025
  reverse=True,
944
1026
  )
945
1027
 
946
- retracted_reqs = []
947
- seq_lens_cpu = self.seq_lens.cpu().numpy()
948
- first_iter = True
949
-
950
1028
  def get_required_tokens(num_reqs: int):
951
1029
  headroom_for_spec_decode = 0
952
1030
  if server_args.speculative_algorithm:
@@ -960,6 +1038,9 @@ class ScheduleBatch:
960
1038
  num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
961
1039
  )
962
1040
 
1041
+ retracted_reqs = []
1042
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
1043
+ first_iter = True
963
1044
  while (
964
1045
  self.token_to_kv_pool_allocator.available_size()
965
1046
  < get_required_tokens(len(sorted_indices))
@@ -984,7 +1065,6 @@ class ScheduleBatch:
984
1065
  ]
985
1066
  self.token_to_kv_pool_allocator.free(token_indices)
986
1067
  self.req_to_token_pool.free(req.req_pool_idx)
987
- del self.tree_cache.entries[req.rid]
988
1068
  else:
989
1069
  # TODO: apply more fine-grained retraction
990
1070
  last_uncached_pos = len(req.prefix_indices)
@@ -1003,9 +1083,7 @@ class ScheduleBatch:
1003
1083
  - self.token_to_kv_pool_allocator.available_size()
1004
1084
  )
1005
1085
  residual_size = max(0, residual_size)
1006
- self.tree_cache.evict(
1007
- residual_size, self.token_to_kv_pool_allocator.free
1008
- )
1086
+ self.tree_cache.evict(residual_size)
1009
1087
 
1010
1088
  req.reset_for_retract()
1011
1089
 
@@ -1028,9 +1106,9 @@ class ScheduleBatch:
1028
1106
 
1029
1107
  def prepare_for_idle(self):
1030
1108
  self.forward_mode = ForwardMode.IDLE
1031
- self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1109
+ self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1032
1110
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1033
- self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1111
+ self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1034
1112
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1035
1113
  self.seq_lens_sum = 0
1036
1114
  self.extend_num_tokens = 0
@@ -1041,6 +1119,8 @@ class ScheduleBatch:
1041
1119
 
1042
1120
  def prepare_for_decode(self):
1043
1121
  self.forward_mode = ForwardMode.DECODE
1122
+ bs = len(self.reqs)
1123
+
1044
1124
  if self.spec_algorithm.is_eagle():
1045
1125
  # if spec decoding is used, the decode batch is prepared inside
1046
1126
  # `forward_batch_speculative_generation` after running draft models.
@@ -1069,33 +1149,39 @@ class ScheduleBatch:
1069
1149
  self.output_ids.to(torch.int64)
1070
1150
  )
1071
1151
 
1152
+ # Update fields
1072
1153
  self.input_ids = self.output_ids
1073
1154
  self.output_ids = None
1074
1155
 
1075
- # Alloc mem
1076
- bs = len(self.reqs)
1077
- self.out_cache_loc = self.alloc_token_slots(bs)
1078
-
1079
1156
  if self.model_config.is_encoder_decoder:
1080
1157
  locs = self.encoder_lens + self.seq_lens
1081
1158
  self.prepare_encoder_info_decode()
1082
1159
  else:
1083
- locs = self.seq_lens
1160
+ locs = self.seq_lens.clone()
1084
1161
 
1085
1162
  if self.enable_overlap:
1086
1163
  # Do not use in-place operations in the overlap mode
1087
- self.req_to_token_pool.write(
1088
- (self.req_pool_indices, locs), self.out_cache_loc
1089
- )
1090
1164
  self.seq_lens = self.seq_lens + 1
1091
1165
  else:
1092
1166
  # A faster in-place version
1093
- self.req_to_token_pool.write(
1094
- (self.req_pool_indices, locs), self.out_cache_loc
1095
- )
1096
1167
  self.seq_lens.add_(1)
1097
1168
  self.seq_lens_sum += bs
1098
1169
 
1170
+ # Allocate memory
1171
+ if self.token_to_kv_pool_allocator.page_size == 1:
1172
+ self.out_cache_loc = self.alloc_token_slots(bs)
1173
+ else:
1174
+ last_loc = self.req_to_token_pool.req_to_token[
1175
+ self.req_pool_indices, self.seq_lens - 2
1176
+ ]
1177
+ self.out_cache_loc = self.alloc_paged_token_slots_decode(
1178
+ self.seq_lens, last_loc
1179
+ )
1180
+
1181
+ self.req_to_token_pool.write(
1182
+ (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
1183
+ )
1184
+
1099
1185
  def filter_batch(
1100
1186
  self,
1101
1187
  chunked_req_to_exclude: Optional[Req] = None,
@@ -1349,8 +1435,8 @@ def write_req_to_token_pool_triton(
1349
1435
  pre_len = tl.load(pre_lens + pid)
1350
1436
  seq_len = tl.load(seq_lens + pid)
1351
1437
 
1352
- # TODO: optimize this?
1353
- cumsum_start = 0
1438
+ # NOTE: This can be slow for large bs
1439
+ cumsum_start = tl.cast(0, tl.int64)
1354
1440
  for i in range(pid):
1355
1441
  cumsum_start += tl.load(extend_lens + i)
1356
1442
 
@@ -1367,3 +1453,12 @@ def write_req_to_token_pool_triton(
1367
1453
  value,
1368
1454
  mask=mask,
1369
1455
  )
1456
+
1457
+
1458
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1459
+ def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
1460
+ return torch.where(
1461
+ prefix_lens_tensor > 0,
1462
+ req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
1463
+ torch.full_like(prefix_lens_tensor, -1),
1464
+ )