sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
|
33
33
|
|
34
34
|
import copy
|
35
35
|
import dataclasses
|
36
|
-
import hashlib
|
37
36
|
import logging
|
38
37
|
import threading
|
39
38
|
from enum import Enum, auto
|
@@ -53,10 +52,9 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
53
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
54
53
|
)
|
55
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
56
|
-
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
57
55
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
58
56
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
59
|
-
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
57
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
60
58
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
61
59
|
from sglang.srt.metrics.collector import TimeStats
|
62
60
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
@@ -87,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
87
85
|
"deepep_mode",
|
88
86
|
"enable_ep_moe",
|
89
87
|
"enable_flashinfer_moe",
|
88
|
+
"enable_flashinfer_allreduce_fusion",
|
90
89
|
"moe_dense_tp_size",
|
91
90
|
"ep_dispatch_algorithm",
|
92
91
|
"deepep_config",
|
@@ -96,12 +95,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
96
95
|
"max_micro_batch_size",
|
97
96
|
"disable_shared_experts_fusion",
|
98
97
|
"sampling_backend",
|
99
|
-
"speculative_accept_threshold_acc",
|
100
98
|
"speculative_accept_threshold_single",
|
99
|
+
"speculative_accept_threshold_acc",
|
101
100
|
"torchao_config",
|
102
101
|
"triton_attention_reduce_in_fp32",
|
103
102
|
"num_reserved_decode_tokens",
|
104
103
|
"weight_loader_disable_mmap",
|
104
|
+
"enable_triton_kernel_moe",
|
105
105
|
]
|
106
106
|
|
107
107
|
# Put some global args for easy access
|
@@ -176,50 +176,63 @@ class Modality(Enum):
|
|
176
176
|
VIDEO = auto()
|
177
177
|
AUDIO = auto()
|
178
178
|
|
179
|
+
@staticmethod
|
180
|
+
def from_str(modality_str: str):
|
181
|
+
try:
|
182
|
+
return Modality[modality_str.upper()]
|
183
|
+
except KeyError:
|
184
|
+
raise ValueError(
|
185
|
+
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
|
186
|
+
)
|
187
|
+
|
179
188
|
|
180
189
|
@dataclasses.dataclass
|
181
190
|
class MultimodalDataItem:
|
182
191
|
"""
|
183
|
-
|
192
|
+
One MultimodalDataItem contains all inputs for one modality.
|
193
|
+
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
194
|
+
One for images and one for audio.
|
195
|
+
|
196
|
+
We put the common fields first and the model-specific fields last.
|
184
197
|
"""
|
185
198
|
|
186
199
|
modality: Modality
|
187
|
-
|
188
200
|
hash: int = None
|
189
201
|
pad_value: int = None
|
190
|
-
|
191
|
-
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
192
|
-
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
193
|
-
|
194
202
|
image_sizes: Tuple[int, int] = None
|
195
203
|
image_offsets: Optional[list] = None
|
196
204
|
|
197
205
|
# the real data, pixel_values or audio_features
|
198
206
|
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
199
|
-
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
207
|
+
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
|
208
|
+
audio_features: Union[torch.Tensor, np.ndarray] = None
|
209
|
+
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
210
|
+
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
211
|
+
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
212
|
+
|
213
|
+
# For qwen-vl
|
200
214
|
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
201
|
-
|
215
|
+
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
202
216
|
|
217
|
+
# For deepseek-vl
|
203
218
|
image_emb_mask: Optional[torch.Tensor] = None
|
204
219
|
image_spatial_crop: Optional[torch.Tensor] = None
|
205
|
-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
206
220
|
|
221
|
+
# For minicpmv
|
207
222
|
# [num_images, (n, w, h)]
|
208
223
|
tgt_size: Tuple[int, int] = None
|
209
224
|
|
210
|
-
#
|
211
|
-
|
225
|
+
# For mllama
|
226
|
+
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
227
|
+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
212
228
|
|
213
|
-
|
214
|
-
|
215
|
-
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
229
|
+
# For kimi-vl
|
230
|
+
image_grid_hws: Optional[List[torch.Tensor]] = None
|
216
231
|
|
217
|
-
# gemma3n
|
232
|
+
# For gemma3n
|
218
233
|
input_features: Optional[torch.Tensor] = None
|
219
234
|
input_features_mask: Optional[torch.Tensor] = None
|
220
235
|
|
221
|
-
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
222
|
-
|
223
236
|
@staticmethod
|
224
237
|
def is_empty_list(l):
|
225
238
|
if l is None:
|
@@ -230,63 +243,18 @@ class MultimodalDataItem:
|
|
230
243
|
"""
|
231
244
|
Set the pad value after first hashing the data
|
232
245
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
if isinstance(tensor_list, list):
|
244
|
-
tensor_list = flatten_nested_list(tensor_list)
|
245
|
-
tensor_list = [
|
246
|
-
x.flatten() if isinstance(x, torch.Tensor) else x
|
247
|
-
for x in tensor_list
|
248
|
-
]
|
249
|
-
tensor = torch.concat(tensor_list)
|
250
|
-
if tensor.is_cuda:
|
251
|
-
return gpu_tensor_hash(tensor)
|
252
|
-
tensor = tensor.detach().contiguous()
|
253
|
-
|
254
|
-
if tensor.dtype == torch.bfloat16:
|
255
|
-
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
256
|
-
tensor = tensor.float()
|
257
|
-
|
258
|
-
assert isinstance(tensor, torch.Tensor)
|
259
|
-
if tensor.is_cuda:
|
260
|
-
# TODO: improve this
|
261
|
-
tensor_cpu = tensor.cpu()
|
246
|
+
from sglang.srt.managers.mm_utils import hash_feature
|
247
|
+
|
248
|
+
if self.hash is None:
|
249
|
+
if self.precomputed_features is not None:
|
250
|
+
self.hash = hash_feature(self.precomputed_features)
|
251
|
+
elif self.is_audio():
|
252
|
+
if self.audio_features is not None:
|
253
|
+
self.hash = hash_feature(self.audio_features)
|
254
|
+
elif self.input_features is not None:
|
255
|
+
self.hash = hash_feature(self.input_features)
|
262
256
|
else:
|
263
|
-
|
264
|
-
|
265
|
-
mv = memoryview(tensor_cpu.numpy())
|
266
|
-
return data_hash(mv.tobytes())
|
267
|
-
|
268
|
-
def hash_feature(f):
|
269
|
-
if isinstance(f, list):
|
270
|
-
if isinstance(f[0], torch.Tensor):
|
271
|
-
return tensor_hash(f)
|
272
|
-
return data_hash(tuple(flatten_nested_list(f)))
|
273
|
-
elif isinstance(f, np.ndarray):
|
274
|
-
arr = np.ascontiguousarray(f)
|
275
|
-
arr_bytes = arr.tobytes()
|
276
|
-
return data_hash(arr_bytes)
|
277
|
-
elif isinstance(f, torch.Tensor):
|
278
|
-
return tensor_hash([f])
|
279
|
-
return data_hash(f)
|
280
|
-
|
281
|
-
if self.precomputed_features is not None:
|
282
|
-
self.hash = hash_feature(self.precomputed_features)
|
283
|
-
elif self.is_audio():
|
284
|
-
if self.audio_features is not None:
|
285
|
-
self.hash = hash_feature(self.audio_features)
|
286
|
-
elif self.input_features is not None:
|
287
|
-
self.hash = hash_feature(self.input_features)
|
288
|
-
else:
|
289
|
-
self.hash = hash_feature(self.pixel_values)
|
257
|
+
self.hash = hash_feature(self.pixel_values)
|
290
258
|
|
291
259
|
assert self.hash is not None
|
292
260
|
self.pad_value = self.hash % (1 << 30)
|
@@ -329,6 +297,13 @@ class MultimodalDataItem:
|
|
329
297
|
ret.validate()
|
330
298
|
return ret
|
331
299
|
|
300
|
+
def merge(self, other):
|
301
|
+
self.pixel_values += other.pixel_values
|
302
|
+
self.image_sizes += other.image_sizes
|
303
|
+
self.image_offsets += other.image_offsets
|
304
|
+
self.hash = hash((self.hash, other.hash))
|
305
|
+
self.set_pad_value()
|
306
|
+
|
332
307
|
|
333
308
|
@dataclasses.dataclass
|
334
309
|
class MultimodalInputs:
|
@@ -339,10 +314,6 @@ class MultimodalInputs:
|
|
339
314
|
image_pad_len: Optional[list] = None
|
340
315
|
num_image_tokens: Optional[int] = None
|
341
316
|
|
342
|
-
# QWen2-VL related
|
343
|
-
mrope_positions: Optional[torch.Tensor] = None
|
344
|
-
mrope_position_delta: Optional[torch.Tensor] = None
|
345
|
-
|
346
317
|
# image
|
347
318
|
im_token_id: Optional[int] = None
|
348
319
|
im_start_id: Optional[int] = None
|
@@ -358,6 +329,10 @@ class MultimodalInputs:
|
|
358
329
|
audio_start_id: Optional[int] = None
|
359
330
|
audio_end_id: Optional[int] = None
|
360
331
|
|
332
|
+
# QWen2-VL related
|
333
|
+
mrope_positions: Optional[torch.Tensor] = None
|
334
|
+
mrope_position_delta: Optional[torch.Tensor] = None
|
335
|
+
|
361
336
|
@staticmethod
|
362
337
|
def from_dict(obj: dict):
|
363
338
|
ret = MultimodalInputs(
|
@@ -485,6 +460,9 @@ class Req:
|
|
485
460
|
# for corss-endoder model
|
486
461
|
self.token_type_ids = token_type_ids
|
487
462
|
|
463
|
+
# The length of KV that have been removed in local attention chunked prefill
|
464
|
+
self.evicted_seqlen_local = 0
|
465
|
+
|
488
466
|
# Sampling info
|
489
467
|
if isinstance(sampling_params.custom_params, dict):
|
490
468
|
sampling_params = copy.copy(sampling_params)
|
@@ -863,8 +841,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
863
841
|
# For DP attention
|
864
842
|
global_num_tokens: Optional[List[int]] = None
|
865
843
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
866
|
-
can_run_dp_cuda_graph: bool = False
|
867
844
|
is_extend_in_batch: bool = False
|
845
|
+
can_run_dp_cuda_graph: bool = False
|
868
846
|
tbo_split_seq_index: Optional[int] = None
|
869
847
|
global_forward_mode: Optional[ForwardMode] = None
|
870
848
|
|
@@ -1191,6 +1169,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1191
1169
|
self.req_to_token_pool.write(
|
1192
1170
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1193
1171
|
)
|
1172
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1173
|
+
self.tree_cache.evict(
|
1174
|
+
req, pre_len, self.model_config.attention_chunk_size
|
1175
|
+
)
|
1194
1176
|
|
1195
1177
|
# If input_embeds are available, store them
|
1196
1178
|
if req.input_embeds is not None:
|
@@ -1383,7 +1365,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1383
1365
|
* buf_multiplier
|
1384
1366
|
* self.token_to_kv_pool_allocator.page_size
|
1385
1367
|
)
|
1386
|
-
|
1387
1368
|
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
1388
1369
|
return True
|
1389
1370
|
|
@@ -1564,6 +1545,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1564
1545
|
self.seq_lens.add_(1)
|
1565
1546
|
self.seq_lens_sum += bs
|
1566
1547
|
|
1548
|
+
# free memory
|
1549
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1550
|
+
for req in self.reqs:
|
1551
|
+
self.tree_cache.evict(
|
1552
|
+
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
1553
|
+
)
|
1554
|
+
|
1567
1555
|
# Allocate memory
|
1568
1556
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
1569
1557
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
@@ -1694,6 +1682,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1694
1682
|
)
|
1695
1683
|
or global_server_args_dict["attention_backend"] == "flashmla"
|
1696
1684
|
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
1685
|
+
or global_server_args_dict["attention_backend"] == "ascend"
|
1697
1686
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1698
1687
|
):
|
1699
1688
|
seq_lens_cpu = (
|
@@ -1726,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1726
1715
|
token_ids_logprobs=self.token_ids_logprobs,
|
1727
1716
|
global_num_tokens=self.global_num_tokens,
|
1728
1717
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1718
|
+
is_extend_in_batch=self.is_extend_in_batch,
|
1729
1719
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1730
1720
|
tbo_split_seq_index=self.tbo_split_seq_index,
|
1731
1721
|
global_forward_mode=self.global_forward_mode,
|
@@ -1798,7 +1788,6 @@ class ModelWorkerBatch:
|
|
1798
1788
|
seq_lens: torch.Tensor
|
1799
1789
|
# The indices of output tokens in the token_to_kv_pool_allocator
|
1800
1790
|
out_cache_loc: torch.Tensor
|
1801
|
-
|
1802
1791
|
# The sequence length tensor on CPU
|
1803
1792
|
seq_lens_cpu: Optional[torch.Tensor]
|
1804
1793
|
seq_lens_sum: int
|
@@ -1811,6 +1800,7 @@ class ModelWorkerBatch:
|
|
1811
1800
|
# For DP attention
|
1812
1801
|
global_num_tokens: Optional[List[int]]
|
1813
1802
|
global_num_tokens_for_logprob: Optional[List[int]]
|
1803
|
+
is_extend_in_batch: bool
|
1814
1804
|
can_run_dp_cuda_graph: bool
|
1815
1805
|
tbo_split_seq_index: Optional[int]
|
1816
1806
|
global_forward_mode: Optional[ForwardMode]
|
@@ -1897,7 +1887,10 @@ def get_last_loc(
|
|
1897
1887
|
req_pool_indices_tensor: torch.Tensor,
|
1898
1888
|
prefix_lens_tensor: torch.Tensor,
|
1899
1889
|
) -> torch.Tensor:
|
1900
|
-
if
|
1890
|
+
if (
|
1891
|
+
global_server_args_dict["attention_backend"] != "ascend"
|
1892
|
+
and global_server_args_dict["attention_backend"] != "torch_native"
|
1893
|
+
):
|
1901
1894
|
impl = get_last_loc_triton
|
1902
1895
|
else:
|
1903
1896
|
impl = get_last_loc_torch
|