sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import logging
|
10
11
|
import os
|
11
12
|
from dataclasses import dataclass
|
12
13
|
from enum import Enum, auto
|
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
16
17
|
import torch
|
17
18
|
|
18
19
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
19
|
-
import logging
|
20
|
-
|
21
20
|
torch._logging.set_logs(dynamo=logging.ERROR)
|
22
21
|
torch._dynamo.config.suppress_errors = True
|
23
22
|
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
24
25
|
from sglang.global_config import global_config
|
25
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
26
27
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
|
|
58
59
|
CROSS_ATTENTION = auto()
|
59
60
|
|
60
61
|
|
62
|
+
@dataclass
|
63
|
+
class MultiItemScoringParams:
|
64
|
+
"""Parameters for multi-item scoring in attention computation.
|
65
|
+
|
66
|
+
Used when processing sequences with multiple items separated by delimiters,
|
67
|
+
where each item needs specific attention patterns that respect item boundaries.
|
68
|
+
|
69
|
+
Attributes:
|
70
|
+
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
|
71
|
+
The tensor size is equal to the batch size.
|
72
|
+
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
|
73
|
+
starting from 0 (delimiter) for each item. For batch size > 1,
|
74
|
+
sequences are concatenated with zero padding to ensure same length.
|
75
|
+
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
|
76
|
+
batch_size > 1 case. Defines the padded length for each sequence.
|
77
|
+
max_item_len_ptr: A uint16 tensor containing the max token length of all items
|
78
|
+
for each prompt in the batch.
|
79
|
+
|
80
|
+
"""
|
81
|
+
|
82
|
+
prefix_len_ptr: Optional[torch.Tensor] = None
|
83
|
+
token_pos_in_items_ptr: Optional[torch.Tensor] = None
|
84
|
+
token_pos_in_items_len: int = 0
|
85
|
+
max_item_len_ptr: Optional[torch.Tensor] = None
|
86
|
+
|
87
|
+
def is_enabled(self) -> bool:
|
88
|
+
"""Check if multi-item scoring is enabled."""
|
89
|
+
return self.prefix_len_ptr is not None
|
90
|
+
|
91
|
+
|
61
92
|
@dataclass
|
62
93
|
class DecodeMetadata:
|
63
94
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
@@ -68,6 +99,7 @@ class PrefillMetadata:
|
|
68
99
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
69
100
|
use_ragged: bool
|
70
101
|
extend_no_prefix: bool
|
102
|
+
multi_item_params: Optional[MultiItemScoringParams] = None
|
71
103
|
|
72
104
|
|
73
105
|
# Reuse this workspace buffer across all flashinfer wrappers
|
@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
90
122
|
):
|
91
123
|
super().__init__()
|
92
124
|
|
125
|
+
# Store multi-item scoring delimiter for efficient access
|
126
|
+
self.multi_item_scoring_delimiter = (
|
127
|
+
model_runner.server_args.multi_item_scoring_delimiter
|
128
|
+
)
|
129
|
+
|
93
130
|
# Parse constants
|
94
131
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
95
132
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
229
266
|
|
230
267
|
# Other metadata
|
231
268
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
269
|
+
|
232
270
|
self.decode_cuda_graph_metadata = {}
|
233
271
|
self.prefill_cuda_graph_metadata = {} # For verify
|
234
272
|
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
235
273
|
|
274
|
+
def _process_multi_item_scoring(
|
275
|
+
self, forward_batch: ForwardBatch
|
276
|
+
) -> MultiItemScoringParams:
|
277
|
+
"""Process multi-item scoring tensors for FlashInfer attention.
|
278
|
+
|
279
|
+
This method handles sequences containing multiple "items" separated by delimiter tokens,
|
280
|
+
where each item needs specific attention patterns that respect item boundaries.
|
281
|
+
|
282
|
+
The method produces four key tensors for FlashInfer:
|
283
|
+
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
|
284
|
+
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
|
285
|
+
- token_pos_in_items_len: padding length for batch processing
|
286
|
+
- max_item_len_ptr: uint16 tensor with max item length for each prompt
|
287
|
+
|
288
|
+
Args:
|
289
|
+
forward_batch: The forward batch containing input sequences and delimiter info
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
MultiItemScoringParams: The processed multi-item scoring parameters
|
293
|
+
|
294
|
+
Examples:
|
295
|
+
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
|
296
|
+
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
|
297
|
+
|
298
|
+
Case 1: Single sequence
|
299
|
+
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
|
300
|
+
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
|
301
|
+
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
302
|
+
- prefix_len_ptr: [7] (query length before first delimiter)
|
303
|
+
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
|
304
|
+
- token_pos_in_items_len: 7 (actual length)
|
305
|
+
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
|
306
|
+
|
307
|
+
Case 2: Batch processing (batch_size=2)
|
308
|
+
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
|
309
|
+
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
|
310
|
+
After padding both to length 10:
|
311
|
+
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
|
312
|
+
- token_pos_in_items_len: 10 (padded length for batch processing)
|
313
|
+
- max_item_len_ptr: [2, 3] (max lengths per sequence)
|
314
|
+
"""
|
315
|
+
|
316
|
+
delimiter = self.multi_item_scoring_delimiter
|
317
|
+
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
|
318
|
+
return MultiItemScoringParams()
|
319
|
+
|
320
|
+
delimiter_mask = forward_batch.input_ids == delimiter
|
321
|
+
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
|
322
|
+
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
|
323
|
+
prefix_len_ptr, token_pos_in_items_ptr = [], []
|
324
|
+
token_pos_in_items_len = 0
|
325
|
+
|
326
|
+
# If no extend_seq_lens, treat whole batch as one sequence
|
327
|
+
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
|
328
|
+
extend_seq_lens = [forward_batch.input_ids.size(0)]
|
329
|
+
|
330
|
+
seq_start = 0
|
331
|
+
for i, seq_len in enumerate(extend_seq_lens):
|
332
|
+
seq_end = seq_start + seq_len
|
333
|
+
mask = delimiter_mask[seq_start:seq_end]
|
334
|
+
pos = forward_batch.positions[seq_start:seq_end]
|
335
|
+
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
|
336
|
+
|
337
|
+
if len(delimiter_indices) > 0:
|
338
|
+
first_delim = delimiter_indices[0]
|
339
|
+
# Prefix length: store as scalar
|
340
|
+
prefix_len = first_delim + (
|
341
|
+
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
|
342
|
+
)
|
343
|
+
prefix_len_ptr.append(
|
344
|
+
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
|
345
|
+
)
|
346
|
+
|
347
|
+
# Compute relative positions within items after delimiters
|
348
|
+
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
|
349
|
+
token_pos = (diff - pos[first_delim]).to(torch.uint16)
|
350
|
+
token_pos_in_items_ptr.append(token_pos)
|
351
|
+
|
352
|
+
# Update forward_batch positions in-place
|
353
|
+
pos[first_delim:] = diff - 1
|
354
|
+
forward_batch.positions[seq_start:seq_end] = pos
|
355
|
+
|
356
|
+
seq_start = seq_end
|
357
|
+
|
358
|
+
# Pad token_pos_in_items_ptr for batch processing
|
359
|
+
if token_pos_in_items_ptr:
|
360
|
+
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
|
361
|
+
device = forward_batch.input_ids.device
|
362
|
+
token_pos_in_items_ptr = [
|
363
|
+
torch.cat(
|
364
|
+
[
|
365
|
+
t,
|
366
|
+
torch.zeros(
|
367
|
+
token_pos_in_items_len - t.numel(),
|
368
|
+
dtype=torch.uint16,
|
369
|
+
device=device,
|
370
|
+
),
|
371
|
+
]
|
372
|
+
)
|
373
|
+
for t in token_pos_in_items_ptr
|
374
|
+
]
|
375
|
+
|
376
|
+
if not prefix_len_ptr or not token_pos_in_items_ptr:
|
377
|
+
return MultiItemScoringParams()
|
378
|
+
|
379
|
+
# Build final params
|
380
|
+
device = forward_batch.input_ids.device
|
381
|
+
return MultiItemScoringParams(
|
382
|
+
prefix_len_ptr=torch.tensor(
|
383
|
+
prefix_len_ptr, dtype=torch.uint32, device=device
|
384
|
+
),
|
385
|
+
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
|
386
|
+
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
|
387
|
+
max_item_len_ptr=torch.stack(
|
388
|
+
[
|
389
|
+
t.to(torch.int32).max().to(torch.uint16)
|
390
|
+
for t in token_pos_in_items_ptr
|
391
|
+
],
|
392
|
+
dim=0,
|
393
|
+
),
|
394
|
+
)
|
395
|
+
|
236
396
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
237
397
|
if forward_batch.forward_mode.is_decode_or_idle():
|
238
398
|
self.indices_updater_decode.update(
|
@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
280
440
|
else:
|
281
441
|
prefix_lens = forward_batch.extend_prefix_lens
|
282
442
|
|
283
|
-
|
443
|
+
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
|
444
|
+
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
|
445
|
+
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
|
446
|
+
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
|
447
|
+
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
|
448
|
+
# 2. Paged wrapper provides better control over attention masking needed
|
449
|
+
# for respecting item boundaries in multi-item sequences
|
450
|
+
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
|
284
451
|
use_ragged = False
|
285
452
|
extend_no_prefix = False
|
286
453
|
else:
|
287
454
|
use_ragged = not self.enable_deterministic
|
288
455
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
289
456
|
|
457
|
+
# Process multi-item scoring in attention backend instead of ForwardBatch
|
458
|
+
multi_item_params = MultiItemScoringParams()
|
459
|
+
if self.multi_item_scoring_delimiter is not None:
|
460
|
+
# Use new backend-specific implementation
|
461
|
+
multi_item_params = self._process_multi_item_scoring(forward_batch)
|
462
|
+
|
290
463
|
self.indices_updater_prefill.update(
|
291
464
|
forward_batch.req_pool_indices,
|
292
465
|
forward_batch.seq_lens,
|
@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
298
471
|
encoder_lens=forward_batch.encoder_lens,
|
299
472
|
spec_info=None,
|
300
473
|
fixed_split_size=self.prefill_split_tile_size,
|
474
|
+
multi_item_params=multi_item_params,
|
301
475
|
)
|
302
476
|
self.forward_metadata = PrefillMetadata(
|
303
|
-
self.prefill_wrappers_paged,
|
477
|
+
self.prefill_wrappers_paged,
|
478
|
+
use_ragged,
|
479
|
+
extend_no_prefix,
|
480
|
+
multi_item_params,
|
304
481
|
)
|
305
482
|
|
306
483
|
def init_cuda_graph_state(
|
@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
531
708
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
532
709
|
causal=not layer.is_cross_attention,
|
533
710
|
sm_scale=layer.scaling,
|
534
|
-
|
711
|
+
# Disable sliding window attention for multi-item scoring:
|
712
|
+
# - Sliding window could cut across item boundaries, breaking semantic coherence
|
713
|
+
# - Multi-item sequences need full attention to properly handle delimiter tokens
|
714
|
+
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
|
715
|
+
# provide more precise attention control than simple sliding windows
|
716
|
+
# - Item-aware masking takes precedence over window-based masking
|
717
|
+
window_left=(
|
718
|
+
layer.sliding_window_size
|
719
|
+
if not (
|
720
|
+
self.forward_metadata.multi_item_params
|
721
|
+
and self.forward_metadata.multi_item_params.is_enabled()
|
722
|
+
)
|
723
|
+
else -1
|
724
|
+
),
|
535
725
|
logits_soft_cap=logits_soft_cap,
|
536
726
|
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
537
727
|
k_scale=layer.k_scale_float,
|
@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
952
1142
|
encoder_lens: Optional[torch.Tensor],
|
953
1143
|
spec_info: Optional[SpecInput],
|
954
1144
|
fixed_split_size: Optional[int] = None,
|
1145
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
955
1146
|
):
|
956
1147
|
if use_ragged:
|
957
1148
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
976
1167
|
use_ragged,
|
977
1168
|
spec_info,
|
978
1169
|
fixed_split_size=fixed_split_size,
|
1170
|
+
multi_item_params=multi_item_params,
|
979
1171
|
)
|
980
1172
|
|
981
1173
|
def update_sliding_window(
|
@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
990
1182
|
encoder_lens: Optional[torch.Tensor],
|
991
1183
|
spec_info: Optional[SpecInput],
|
992
1184
|
fixed_split_size: Optional[int] = None,
|
1185
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
993
1186
|
):
|
994
1187
|
for wrapper_id in range(2):
|
995
1188
|
if wrapper_id == 0:
|
@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1023
1216
|
use_ragged,
|
1024
1217
|
spec_info,
|
1025
1218
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
1219
|
+
multi_item_params=multi_item_params,
|
1026
1220
|
)
|
1027
1221
|
|
1028
1222
|
def update_cross_attention(
|
@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1037
1231
|
encoder_lens: Optional[torch.Tensor],
|
1038
1232
|
spec_info: Optional[SpecInput],
|
1039
1233
|
fixed_split_size: Optional[int] = None,
|
1234
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
1040
1235
|
):
|
1041
1236
|
for wrapper_id in range(2):
|
1042
1237
|
if wrapper_id == 0:
|
@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1063
1258
|
self.qo_indptr[wrapper_id],
|
1064
1259
|
use_ragged,
|
1065
1260
|
spec_info,
|
1261
|
+
multi_item_params=multi_item_params,
|
1066
1262
|
)
|
1067
1263
|
|
1068
1264
|
def call_begin_forward(
|
@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1081
1277
|
spec_info: Optional[SpecInput],
|
1082
1278
|
use_sliding_window_kv_pool: bool = False,
|
1083
1279
|
fixed_split_size: Optional[int] = None,
|
1280
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
1084
1281
|
):
|
1085
1282
|
bs = len(seq_lens)
|
1086
1283
|
if spec_info is None:
|
@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1136
1333
|
)
|
1137
1334
|
|
1138
1335
|
# cached part
|
1336
|
+
# Conditionally set multi-item parameters
|
1337
|
+
if multi_item_params is not None and multi_item_params.is_enabled():
|
1338
|
+
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
|
1339
|
+
use_custom_mask = None
|
1340
|
+
prefix_len_ptr = multi_item_params.prefix_len_ptr
|
1341
|
+
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
|
1342
|
+
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
|
1343
|
+
max_item_len_ptr = multi_item_params.max_item_len_ptr
|
1344
|
+
else:
|
1345
|
+
# No multi-item scoring - use standard parameters
|
1346
|
+
use_custom_mask = custom_mask
|
1347
|
+
prefix_len_ptr = None
|
1348
|
+
token_pos_in_items_ptr = None
|
1349
|
+
token_pos_in_items_len = 0
|
1350
|
+
max_item_len_ptr = None
|
1351
|
+
|
1139
1352
|
wrapper_paged.begin_forward(
|
1140
1353
|
qo_indptr,
|
1141
1354
|
kv_indptr,
|
@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1147
1360
|
1,
|
1148
1361
|
q_data_type=self.q_data_type,
|
1149
1362
|
kv_data_type=self.data_type,
|
1150
|
-
custom_mask=
|
1363
|
+
custom_mask=use_custom_mask,
|
1151
1364
|
non_blocking=True,
|
1152
1365
|
fixed_split_size=fixed_split_size,
|
1366
|
+
prefix_len_ptr=prefix_len_ptr,
|
1367
|
+
token_pos_in_items_ptr=token_pos_in_items_ptr,
|
1368
|
+
token_pos_in_items_len=token_pos_in_items_len,
|
1369
|
+
max_item_len_ptr=max_item_len_ptr,
|
1153
1370
|
)
|
1154
1371
|
|
1155
1372
|
|
@@ -1060,7 +1060,7 @@ def fast_mla_decode_plan(
|
|
1060
1060
|
|
1061
1061
|
try:
|
1062
1062
|
# Standard version with just the required arguments (no use_profiler)
|
1063
|
-
self._cached_module.plan
|
1063
|
+
self._cached_module.plan(
|
1064
1064
|
self._float_workspace_buffer,
|
1065
1065
|
self._int_workspace_buffer,
|
1066
1066
|
self._pin_memory_int_workspace_buffer,
|