sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
33
33
 
34
34
  import copy
35
35
  import dataclasses
36
- import hashlib
37
36
  import logging
38
37
  import threading
39
38
  from enum import Enum, auto
@@ -53,10 +52,9 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
53
52
  ScheduleBatchDisaggregationDecodeMixin,
54
53
  )
55
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
- from sglang.srt.layers.multimodal import gpu_tensor_hash
57
55
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
58
56
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
59
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
57
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
60
58
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
59
  from sglang.srt.metrics.collector import TimeStats
62
60
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -87,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
87
85
  "deepep_mode",
88
86
  "enable_ep_moe",
89
87
  "enable_flashinfer_moe",
88
+ "enable_flashinfer_allreduce_fusion",
90
89
  "moe_dense_tp_size",
91
90
  "ep_dispatch_algorithm",
92
91
  "deepep_config",
@@ -96,8 +95,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
96
95
  "max_micro_batch_size",
97
96
  "disable_shared_experts_fusion",
98
97
  "sampling_backend",
99
- "speculative_accept_threshold_acc",
100
98
  "speculative_accept_threshold_single",
99
+ "speculative_accept_threshold_acc",
101
100
  "torchao_config",
102
101
  "triton_attention_reduce_in_fp32",
103
102
  "num_reserved_decode_tokens",
@@ -176,45 +175,62 @@ class Modality(Enum):
176
175
  VIDEO = auto()
177
176
  AUDIO = auto()
178
177
 
178
+ @staticmethod
179
+ def from_str(modality_str: str):
180
+ try:
181
+ return Modality[modality_str.upper()]
182
+ except KeyError:
183
+ raise ValueError(
184
+ f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
185
+ )
186
+
179
187
 
180
188
  @dataclasses.dataclass
181
189
  class MultimodalDataItem:
182
190
  """
183
- A single multimodal data, from a single image/video/audio or others
191
+ One MultimodalDataItem contains all inputs for one modality.
192
+ For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
193
+ One for images and one for audio.
194
+
195
+ We put the common fields first and the model-specific fields last.
184
196
  """
185
197
 
186
198
  modality: Modality
187
-
188
199
  hash: int = None
189
200
  pad_value: int = None
190
-
191
- aspect_ratio_id: Optional[List[torch.Tensor]] = None
192
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
193
-
194
201
  image_sizes: Tuple[int, int] = None
195
202
  image_offsets: Optional[list] = None
196
203
 
197
204
  # the real data, pixel_values or audio_features
198
205
  # data: Union[List[torch.Tensor], List[np.ndarray]]
199
- pixel_values: Union[torch.Tensor, np.ndarray] = None
206
+ pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
207
+ audio_features: Union[torch.Tensor, np.ndarray] = None
208
+ audio_feature_lens: Optional[List[torch.Tensor]] = None
209
+ audio_offsets: Optional[List[Tuple[int, int]]] = None
210
+ precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
211
+
212
+ # For qwen-vl
200
213
  image_grid_thw: Union[torch.Tensor, np.ndarray] = None
201
- video_grid_thws: Union[torch.Tensor, np.ndarray] = None
214
+ second_per_grid_ts: Optional[List[torch.Tensor]] = None
202
215
 
216
+ # For deepseek-vl
203
217
  image_emb_mask: Optional[torch.Tensor] = None
204
218
  image_spatial_crop: Optional[torch.Tensor] = None
205
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
206
219
 
220
+ # For minicpmv
207
221
  # [num_images, (n, w, h)]
208
222
  tgt_size: Tuple[int, int] = None
209
223
 
210
- # kimi-vl related
211
- image_grid_hws: Optional[List[torch.Tensor]] = None
224
+ # For mllama
225
+ aspect_ratio_id: Optional[List[torch.Tensor]] = None
226
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
212
227
 
213
- audio_features: Union[torch.Tensor, np.ndarray] = None
214
- audio_feature_lens: Optional[List[torch.Tensor]] = None
215
- audio_offsets: Optional[List[Tuple[int, int]]] = None
228
+ # For kimi-vl
229
+ image_grid_hws: Optional[List[torch.Tensor]] = None
216
230
 
217
- precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
231
+ # For gemma3n
232
+ input_features: Optional[torch.Tensor] = None
233
+ input_features_mask: Optional[torch.Tensor] = None
218
234
 
219
235
  @staticmethod
220
236
  def is_empty_list(l):
@@ -226,60 +242,18 @@ class MultimodalDataItem:
226
242
  """
227
243
  Set the pad value after first hashing the data
228
244
  """
229
-
230
- def data_hash(data) -> int:
231
- hash_bytes = hashlib.sha256(data).digest()[:8]
232
- return int.from_bytes(hash_bytes, byteorder="big", signed=False)
233
-
234
- def tensor_hash(tensor_list) -> int:
235
- """
236
- hash a tensor or a tensor list
237
- """
238
- tensor = tensor_list
239
- if isinstance(tensor_list, list):
240
- tensor_list = flatten_nested_list(tensor_list)
241
- tensor_list = [
242
- x.flatten() if isinstance(x, torch.Tensor) else x
243
- for x in tensor_list
244
- ]
245
- tensor = torch.concat(tensor_list)
246
- if tensor.is_cuda:
247
- return gpu_tensor_hash(tensor)
248
- tensor = tensor.detach().contiguous()
249
-
250
- if tensor.dtype == torch.bfloat16:
251
- # memoryview() doesn't support PyTorch's BFloat16 dtype
252
- tensor = tensor.float()
253
-
254
- assert isinstance(tensor, torch.Tensor)
255
- if tensor.is_cuda:
256
- # TODO: improve this
257
- tensor_cpu = tensor.cpu()
245
+ from sglang.srt.managers.mm_utils import hash_feature
246
+
247
+ if self.hash is None:
248
+ if self.precomputed_features is not None:
249
+ self.hash = hash_feature(self.precomputed_features)
250
+ elif self.is_audio():
251
+ if self.audio_features is not None:
252
+ self.hash = hash_feature(self.audio_features)
253
+ elif self.input_features is not None:
254
+ self.hash = hash_feature(self.input_features)
258
255
  else:
259
- tensor_cpu = tensor
260
-
261
- mv = memoryview(tensor_cpu.numpy())
262
- return data_hash(mv.tobytes())
263
-
264
- def hash_feature(f):
265
- if isinstance(f, list):
266
- if isinstance(f[0], torch.Tensor):
267
- return tensor_hash(f)
268
- return data_hash(tuple(flatten_nested_list(f)))
269
- elif isinstance(f, np.ndarray):
270
- arr = np.ascontiguousarray(f)
271
- arr_bytes = arr.tobytes()
272
- return data_hash(arr_bytes)
273
- elif isinstance(f, torch.Tensor):
274
- return tensor_hash([f])
275
- return data_hash(f)
276
-
277
- if self.precomputed_features is not None:
278
- self.hash = hash_feature(self.precomputed_features)
279
- elif self.is_audio():
280
- self.hash = hash_feature(self.audio_features)
281
- else:
282
- self.hash = hash_feature(self.pixel_values)
256
+ self.hash = hash_feature(self.pixel_values)
283
257
 
284
258
  assert self.hash is not None
285
259
  self.pad_value = self.hash % (1 << 30)
@@ -288,6 +262,7 @@ class MultimodalDataItem:
288
262
  return (self.modality == Modality.AUDIO) and (
289
263
  self.precomputed_features is not None
290
264
  or not MultimodalDataItem.is_empty_list(self.audio_features)
265
+ or not MultimodalDataItem.is_empty_list(self.input_features)
291
266
  )
292
267
 
293
268
  def is_image(self):
@@ -321,6 +296,13 @@ class MultimodalDataItem:
321
296
  ret.validate()
322
297
  return ret
323
298
 
299
+ def merge(self, other):
300
+ self.pixel_values += other.pixel_values
301
+ self.image_sizes += other.image_sizes
302
+ self.image_offsets += other.image_offsets
303
+ self.hash = hash((self.hash, other.hash))
304
+ self.set_pad_value()
305
+
324
306
 
325
307
  @dataclasses.dataclass
326
308
  class MultimodalInputs:
@@ -331,10 +313,6 @@ class MultimodalInputs:
331
313
  image_pad_len: Optional[list] = None
332
314
  num_image_tokens: Optional[int] = None
333
315
 
334
- # QWen2-VL related
335
- mrope_positions: Optional[torch.Tensor] = None
336
- mrope_position_delta: Optional[torch.Tensor] = None
337
-
338
316
  # image
339
317
  im_token_id: Optional[int] = None
340
318
  im_start_id: Optional[int] = None
@@ -350,6 +328,10 @@ class MultimodalInputs:
350
328
  audio_start_id: Optional[int] = None
351
329
  audio_end_id: Optional[int] = None
352
330
 
331
+ # QWen2-VL related
332
+ mrope_positions: Optional[torch.Tensor] = None
333
+ mrope_position_delta: Optional[torch.Tensor] = None
334
+
353
335
  @staticmethod
354
336
  def from_dict(obj: dict):
355
337
  ret = MultimodalInputs(
@@ -477,6 +459,9 @@ class Req:
477
459
  # for corss-endoder model
478
460
  self.token_type_ids = token_type_ids
479
461
 
462
+ # The length of KV that have been removed in local attention chunked prefill
463
+ self.evicted_seqlen_local = 0
464
+
480
465
  # Sampling info
481
466
  if isinstance(sampling_params.custom_params, dict):
482
467
  sampling_params = copy.copy(sampling_params)
@@ -855,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
855
840
  # For DP attention
856
841
  global_num_tokens: Optional[List[int]] = None
857
842
  global_num_tokens_for_logprob: Optional[List[int]] = None
843
+ is_extend_in_batch: bool = False
858
844
  can_run_dp_cuda_graph: bool = False
859
845
  is_extend_in_batch: bool = False
860
846
  tbo_split_seq_index: Optional[int] = None
@@ -1183,6 +1169,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1183
1169
  self.req_to_token_pool.write(
1184
1170
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1185
1171
  )
1172
+ if isinstance(self.tree_cache, SWAChunkCache):
1173
+ self.tree_cache.evict(
1174
+ req, pre_len, self.model_config.attention_chunk_size
1175
+ )
1186
1176
 
1187
1177
  # If input_embeds are available, store them
1188
1178
  if req.input_embeds is not None:
@@ -1375,7 +1365,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1375
1365
  * buf_multiplier
1376
1366
  * self.token_to_kv_pool_allocator.page_size
1377
1367
  )
1378
-
1379
1368
  if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1380
1369
  return True
1381
1370
 
@@ -1556,6 +1545,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1556
1545
  self.seq_lens.add_(1)
1557
1546
  self.seq_lens_sum += bs
1558
1547
 
1548
+ # free memory
1549
+ if isinstance(self.tree_cache, SWAChunkCache):
1550
+ for req in self.reqs:
1551
+ self.tree_cache.evict(
1552
+ req, req.seqlen - 1, self.model_config.attention_chunk_size
1553
+ )
1554
+
1559
1555
  # Allocate memory
1560
1556
  if self.token_to_kv_pool_allocator.page_size == 1:
1561
1557
  self.out_cache_loc = self.alloc_token_slots(bs)
@@ -1686,6 +1682,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1686
1682
  )
1687
1683
  or global_server_args_dict["attention_backend"] == "flashmla"
1688
1684
  or global_server_args_dict["attention_backend"] == "cutlass_mla"
1685
+ or global_server_args_dict["attention_backend"] == "ascend"
1689
1686
  or global_server_args_dict["enable_two_batch_overlap"]
1690
1687
  ):
1691
1688
  seq_lens_cpu = (
@@ -1718,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1718
1715
  token_ids_logprobs=self.token_ids_logprobs,
1719
1716
  global_num_tokens=self.global_num_tokens,
1720
1717
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1718
+ is_extend_in_batch=self.is_extend_in_batch,
1721
1719
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1722
1720
  tbo_split_seq_index=self.tbo_split_seq_index,
1723
1721
  global_forward_mode=self.global_forward_mode,
@@ -1790,7 +1788,6 @@ class ModelWorkerBatch:
1790
1788
  seq_lens: torch.Tensor
1791
1789
  # The indices of output tokens in the token_to_kv_pool_allocator
1792
1790
  out_cache_loc: torch.Tensor
1793
-
1794
1791
  # The sequence length tensor on CPU
1795
1792
  seq_lens_cpu: Optional[torch.Tensor]
1796
1793
  seq_lens_sum: int
@@ -1803,6 +1800,7 @@ class ModelWorkerBatch:
1803
1800
  # For DP attention
1804
1801
  global_num_tokens: Optional[List[int]]
1805
1802
  global_num_tokens_for_logprob: Optional[List[int]]
1803
+ is_extend_in_batch: bool
1806
1804
  can_run_dp_cuda_graph: bool
1807
1805
  tbo_split_seq_index: Optional[int]
1808
1806
  global_forward_mode: Optional[ForwardMode]
@@ -1889,7 +1887,10 @@ def get_last_loc(
1889
1887
  req_pool_indices_tensor: torch.Tensor,
1890
1888
  prefix_lens_tensor: torch.Tensor,
1891
1889
  ) -> torch.Tensor:
1892
- if global_server_args_dict["attention_backend"] != "torch_native":
1890
+ if (
1891
+ global_server_args_dict["attention_backend"] != "ascend"
1892
+ and global_server_args_dict["attention_backend"] != "torch_native"
1893
+ ):
1893
1894
  impl = get_last_loc_triton
1894
1895
  else:
1895
1896
  impl = get_last_loc_torch