sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
|
33
33
|
|
34
34
|
import copy
|
35
35
|
import dataclasses
|
36
|
-
import hashlib
|
37
36
|
import logging
|
38
37
|
import threading
|
39
38
|
from enum import Enum, auto
|
@@ -53,10 +52,9 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
53
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
54
53
|
)
|
55
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
56
|
-
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
57
55
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
58
56
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
59
|
-
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
57
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
60
58
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
61
59
|
from sglang.srt.metrics.collector import TimeStats
|
62
60
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
@@ -87,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
87
85
|
"deepep_mode",
|
88
86
|
"enable_ep_moe",
|
89
87
|
"enable_flashinfer_moe",
|
88
|
+
"enable_flashinfer_allreduce_fusion",
|
90
89
|
"moe_dense_tp_size",
|
91
90
|
"ep_dispatch_algorithm",
|
92
91
|
"deepep_config",
|
@@ -96,8 +95,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
96
95
|
"max_micro_batch_size",
|
97
96
|
"disable_shared_experts_fusion",
|
98
97
|
"sampling_backend",
|
99
|
-
"speculative_accept_threshold_acc",
|
100
98
|
"speculative_accept_threshold_single",
|
99
|
+
"speculative_accept_threshold_acc",
|
101
100
|
"torchao_config",
|
102
101
|
"triton_attention_reduce_in_fp32",
|
103
102
|
"num_reserved_decode_tokens",
|
@@ -176,45 +175,62 @@ class Modality(Enum):
|
|
176
175
|
VIDEO = auto()
|
177
176
|
AUDIO = auto()
|
178
177
|
|
178
|
+
@staticmethod
|
179
|
+
def from_str(modality_str: str):
|
180
|
+
try:
|
181
|
+
return Modality[modality_str.upper()]
|
182
|
+
except KeyError:
|
183
|
+
raise ValueError(
|
184
|
+
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
185
|
+
)
|
186
|
+
|
179
187
|
|
180
188
|
@dataclasses.dataclass
|
181
189
|
class MultimodalDataItem:
|
182
190
|
"""
|
183
|
-
|
191
|
+
One MultimodalDataItem contains all inputs for one modality.
|
192
|
+
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
193
|
+
One for images and one for audio.
|
194
|
+
|
195
|
+
We put the common fields first and the model-specific fields last.
|
184
196
|
"""
|
185
197
|
|
186
198
|
modality: Modality
|
187
|
-
|
188
199
|
hash: int = None
|
189
200
|
pad_value: int = None
|
190
|
-
|
191
|
-
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
192
|
-
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
193
|
-
|
194
201
|
image_sizes: Tuple[int, int] = None
|
195
202
|
image_offsets: Optional[list] = None
|
196
203
|
|
197
204
|
# the real data, pixel_values or audio_features
|
198
205
|
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
199
|
-
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
206
|
+
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
|
207
|
+
audio_features: Union[torch.Tensor, np.ndarray] = None
|
208
|
+
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
209
|
+
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
210
|
+
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
211
|
+
|
212
|
+
# For qwen-vl
|
200
213
|
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
201
|
-
|
214
|
+
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
202
215
|
|
216
|
+
# For deepseek-vl
|
203
217
|
image_emb_mask: Optional[torch.Tensor] = None
|
204
218
|
image_spatial_crop: Optional[torch.Tensor] = None
|
205
|
-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
206
219
|
|
220
|
+
# For minicpmv
|
207
221
|
# [num_images, (n, w, h)]
|
208
222
|
tgt_size: Tuple[int, int] = None
|
209
223
|
|
210
|
-
#
|
211
|
-
|
224
|
+
# For mllama
|
225
|
+
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
226
|
+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
212
227
|
|
213
|
-
|
214
|
-
|
215
|
-
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
228
|
+
# For kimi-vl
|
229
|
+
image_grid_hws: Optional[List[torch.Tensor]] = None
|
216
230
|
|
217
|
-
|
231
|
+
# For gemma3n
|
232
|
+
input_features: Optional[torch.Tensor] = None
|
233
|
+
input_features_mask: Optional[torch.Tensor] = None
|
218
234
|
|
219
235
|
@staticmethod
|
220
236
|
def is_empty_list(l):
|
@@ -226,60 +242,18 @@ class MultimodalDataItem:
|
|
226
242
|
"""
|
227
243
|
Set the pad value after first hashing the data
|
228
244
|
"""
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
if isinstance(tensor_list, list):
|
240
|
-
tensor_list = flatten_nested_list(tensor_list)
|
241
|
-
tensor_list = [
|
242
|
-
x.flatten() if isinstance(x, torch.Tensor) else x
|
243
|
-
for x in tensor_list
|
244
|
-
]
|
245
|
-
tensor = torch.concat(tensor_list)
|
246
|
-
if tensor.is_cuda:
|
247
|
-
return gpu_tensor_hash(tensor)
|
248
|
-
tensor = tensor.detach().contiguous()
|
249
|
-
|
250
|
-
if tensor.dtype == torch.bfloat16:
|
251
|
-
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
252
|
-
tensor = tensor.float()
|
253
|
-
|
254
|
-
assert isinstance(tensor, torch.Tensor)
|
255
|
-
if tensor.is_cuda:
|
256
|
-
# TODO: improve this
|
257
|
-
tensor_cpu = tensor.cpu()
|
245
|
+
from sglang.srt.managers.mm_utils import hash_feature
|
246
|
+
|
247
|
+
if self.hash is None:
|
248
|
+
if self.precomputed_features is not None:
|
249
|
+
self.hash = hash_feature(self.precomputed_features)
|
250
|
+
elif self.is_audio():
|
251
|
+
if self.audio_features is not None:
|
252
|
+
self.hash = hash_feature(self.audio_features)
|
253
|
+
elif self.input_features is not None:
|
254
|
+
self.hash = hash_feature(self.input_features)
|
258
255
|
else:
|
259
|
-
|
260
|
-
|
261
|
-
mv = memoryview(tensor_cpu.numpy())
|
262
|
-
return data_hash(mv.tobytes())
|
263
|
-
|
264
|
-
def hash_feature(f):
|
265
|
-
if isinstance(f, list):
|
266
|
-
if isinstance(f[0], torch.Tensor):
|
267
|
-
return tensor_hash(f)
|
268
|
-
return data_hash(tuple(flatten_nested_list(f)))
|
269
|
-
elif isinstance(f, np.ndarray):
|
270
|
-
arr = np.ascontiguousarray(f)
|
271
|
-
arr_bytes = arr.tobytes()
|
272
|
-
return data_hash(arr_bytes)
|
273
|
-
elif isinstance(f, torch.Tensor):
|
274
|
-
return tensor_hash([f])
|
275
|
-
return data_hash(f)
|
276
|
-
|
277
|
-
if self.precomputed_features is not None:
|
278
|
-
self.hash = hash_feature(self.precomputed_features)
|
279
|
-
elif self.is_audio():
|
280
|
-
self.hash = hash_feature(self.audio_features)
|
281
|
-
else:
|
282
|
-
self.hash = hash_feature(self.pixel_values)
|
256
|
+
self.hash = hash_feature(self.pixel_values)
|
283
257
|
|
284
258
|
assert self.hash is not None
|
285
259
|
self.pad_value = self.hash % (1 << 30)
|
@@ -288,6 +262,7 @@ class MultimodalDataItem:
|
|
288
262
|
return (self.modality == Modality.AUDIO) and (
|
289
263
|
self.precomputed_features is not None
|
290
264
|
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
265
|
+
or not MultimodalDataItem.is_empty_list(self.input_features)
|
291
266
|
)
|
292
267
|
|
293
268
|
def is_image(self):
|
@@ -321,6 +296,13 @@ class MultimodalDataItem:
|
|
321
296
|
ret.validate()
|
322
297
|
return ret
|
323
298
|
|
299
|
+
def merge(self, other):
|
300
|
+
self.pixel_values += other.pixel_values
|
301
|
+
self.image_sizes += other.image_sizes
|
302
|
+
self.image_offsets += other.image_offsets
|
303
|
+
self.hash = hash((self.hash, other.hash))
|
304
|
+
self.set_pad_value()
|
305
|
+
|
324
306
|
|
325
307
|
@dataclasses.dataclass
|
326
308
|
class MultimodalInputs:
|
@@ -331,10 +313,6 @@ class MultimodalInputs:
|
|
331
313
|
image_pad_len: Optional[list] = None
|
332
314
|
num_image_tokens: Optional[int] = None
|
333
315
|
|
334
|
-
# QWen2-VL related
|
335
|
-
mrope_positions: Optional[torch.Tensor] = None
|
336
|
-
mrope_position_delta: Optional[torch.Tensor] = None
|
337
|
-
|
338
316
|
# image
|
339
317
|
im_token_id: Optional[int] = None
|
340
318
|
im_start_id: Optional[int] = None
|
@@ -350,6 +328,10 @@ class MultimodalInputs:
|
|
350
328
|
audio_start_id: Optional[int] = None
|
351
329
|
audio_end_id: Optional[int] = None
|
352
330
|
|
331
|
+
# QWen2-VL related
|
332
|
+
mrope_positions: Optional[torch.Tensor] = None
|
333
|
+
mrope_position_delta: Optional[torch.Tensor] = None
|
334
|
+
|
353
335
|
@staticmethod
|
354
336
|
def from_dict(obj: dict):
|
355
337
|
ret = MultimodalInputs(
|
@@ -477,6 +459,9 @@ class Req:
|
|
477
459
|
# for corss-endoder model
|
478
460
|
self.token_type_ids = token_type_ids
|
479
461
|
|
462
|
+
# The length of KV that have been removed in local attention chunked prefill
|
463
|
+
self.evicted_seqlen_local = 0
|
464
|
+
|
480
465
|
# Sampling info
|
481
466
|
if isinstance(sampling_params.custom_params, dict):
|
482
467
|
sampling_params = copy.copy(sampling_params)
|
@@ -855,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
855
840
|
# For DP attention
|
856
841
|
global_num_tokens: Optional[List[int]] = None
|
857
842
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
843
|
+
is_extend_in_batch: bool = False
|
858
844
|
can_run_dp_cuda_graph: bool = False
|
859
845
|
is_extend_in_batch: bool = False
|
860
846
|
tbo_split_seq_index: Optional[int] = None
|
@@ -1183,6 +1169,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1183
1169
|
self.req_to_token_pool.write(
|
1184
1170
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1185
1171
|
)
|
1172
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1173
|
+
self.tree_cache.evict(
|
1174
|
+
req, pre_len, self.model_config.attention_chunk_size
|
1175
|
+
)
|
1186
1176
|
|
1187
1177
|
# If input_embeds are available, store them
|
1188
1178
|
if req.input_embeds is not None:
|
@@ -1375,7 +1365,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1375
1365
|
* buf_multiplier
|
1376
1366
|
* self.token_to_kv_pool_allocator.page_size
|
1377
1367
|
)
|
1378
|
-
|
1379
1368
|
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
1380
1369
|
return True
|
1381
1370
|
|
@@ -1556,6 +1545,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1556
1545
|
self.seq_lens.add_(1)
|
1557
1546
|
self.seq_lens_sum += bs
|
1558
1547
|
|
1548
|
+
# free memory
|
1549
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1550
|
+
for req in self.reqs:
|
1551
|
+
self.tree_cache.evict(
|
1552
|
+
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
1553
|
+
)
|
1554
|
+
|
1559
1555
|
# Allocate memory
|
1560
1556
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
1561
1557
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
@@ -1686,6 +1682,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1686
1682
|
)
|
1687
1683
|
or global_server_args_dict["attention_backend"] == "flashmla"
|
1688
1684
|
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
1685
|
+
or global_server_args_dict["attention_backend"] == "ascend"
|
1689
1686
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1690
1687
|
):
|
1691
1688
|
seq_lens_cpu = (
|
@@ -1718,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1718
1715
|
token_ids_logprobs=self.token_ids_logprobs,
|
1719
1716
|
global_num_tokens=self.global_num_tokens,
|
1720
1717
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1718
|
+
is_extend_in_batch=self.is_extend_in_batch,
|
1721
1719
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1722
1720
|
tbo_split_seq_index=self.tbo_split_seq_index,
|
1723
1721
|
global_forward_mode=self.global_forward_mode,
|
@@ -1790,7 +1788,6 @@ class ModelWorkerBatch:
|
|
1790
1788
|
seq_lens: torch.Tensor
|
1791
1789
|
# The indices of output tokens in the token_to_kv_pool_allocator
|
1792
1790
|
out_cache_loc: torch.Tensor
|
1793
|
-
|
1794
1791
|
# The sequence length tensor on CPU
|
1795
1792
|
seq_lens_cpu: Optional[torch.Tensor]
|
1796
1793
|
seq_lens_sum: int
|
@@ -1803,6 +1800,7 @@ class ModelWorkerBatch:
|
|
1803
1800
|
# For DP attention
|
1804
1801
|
global_num_tokens: Optional[List[int]]
|
1805
1802
|
global_num_tokens_for_logprob: Optional[List[int]]
|
1803
|
+
is_extend_in_batch: bool
|
1806
1804
|
can_run_dp_cuda_graph: bool
|
1807
1805
|
tbo_split_seq_index: Optional[int]
|
1808
1806
|
global_forward_mode: Optional[ForwardMode]
|
@@ -1889,7 +1887,10 @@ def get_last_loc(
|
|
1889
1887
|
req_pool_indices_tensor: torch.Tensor,
|
1890
1888
|
prefix_lens_tensor: torch.Tensor,
|
1891
1889
|
) -> torch.Tensor:
|
1892
|
-
if
|
1890
|
+
if (
|
1891
|
+
global_server_args_dict["attention_backend"] != "ascend"
|
1892
|
+
and global_server_args_dict["attention_backend"] != "torch_native"
|
1893
|
+
):
|
1893
1894
|
impl = get_last_loc_triton
|
1894
1895
|
else:
|
1895
1896
|
impl = get_last_loc_torch
|