sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
33
33
 
34
34
  import copy
35
35
  import dataclasses
36
- import hashlib
37
36
  import logging
38
37
  import threading
39
38
  from enum import Enum, auto
@@ -53,10 +52,9 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
53
52
  ScheduleBatchDisaggregationDecodeMixin,
54
53
  )
55
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
- from sglang.srt.layers.multimodal import gpu_tensor_hash
57
55
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
58
56
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
59
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
57
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
60
58
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
59
  from sglang.srt.metrics.collector import TimeStats
62
60
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -87,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
87
85
  "deepep_mode",
88
86
  "enable_ep_moe",
89
87
  "enable_flashinfer_moe",
88
+ "enable_flashinfer_allreduce_fusion",
90
89
  "moe_dense_tp_size",
91
90
  "ep_dispatch_algorithm",
92
91
  "deepep_config",
@@ -96,8 +95,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
96
95
  "max_micro_batch_size",
97
96
  "disable_shared_experts_fusion",
98
97
  "sampling_backend",
99
- "speculative_accept_threshold_acc",
100
98
  "speculative_accept_threshold_single",
99
+ "speculative_accept_threshold_acc",
101
100
  "torchao_config",
102
101
  "triton_attention_reduce_in_fp32",
103
102
  "num_reserved_decode_tokens",
@@ -176,50 +175,63 @@ class Modality(Enum):
176
175
  VIDEO = auto()
177
176
  AUDIO = auto()
178
177
 
178
+ @staticmethod
179
+ def from_str(modality_str: str):
180
+ try:
181
+ return Modality[modality_str.upper()]
182
+ except KeyError:
183
+ raise ValueError(
184
+ f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
185
+ )
186
+
179
187
 
180
188
  @dataclasses.dataclass
181
189
  class MultimodalDataItem:
182
190
  """
183
- A single multimodal data, from a single image/video/audio or others
191
+ One MultimodalDataItem contains all inputs for one modality.
192
+ For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
193
+ One for images and one for audio.
194
+
195
+ We put the common fields first and the model-specific fields last.
184
196
  """
185
197
 
186
198
  modality: Modality
187
-
188
199
  hash: int = None
189
200
  pad_value: int = None
190
-
191
- aspect_ratio_id: Optional[List[torch.Tensor]] = None
192
- aspect_ratio_mask: Optional[List[torch.Tensor]] = None
193
-
194
201
  image_sizes: Tuple[int, int] = None
195
202
  image_offsets: Optional[list] = None
196
203
 
197
204
  # the real data, pixel_values or audio_features
198
205
  # data: Union[List[torch.Tensor], List[np.ndarray]]
199
- pixel_values: Union[torch.Tensor, np.ndarray] = None
206
+ pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
207
+ audio_features: Union[torch.Tensor, np.ndarray] = None
208
+ audio_feature_lens: Optional[List[torch.Tensor]] = None
209
+ audio_offsets: Optional[List[Tuple[int, int]]] = None
210
+ precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
211
+
212
+ # For qwen-vl
200
213
  image_grid_thw: Union[torch.Tensor, np.ndarray] = None
201
- video_grid_thws: Union[torch.Tensor, np.ndarray] = None
214
+ second_per_grid_ts: Optional[List[torch.Tensor]] = None
202
215
 
216
+ # For deepseek-vl
203
217
  image_emb_mask: Optional[torch.Tensor] = None
204
218
  image_spatial_crop: Optional[torch.Tensor] = None
205
- second_per_grid_ts: Optional[List[torch.Tensor]] = None
206
219
 
220
+ # For minicpmv
207
221
  # [num_images, (n, w, h)]
208
222
  tgt_size: Tuple[int, int] = None
209
223
 
210
- # kimi-vl related
211
- image_grid_hws: Optional[List[torch.Tensor]] = None
224
+ # For mllama
225
+ aspect_ratio_id: Optional[List[torch.Tensor]] = None
226
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
212
227
 
213
- audio_features: Union[torch.Tensor, np.ndarray] = None
214
- audio_feature_lens: Optional[List[torch.Tensor]] = None
215
- audio_offsets: Optional[List[Tuple[int, int]]] = None
228
+ # For kimi-vl
229
+ image_grid_hws: Optional[List[torch.Tensor]] = None
216
230
 
217
- # gemma3n related
231
+ # For gemma3n
218
232
  input_features: Optional[torch.Tensor] = None
219
233
  input_features_mask: Optional[torch.Tensor] = None
220
234
 
221
- precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
222
-
223
235
  @staticmethod
224
236
  def is_empty_list(l):
225
237
  if l is None:
@@ -230,63 +242,18 @@ class MultimodalDataItem:
230
242
  """
231
243
  Set the pad value after first hashing the data
232
244
  """
233
-
234
- def data_hash(data) -> int:
235
- hash_bytes = hashlib.sha256(data).digest()[:8]
236
- return int.from_bytes(hash_bytes, byteorder="big", signed=False)
237
-
238
- def tensor_hash(tensor_list) -> int:
239
- """
240
- hash a tensor or a tensor list
241
- """
242
- tensor = tensor_list
243
- if isinstance(tensor_list, list):
244
- tensor_list = flatten_nested_list(tensor_list)
245
- tensor_list = [
246
- x.flatten() if isinstance(x, torch.Tensor) else x
247
- for x in tensor_list
248
- ]
249
- tensor = torch.concat(tensor_list)
250
- if tensor.is_cuda:
251
- return gpu_tensor_hash(tensor)
252
- tensor = tensor.detach().contiguous()
253
-
254
- if tensor.dtype == torch.bfloat16:
255
- # memoryview() doesn't support PyTorch's BFloat16 dtype
256
- tensor = tensor.float()
257
-
258
- assert isinstance(tensor, torch.Tensor)
259
- if tensor.is_cuda:
260
- # TODO: improve this
261
- tensor_cpu = tensor.cpu()
245
+ from sglang.srt.managers.mm_utils import hash_feature
246
+
247
+ if self.hash is None:
248
+ if self.precomputed_features is not None:
249
+ self.hash = hash_feature(self.precomputed_features)
250
+ elif self.is_audio():
251
+ if self.audio_features is not None:
252
+ self.hash = hash_feature(self.audio_features)
253
+ elif self.input_features is not None:
254
+ self.hash = hash_feature(self.input_features)
262
255
  else:
263
- tensor_cpu = tensor
264
-
265
- mv = memoryview(tensor_cpu.numpy())
266
- return data_hash(mv.tobytes())
267
-
268
- def hash_feature(f):
269
- if isinstance(f, list):
270
- if isinstance(f[0], torch.Tensor):
271
- return tensor_hash(f)
272
- return data_hash(tuple(flatten_nested_list(f)))
273
- elif isinstance(f, np.ndarray):
274
- arr = np.ascontiguousarray(f)
275
- arr_bytes = arr.tobytes()
276
- return data_hash(arr_bytes)
277
- elif isinstance(f, torch.Tensor):
278
- return tensor_hash([f])
279
- return data_hash(f)
280
-
281
- if self.precomputed_features is not None:
282
- self.hash = hash_feature(self.precomputed_features)
283
- elif self.is_audio():
284
- if self.audio_features is not None:
285
- self.hash = hash_feature(self.audio_features)
286
- elif self.input_features is not None:
287
- self.hash = hash_feature(self.input_features)
288
- else:
289
- self.hash = hash_feature(self.pixel_values)
256
+ self.hash = hash_feature(self.pixel_values)
290
257
 
291
258
  assert self.hash is not None
292
259
  self.pad_value = self.hash % (1 << 30)
@@ -329,6 +296,13 @@ class MultimodalDataItem:
329
296
  ret.validate()
330
297
  return ret
331
298
 
299
+ def merge(self, other):
300
+ self.pixel_values += other.pixel_values
301
+ self.image_sizes += other.image_sizes
302
+ self.image_offsets += other.image_offsets
303
+ self.hash = hash((self.hash, other.hash))
304
+ self.set_pad_value()
305
+
332
306
 
333
307
  @dataclasses.dataclass
334
308
  class MultimodalInputs:
@@ -339,10 +313,6 @@ class MultimodalInputs:
339
313
  image_pad_len: Optional[list] = None
340
314
  num_image_tokens: Optional[int] = None
341
315
 
342
- # QWen2-VL related
343
- mrope_positions: Optional[torch.Tensor] = None
344
- mrope_position_delta: Optional[torch.Tensor] = None
345
-
346
316
  # image
347
317
  im_token_id: Optional[int] = None
348
318
  im_start_id: Optional[int] = None
@@ -358,6 +328,10 @@ class MultimodalInputs:
358
328
  audio_start_id: Optional[int] = None
359
329
  audio_end_id: Optional[int] = None
360
330
 
331
+ # QWen2-VL related
332
+ mrope_positions: Optional[torch.Tensor] = None
333
+ mrope_position_delta: Optional[torch.Tensor] = None
334
+
361
335
  @staticmethod
362
336
  def from_dict(obj: dict):
363
337
  ret = MultimodalInputs(
@@ -485,6 +459,9 @@ class Req:
485
459
  # for corss-endoder model
486
460
  self.token_type_ids = token_type_ids
487
461
 
462
+ # The length of KV that have been removed in local attention chunked prefill
463
+ self.evicted_seqlen_local = 0
464
+
488
465
  # Sampling info
489
466
  if isinstance(sampling_params.custom_params, dict):
490
467
  sampling_params = copy.copy(sampling_params)
@@ -863,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
863
840
  # For DP attention
864
841
  global_num_tokens: Optional[List[int]] = None
865
842
  global_num_tokens_for_logprob: Optional[List[int]] = None
843
+ is_extend_in_batch: bool = False
866
844
  can_run_dp_cuda_graph: bool = False
867
845
  is_extend_in_batch: bool = False
868
846
  tbo_split_seq_index: Optional[int] = None
@@ -1191,6 +1169,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1191
1169
  self.req_to_token_pool.write(
1192
1170
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1193
1171
  )
1172
+ if isinstance(self.tree_cache, SWAChunkCache):
1173
+ self.tree_cache.evict(
1174
+ req, pre_len, self.model_config.attention_chunk_size
1175
+ )
1194
1176
 
1195
1177
  # If input_embeds are available, store them
1196
1178
  if req.input_embeds is not None:
@@ -1383,7 +1365,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1383
1365
  * buf_multiplier
1384
1366
  * self.token_to_kv_pool_allocator.page_size
1385
1367
  )
1386
-
1387
1368
  if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1388
1369
  return True
1389
1370
 
@@ -1564,6 +1545,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1564
1545
  self.seq_lens.add_(1)
1565
1546
  self.seq_lens_sum += bs
1566
1547
 
1548
+ # free memory
1549
+ if isinstance(self.tree_cache, SWAChunkCache):
1550
+ for req in self.reqs:
1551
+ self.tree_cache.evict(
1552
+ req, req.seqlen - 1, self.model_config.attention_chunk_size
1553
+ )
1554
+
1567
1555
  # Allocate memory
1568
1556
  if self.token_to_kv_pool_allocator.page_size == 1:
1569
1557
  self.out_cache_loc = self.alloc_token_slots(bs)
@@ -1694,6 +1682,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1694
1682
  )
1695
1683
  or global_server_args_dict["attention_backend"] == "flashmla"
1696
1684
  or global_server_args_dict["attention_backend"] == "cutlass_mla"
1685
+ or global_server_args_dict["attention_backend"] == "ascend"
1697
1686
  or global_server_args_dict["enable_two_batch_overlap"]
1698
1687
  ):
1699
1688
  seq_lens_cpu = (
@@ -1726,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1726
1715
  token_ids_logprobs=self.token_ids_logprobs,
1727
1716
  global_num_tokens=self.global_num_tokens,
1728
1717
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1718
+ is_extend_in_batch=self.is_extend_in_batch,
1729
1719
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1730
1720
  tbo_split_seq_index=self.tbo_split_seq_index,
1731
1721
  global_forward_mode=self.global_forward_mode,
@@ -1798,7 +1788,6 @@ class ModelWorkerBatch:
1798
1788
  seq_lens: torch.Tensor
1799
1789
  # The indices of output tokens in the token_to_kv_pool_allocator
1800
1790
  out_cache_loc: torch.Tensor
1801
-
1802
1791
  # The sequence length tensor on CPU
1803
1792
  seq_lens_cpu: Optional[torch.Tensor]
1804
1793
  seq_lens_sum: int
@@ -1811,6 +1800,7 @@ class ModelWorkerBatch:
1811
1800
  # For DP attention
1812
1801
  global_num_tokens: Optional[List[int]]
1813
1802
  global_num_tokens_for_logprob: Optional[List[int]]
1803
+ is_extend_in_batch: bool
1814
1804
  can_run_dp_cuda_graph: bool
1815
1805
  tbo_split_seq_index: Optional[int]
1816
1806
  global_forward_mode: Optional[ForwardMode]
@@ -1897,7 +1887,10 @@ def get_last_loc(
1897
1887
  req_pool_indices_tensor: torch.Tensor,
1898
1888
  prefix_lens_tensor: torch.Tensor,
1899
1889
  ) -> torch.Tensor:
1900
- if global_server_args_dict["attention_backend"] != "torch_native":
1890
+ if (
1891
+ global_server_args_dict["attention_backend"] != "ascend"
1892
+ and global_server_args_dict["attention_backend"] != "torch_native"
1893
+ ):
1901
1894
  impl = get_last_loc_triton
1902
1895
  else:
1903
1896
  impl = get_last_loc_torch