sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
33
33
 
34
34
  import copy
35
35
  import dataclasses
36
- import hashlib
37
36
  import logging
38
37
  import threading
39
38
  from enum import Enum, auto
@@ -53,10 +52,9 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
53
52
  ScheduleBatchDisaggregationDecodeMixin,
54
53
  )
55
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
- from sglang.srt.layers.multimodal import gpu_tensor_hash
57
55
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
58
56
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
59
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
57
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
60
58
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
59
  from sglang.srt.metrics.collector import TimeStats
62
60
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -87,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
87
85
  "deepep_mode",
88
86
  "enable_ep_moe",
89
87
  "enable_flashinfer_moe",
88
+ "enable_flashinfer_allreduce_fusion",
90
89
  "moe_dense_tp_size",
91
90
  "ep_dispatch_algorithm",
92
91
  "deepep_config",
@@ -96,12 +95,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
96
95
  "max_micro_batch_size",
97
96
  "disable_shared_experts_fusion",
98
97
  "sampling_backend",
99
- "speculative_accept_threshold_acc",
100
98
  "speculative_accept_threshold_single",
99
+ "speculative_accept_threshold_acc",
101
100
  "torchao_config",
102
101
  "triton_attention_reduce_in_fp32",
103
102
  "num_reserved_decode_tokens",
104
103
  "weight_loader_disable_mmap",
104
+ "enable_triton_kernel_moe",
105
105
  ]
106
106
 
107
107
  # Put some global args for easy access
@@ -176,50 +176,63 @@ class Modality(Enum):
176
176
  VIDEO = auto()
177
177
  AUDIO = auto()
178
178
 
179
+ @staticmethod
180
+ def from_str(modality_str: str):
181
+ try:
182
+ return Modality[modality_str.upper()]
183
+ except KeyError:
184
+ raise ValueError(
185
+ f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
186
+ )
187
+
179
188
 
180
189
  @dataclasses.dataclass
181
190
  class MultimodalDataItem:
182
191
  """
183
- A single multimodal data, from a single image/video/audio or others
192
+ One MultimodalDataItem contains all inputs for one modality.
193
+ For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
194
+ One for images and one for audio.
195
+
196
+ We put the common fields first and the model-specific fields last.
184
197
  """
185
198
 
186
199
  modality: Modality
187
-
188
200
  hash: int = None
189
201
  pad_value: int = None
190
-
191
- aspect_ratio_id: Optional[List[torch.Tensor]] = None
192
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
193
-
194
202
  image_sizes: Tuple[int, int] = None
195
203
  image_offsets: Optional[list] = None
196
204
 
197
205
  # the real data, pixel_values or audio_features
198
206
  # data: Union[List[torch.Tensor], List[np.ndarray]]
199
- pixel_values: Union[torch.Tensor, np.ndarray] = None
207
+ pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
208
+ audio_features: Union[torch.Tensor, np.ndarray] = None
209
+ audio_feature_lens: Optional[List[torch.Tensor]] = None
210
+ audio_offsets: Optional[List[Tuple[int, int]]] = None
211
+ precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
212
+
213
+ # For qwen-vl
200
214
  image_grid_thw: Union[torch.Tensor, np.ndarray] = None
201
- video_grid_thws: Union[torch.Tensor, np.ndarray] = None
215
+ second_per_grid_ts: Optional[List[torch.Tensor]] = None
202
216
 
217
+ # For deepseek-vl
203
218
  image_emb_mask: Optional[torch.Tensor] = None
204
219
  image_spatial_crop: Optional[torch.Tensor] = None
205
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
206
220
 
221
+ # For minicpmv
207
222
  # [num_images, (n, w, h)]
208
223
  tgt_size: Tuple[int, int] = None
209
224
 
210
- # kimi-vl related
211
- image_grid_hws: Optional[List[torch.Tensor]] = None
225
+ # For mllama
226
+ aspect_ratio_id: Optional[List[torch.Tensor]] = None
227
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
212
228
 
213
- audio_features: Union[torch.Tensor, np.ndarray] = None
214
- audio_feature_lens: Optional[List[torch.Tensor]] = None
215
- audio_offsets: Optional[List[Tuple[int, int]]] = None
229
+ # For kimi-vl
230
+ image_grid_hws: Optional[List[torch.Tensor]] = None
216
231
 
217
- # gemma3n related
232
+ # For gemma3n
218
233
  input_features: Optional[torch.Tensor] = None
219
234
  input_features_mask: Optional[torch.Tensor] = None
220
235
 
221
- precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
222
-
223
236
  @staticmethod
224
237
  def is_empty_list(l):
225
238
  if l is None:
@@ -230,63 +243,18 @@ class MultimodalDataItem:
230
243
  """
231
244
  Set the pad value after first hashing the data
232
245
  """
233
-
234
- def data_hash(data) -> int:
235
- hash_bytes = hashlib.sha256(data).digest()[:8]
236
- return int.from_bytes(hash_bytes, byteorder="big", signed=False)
237
-
238
- def tensor_hash(tensor_list) -> int:
239
- """
240
- hash a tensor or a tensor list
241
- """
242
- tensor = tensor_list
243
- if isinstance(tensor_list, list):
244
- tensor_list = flatten_nested_list(tensor_list)
245
- tensor_list = [
246
- x.flatten() if isinstance(x, torch.Tensor) else x
247
- for x in tensor_list
248
- ]
249
- tensor = torch.concat(tensor_list)
250
- if tensor.is_cuda:
251
- return gpu_tensor_hash(tensor)
252
- tensor = tensor.detach().contiguous()
253
-
254
- if tensor.dtype == torch.bfloat16:
255
- # memoryview() doesn't support PyTorch's BFloat16 dtype
256
- tensor = tensor.float()
257
-
258
- assert isinstance(tensor, torch.Tensor)
259
- if tensor.is_cuda:
260
- # TODO: improve this
261
- tensor_cpu = tensor.cpu()
246
+ from sglang.srt.managers.mm_utils import hash_feature
247
+
248
+ if self.hash is None:
249
+ if self.precomputed_features is not None:
250
+ self.hash = hash_feature(self.precomputed_features)
251
+ elif self.is_audio():
252
+ if self.audio_features is not None:
253
+ self.hash = hash_feature(self.audio_features)
254
+ elif self.input_features is not None:
255
+ self.hash = hash_feature(self.input_features)
262
256
  else:
263
- tensor_cpu = tensor
264
-
265
- mv = memoryview(tensor_cpu.numpy())
266
- return data_hash(mv.tobytes())
267
-
268
- def hash_feature(f):
269
- if isinstance(f, list):
270
- if isinstance(f[0], torch.Tensor):
271
- return tensor_hash(f)
272
- return data_hash(tuple(flatten_nested_list(f)))
273
- elif isinstance(f, np.ndarray):
274
- arr = np.ascontiguousarray(f)
275
- arr_bytes = arr.tobytes()
276
- return data_hash(arr_bytes)
277
- elif isinstance(f, torch.Tensor):
278
- return tensor_hash([f])
279
- return data_hash(f)
280
-
281
- if self.precomputed_features is not None:
282
- self.hash = hash_feature(self.precomputed_features)
283
- elif self.is_audio():
284
- if self.audio_features is not None:
285
- self.hash = hash_feature(self.audio_features)
286
- elif self.input_features is not None:
287
- self.hash = hash_feature(self.input_features)
288
- else:
289
- self.hash = hash_feature(self.pixel_values)
257
+ self.hash = hash_feature(self.pixel_values)
290
258
 
291
259
  assert self.hash is not None
292
260
  self.pad_value = self.hash % (1 << 30)
@@ -329,6 +297,13 @@ class MultimodalDataItem:
329
297
  ret.validate()
330
298
  return ret
331
299
 
300
+ def merge(self, other):
301
+ self.pixel_values += other.pixel_values
302
+ self.image_sizes += other.image_sizes
303
+ self.image_offsets += other.image_offsets
304
+ self.hash = hash((self.hash, other.hash))
305
+ self.set_pad_value()
306
+
332
307
 
333
308
  @dataclasses.dataclass
334
309
  class MultimodalInputs:
@@ -339,10 +314,6 @@ class MultimodalInputs:
339
314
  image_pad_len: Optional[list] = None
340
315
  num_image_tokens: Optional[int] = None
341
316
 
342
- # QWen2-VL related
343
- mrope_positions: Optional[torch.Tensor] = None
344
- mrope_position_delta: Optional[torch.Tensor] = None
345
-
346
317
  # image
347
318
  im_token_id: Optional[int] = None
348
319
  im_start_id: Optional[int] = None
@@ -358,6 +329,10 @@ class MultimodalInputs:
358
329
  audio_start_id: Optional[int] = None
359
330
  audio_end_id: Optional[int] = None
360
331
 
332
+ # QWen2-VL related
333
+ mrope_positions: Optional[torch.Tensor] = None
334
+ mrope_position_delta: Optional[torch.Tensor] = None
335
+
361
336
  @staticmethod
362
337
  def from_dict(obj: dict):
363
338
  ret = MultimodalInputs(
@@ -485,6 +460,9 @@ class Req:
485
460
  # for corss-endoder model
486
461
  self.token_type_ids = token_type_ids
487
462
 
463
+ # The length of KV that have been removed in local attention chunked prefill
464
+ self.evicted_seqlen_local = 0
465
+
488
466
  # Sampling info
489
467
  if isinstance(sampling_params.custom_params, dict):
490
468
  sampling_params = copy.copy(sampling_params)
@@ -863,8 +841,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
863
841
  # For DP attention
864
842
  global_num_tokens: Optional[List[int]] = None
865
843
  global_num_tokens_for_logprob: Optional[List[int]] = None
866
- can_run_dp_cuda_graph: bool = False
867
844
  is_extend_in_batch: bool = False
845
+ can_run_dp_cuda_graph: bool = False
868
846
  tbo_split_seq_index: Optional[int] = None
869
847
  global_forward_mode: Optional[ForwardMode] = None
870
848
 
@@ -1191,6 +1169,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1191
1169
  self.req_to_token_pool.write(
1192
1170
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1193
1171
  )
1172
+ if isinstance(self.tree_cache, SWAChunkCache):
1173
+ self.tree_cache.evict(
1174
+ req, pre_len, self.model_config.attention_chunk_size
1175
+ )
1194
1176
 
1195
1177
  # If input_embeds are available, store them
1196
1178
  if req.input_embeds is not None:
@@ -1383,7 +1365,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1383
1365
  * buf_multiplier
1384
1366
  * self.token_to_kv_pool_allocator.page_size
1385
1367
  )
1386
-
1387
1368
  if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1388
1369
  return True
1389
1370
 
@@ -1564,6 +1545,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1564
1545
  self.seq_lens.add_(1)
1565
1546
  self.seq_lens_sum += bs
1566
1547
 
1548
+ # free memory
1549
+ if isinstance(self.tree_cache, SWAChunkCache):
1550
+ for req in self.reqs:
1551
+ self.tree_cache.evict(
1552
+ req, req.seqlen - 1, self.model_config.attention_chunk_size
1553
+ )
1554
+
1567
1555
  # Allocate memory
1568
1556
  if self.token_to_kv_pool_allocator.page_size == 1:
1569
1557
  self.out_cache_loc = self.alloc_token_slots(bs)
@@ -1694,6 +1682,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1694
1682
  )
1695
1683
  or global_server_args_dict["attention_backend"] == "flashmla"
1696
1684
  or global_server_args_dict["attention_backend"] == "cutlass_mla"
1685
+ or global_server_args_dict["attention_backend"] == "ascend"
1697
1686
  or global_server_args_dict["enable_two_batch_overlap"]
1698
1687
  ):
1699
1688
  seq_lens_cpu = (
@@ -1726,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1726
1715
  token_ids_logprobs=self.token_ids_logprobs,
1727
1716
  global_num_tokens=self.global_num_tokens,
1728
1717
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1718
+ is_extend_in_batch=self.is_extend_in_batch,
1729
1719
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1730
1720
  tbo_split_seq_index=self.tbo_split_seq_index,
1731
1721
  global_forward_mode=self.global_forward_mode,
@@ -1798,7 +1788,6 @@ class ModelWorkerBatch:
1798
1788
  seq_lens: torch.Tensor
1799
1789
  # The indices of output tokens in the token_to_kv_pool_allocator
1800
1790
  out_cache_loc: torch.Tensor
1801
-
1802
1791
  # The sequence length tensor on CPU
1803
1792
  seq_lens_cpu: Optional[torch.Tensor]
1804
1793
  seq_lens_sum: int
@@ -1811,6 +1800,7 @@ class ModelWorkerBatch:
1811
1800
  # For DP attention
1812
1801
  global_num_tokens: Optional[List[int]]
1813
1802
  global_num_tokens_for_logprob: Optional[List[int]]
1803
+ is_extend_in_batch: bool
1814
1804
  can_run_dp_cuda_graph: bool
1815
1805
  tbo_split_seq_index: Optional[int]
1816
1806
  global_forward_mode: Optional[ForwardMode]
@@ -1897,7 +1887,10 @@ def get_last_loc(
1897
1887
  req_pool_indices_tensor: torch.Tensor,
1898
1888
  prefix_lens_tensor: torch.Tensor,
1899
1889
  ) -> torch.Tensor:
1900
- if global_server_args_dict["attention_backend"] != "torch_native":
1890
+ if (
1891
+ global_server_args_dict["attention_backend"] != "ascend"
1892
+ and global_server_args_dict["attention_backend"] != "torch_native"
1893
+ ):
1901
1894
  impl = get_last_loc_triton
1902
1895
  else:
1903
1896
  impl = get_last_loc_torch