sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
|
33
33
|
|
34
34
|
import copy
|
35
35
|
import dataclasses
|
36
|
-
import hashlib
|
37
36
|
import logging
|
38
37
|
import threading
|
39
38
|
from enum import Enum, auto
|
@@ -53,10 +52,9 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
53
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
54
53
|
)
|
55
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
56
|
-
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
57
55
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
58
56
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
59
|
-
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
57
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
60
58
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
61
59
|
from sglang.srt.metrics.collector import TimeStats
|
62
60
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
@@ -87,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
87
85
|
"deepep_mode",
|
88
86
|
"enable_ep_moe",
|
89
87
|
"enable_flashinfer_moe",
|
88
|
+
"enable_flashinfer_allreduce_fusion",
|
90
89
|
"moe_dense_tp_size",
|
91
90
|
"ep_dispatch_algorithm",
|
92
91
|
"deepep_config",
|
@@ -96,8 +95,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
96
95
|
"max_micro_batch_size",
|
97
96
|
"disable_shared_experts_fusion",
|
98
97
|
"sampling_backend",
|
99
|
-
"speculative_accept_threshold_acc",
|
100
98
|
"speculative_accept_threshold_single",
|
99
|
+
"speculative_accept_threshold_acc",
|
101
100
|
"torchao_config",
|
102
101
|
"triton_attention_reduce_in_fp32",
|
103
102
|
"num_reserved_decode_tokens",
|
@@ -176,50 +175,63 @@ class Modality(Enum):
|
|
176
175
|
VIDEO = auto()
|
177
176
|
AUDIO = auto()
|
178
177
|
|
178
|
+
@staticmethod
|
179
|
+
def from_str(modality_str: str):
|
180
|
+
try:
|
181
|
+
return Modality[modality_str.upper()]
|
182
|
+
except KeyError:
|
183
|
+
raise ValueError(
|
184
|
+
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
185
|
+
)
|
186
|
+
|
179
187
|
|
180
188
|
@dataclasses.dataclass
|
181
189
|
class MultimodalDataItem:
|
182
190
|
"""
|
183
|
-
|
191
|
+
One MultimodalDataItem contains all inputs for one modality.
|
192
|
+
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
193
|
+
One for images and one for audio.
|
194
|
+
|
195
|
+
We put the common fields first and the model-specific fields last.
|
184
196
|
"""
|
185
197
|
|
186
198
|
modality: Modality
|
187
|
-
|
188
199
|
hash: int = None
|
189
200
|
pad_value: int = None
|
190
|
-
|
191
|
-
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
192
|
-
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
193
|
-
|
194
201
|
image_sizes: Tuple[int, int] = None
|
195
202
|
image_offsets: Optional[list] = None
|
196
203
|
|
197
204
|
# the real data, pixel_values or audio_features
|
198
205
|
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
199
|
-
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
206
|
+
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
|
207
|
+
audio_features: Union[torch.Tensor, np.ndarray] = None
|
208
|
+
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
209
|
+
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
210
|
+
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
211
|
+
|
212
|
+
# For qwen-vl
|
200
213
|
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
201
|
-
|
214
|
+
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
202
215
|
|
216
|
+
# For deepseek-vl
|
203
217
|
image_emb_mask: Optional[torch.Tensor] = None
|
204
218
|
image_spatial_crop: Optional[torch.Tensor] = None
|
205
|
-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
206
219
|
|
220
|
+
# For minicpmv
|
207
221
|
# [num_images, (n, w, h)]
|
208
222
|
tgt_size: Tuple[int, int] = None
|
209
223
|
|
210
|
-
#
|
211
|
-
|
224
|
+
# For mllama
|
225
|
+
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
226
|
+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
212
227
|
|
213
|
-
|
214
|
-
|
215
|
-
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
228
|
+
# For kimi-vl
|
229
|
+
image_grid_hws: Optional[List[torch.Tensor]] = None
|
216
230
|
|
217
|
-
# gemma3n
|
231
|
+
# For gemma3n
|
218
232
|
input_features: Optional[torch.Tensor] = None
|
219
233
|
input_features_mask: Optional[torch.Tensor] = None
|
220
234
|
|
221
|
-
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
222
|
-
|
223
235
|
@staticmethod
|
224
236
|
def is_empty_list(l):
|
225
237
|
if l is None:
|
@@ -230,63 +242,18 @@ class MultimodalDataItem:
|
|
230
242
|
"""
|
231
243
|
Set the pad value after first hashing the data
|
232
244
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
if isinstance(tensor_list, list):
|
244
|
-
tensor_list = flatten_nested_list(tensor_list)
|
245
|
-
tensor_list = [
|
246
|
-
x.flatten() if isinstance(x, torch.Tensor) else x
|
247
|
-
for x in tensor_list
|
248
|
-
]
|
249
|
-
tensor = torch.concat(tensor_list)
|
250
|
-
if tensor.is_cuda:
|
251
|
-
return gpu_tensor_hash(tensor)
|
252
|
-
tensor = tensor.detach().contiguous()
|
253
|
-
|
254
|
-
if tensor.dtype == torch.bfloat16:
|
255
|
-
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
256
|
-
tensor = tensor.float()
|
257
|
-
|
258
|
-
assert isinstance(tensor, torch.Tensor)
|
259
|
-
if tensor.is_cuda:
|
260
|
-
# TODO: improve this
|
261
|
-
tensor_cpu = tensor.cpu()
|
245
|
+
from sglang.srt.managers.mm_utils import hash_feature
|
246
|
+
|
247
|
+
if self.hash is None:
|
248
|
+
if self.precomputed_features is not None:
|
249
|
+
self.hash = hash_feature(self.precomputed_features)
|
250
|
+
elif self.is_audio():
|
251
|
+
if self.audio_features is not None:
|
252
|
+
self.hash = hash_feature(self.audio_features)
|
253
|
+
elif self.input_features is not None:
|
254
|
+
self.hash = hash_feature(self.input_features)
|
262
255
|
else:
|
263
|
-
|
264
|
-
|
265
|
-
mv = memoryview(tensor_cpu.numpy())
|
266
|
-
return data_hash(mv.tobytes())
|
267
|
-
|
268
|
-
def hash_feature(f):
|
269
|
-
if isinstance(f, list):
|
270
|
-
if isinstance(f[0], torch.Tensor):
|
271
|
-
return tensor_hash(f)
|
272
|
-
return data_hash(tuple(flatten_nested_list(f)))
|
273
|
-
elif isinstance(f, np.ndarray):
|
274
|
-
arr = np.ascontiguousarray(f)
|
275
|
-
arr_bytes = arr.tobytes()
|
276
|
-
return data_hash(arr_bytes)
|
277
|
-
elif isinstance(f, torch.Tensor):
|
278
|
-
return tensor_hash([f])
|
279
|
-
return data_hash(f)
|
280
|
-
|
281
|
-
if self.precomputed_features is not None:
|
282
|
-
self.hash = hash_feature(self.precomputed_features)
|
283
|
-
elif self.is_audio():
|
284
|
-
if self.audio_features is not None:
|
285
|
-
self.hash = hash_feature(self.audio_features)
|
286
|
-
elif self.input_features is not None:
|
287
|
-
self.hash = hash_feature(self.input_features)
|
288
|
-
else:
|
289
|
-
self.hash = hash_feature(self.pixel_values)
|
256
|
+
self.hash = hash_feature(self.pixel_values)
|
290
257
|
|
291
258
|
assert self.hash is not None
|
292
259
|
self.pad_value = self.hash % (1 << 30)
|
@@ -329,6 +296,13 @@ class MultimodalDataItem:
|
|
329
296
|
ret.validate()
|
330
297
|
return ret
|
331
298
|
|
299
|
+
def merge(self, other):
|
300
|
+
self.pixel_values += other.pixel_values
|
301
|
+
self.image_sizes += other.image_sizes
|
302
|
+
self.image_offsets += other.image_offsets
|
303
|
+
self.hash = hash((self.hash, other.hash))
|
304
|
+
self.set_pad_value()
|
305
|
+
|
332
306
|
|
333
307
|
@dataclasses.dataclass
|
334
308
|
class MultimodalInputs:
|
@@ -339,10 +313,6 @@ class MultimodalInputs:
|
|
339
313
|
image_pad_len: Optional[list] = None
|
340
314
|
num_image_tokens: Optional[int] = None
|
341
315
|
|
342
|
-
# QWen2-VL related
|
343
|
-
mrope_positions: Optional[torch.Tensor] = None
|
344
|
-
mrope_position_delta: Optional[torch.Tensor] = None
|
345
|
-
|
346
316
|
# image
|
347
317
|
im_token_id: Optional[int] = None
|
348
318
|
im_start_id: Optional[int] = None
|
@@ -358,6 +328,10 @@ class MultimodalInputs:
|
|
358
328
|
audio_start_id: Optional[int] = None
|
359
329
|
audio_end_id: Optional[int] = None
|
360
330
|
|
331
|
+
# QWen2-VL related
|
332
|
+
mrope_positions: Optional[torch.Tensor] = None
|
333
|
+
mrope_position_delta: Optional[torch.Tensor] = None
|
334
|
+
|
361
335
|
@staticmethod
|
362
336
|
def from_dict(obj: dict):
|
363
337
|
ret = MultimodalInputs(
|
@@ -485,6 +459,9 @@ class Req:
|
|
485
459
|
# for corss-endoder model
|
486
460
|
self.token_type_ids = token_type_ids
|
487
461
|
|
462
|
+
# The length of KV that have been removed in local attention chunked prefill
|
463
|
+
self.evicted_seqlen_local = 0
|
464
|
+
|
488
465
|
# Sampling info
|
489
466
|
if isinstance(sampling_params.custom_params, dict):
|
490
467
|
sampling_params = copy.copy(sampling_params)
|
@@ -863,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
863
840
|
# For DP attention
|
864
841
|
global_num_tokens: Optional[List[int]] = None
|
865
842
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
843
|
+
is_extend_in_batch: bool = False
|
866
844
|
can_run_dp_cuda_graph: bool = False
|
867
845
|
is_extend_in_batch: bool = False
|
868
846
|
tbo_split_seq_index: Optional[int] = None
|
@@ -1191,6 +1169,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1191
1169
|
self.req_to_token_pool.write(
|
1192
1170
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1193
1171
|
)
|
1172
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1173
|
+
self.tree_cache.evict(
|
1174
|
+
req, pre_len, self.model_config.attention_chunk_size
|
1175
|
+
)
|
1194
1176
|
|
1195
1177
|
# If input_embeds are available, store them
|
1196
1178
|
if req.input_embeds is not None:
|
@@ -1383,7 +1365,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1383
1365
|
* buf_multiplier
|
1384
1366
|
* self.token_to_kv_pool_allocator.page_size
|
1385
1367
|
)
|
1386
|
-
|
1387
1368
|
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
1388
1369
|
return True
|
1389
1370
|
|
@@ -1564,6 +1545,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1564
1545
|
self.seq_lens.add_(1)
|
1565
1546
|
self.seq_lens_sum += bs
|
1566
1547
|
|
1548
|
+
# free memory
|
1549
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1550
|
+
for req in self.reqs:
|
1551
|
+
self.tree_cache.evict(
|
1552
|
+
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
1553
|
+
)
|
1554
|
+
|
1567
1555
|
# Allocate memory
|
1568
1556
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
1569
1557
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
@@ -1694,6 +1682,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1694
1682
|
)
|
1695
1683
|
or global_server_args_dict["attention_backend"] == "flashmla"
|
1696
1684
|
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
1685
|
+
or global_server_args_dict["attention_backend"] == "ascend"
|
1697
1686
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1698
1687
|
):
|
1699
1688
|
seq_lens_cpu = (
|
@@ -1726,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1726
1715
|
token_ids_logprobs=self.token_ids_logprobs,
|
1727
1716
|
global_num_tokens=self.global_num_tokens,
|
1728
1717
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1718
|
+
is_extend_in_batch=self.is_extend_in_batch,
|
1729
1719
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1730
1720
|
tbo_split_seq_index=self.tbo_split_seq_index,
|
1731
1721
|
global_forward_mode=self.global_forward_mode,
|
@@ -1798,7 +1788,6 @@ class ModelWorkerBatch:
|
|
1798
1788
|
seq_lens: torch.Tensor
|
1799
1789
|
# The indices of output tokens in the token_to_kv_pool_allocator
|
1800
1790
|
out_cache_loc: torch.Tensor
|
1801
|
-
|
1802
1791
|
# The sequence length tensor on CPU
|
1803
1792
|
seq_lens_cpu: Optional[torch.Tensor]
|
1804
1793
|
seq_lens_sum: int
|
@@ -1811,6 +1800,7 @@ class ModelWorkerBatch:
|
|
1811
1800
|
# For DP attention
|
1812
1801
|
global_num_tokens: Optional[List[int]]
|
1813
1802
|
global_num_tokens_for_logprob: Optional[List[int]]
|
1803
|
+
is_extend_in_batch: bool
|
1814
1804
|
can_run_dp_cuda_graph: bool
|
1815
1805
|
tbo_split_seq_index: Optional[int]
|
1816
1806
|
global_forward_mode: Optional[ForwardMode]
|
@@ -1897,7 +1887,10 @@ def get_last_loc(
|
|
1897
1887
|
req_pool_indices_tensor: torch.Tensor,
|
1898
1888
|
prefix_lens_tensor: torch.Tensor,
|
1899
1889
|
) -> torch.Tensor:
|
1900
|
-
if
|
1890
|
+
if (
|
1891
|
+
global_server_args_dict["attention_backend"] != "ascend"
|
1892
|
+
and global_server_args_dict["attention_backend"] != "torch_native"
|
1893
|
+
):
|
1901
1894
|
impl = get_last_loc_triton
|
1902
1895
|
else:
|
1903
1896
|
impl = get_last_loc_torch
|