sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import logging
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
16
17
  import torch
17
18
 
18
19
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
- import logging
20
-
21
20
  torch._logging.set_logs(dynamo=logging.ERROR)
22
21
  torch._dynamo.config.suppress_errors = True
23
22
 
23
+ logger = logging.getLogger(__name__)
24
+
24
25
  from sglang.global_config import global_config
25
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
27
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
58
59
  CROSS_ATTENTION = auto()
59
60
 
60
61
 
62
+ @dataclass
63
+ class MultiItemScoringParams:
64
+ """Parameters for multi-item scoring in attention computation.
65
+
66
+ Used when processing sequences with multiple items separated by delimiters,
67
+ where each item needs specific attention patterns that respect item boundaries.
68
+
69
+ Attributes:
70
+ prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
71
+ The tensor size is equal to the batch size.
72
+ token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
73
+ starting from 0 (delimiter) for each item. For batch size > 1,
74
+ sequences are concatenated with zero padding to ensure same length.
75
+ token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
76
+ batch_size > 1 case. Defines the padded length for each sequence.
77
+ max_item_len_ptr: A uint16 tensor containing the max token length of all items
78
+ for each prompt in the batch.
79
+
80
+ """
81
+
82
+ prefix_len_ptr: Optional[torch.Tensor] = None
83
+ token_pos_in_items_ptr: Optional[torch.Tensor] = None
84
+ token_pos_in_items_len: int = 0
85
+ max_item_len_ptr: Optional[torch.Tensor] = None
86
+
87
+ def is_enabled(self) -> bool:
88
+ """Check if multi-item scoring is enabled."""
89
+ return self.prefix_len_ptr is not None
90
+
91
+
61
92
  @dataclass
62
93
  class DecodeMetadata:
63
94
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@@ -68,6 +99,7 @@ class PrefillMetadata:
68
99
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
69
100
  use_ragged: bool
70
101
  extend_no_prefix: bool
102
+ multi_item_params: Optional[MultiItemScoringParams] = None
71
103
 
72
104
 
73
105
  # Reuse this workspace buffer across all flashinfer wrappers
@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
90
122
  ):
91
123
  super().__init__()
92
124
 
125
+ # Store multi-item scoring delimiter for efficient access
126
+ self.multi_item_scoring_delimiter = (
127
+ model_runner.server_args.multi_item_scoring_delimiter
128
+ )
129
+
93
130
  # Parse constants
94
131
  self.decode_use_tensor_cores = should_use_tensor_core(
95
132
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
229
266
 
230
267
  # Other metadata
231
268
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
269
+
232
270
  self.decode_cuda_graph_metadata = {}
233
271
  self.prefill_cuda_graph_metadata = {} # For verify
234
272
  self.draft_extend_cuda_graph_metadata = {} # For draft extend
235
273
 
274
+ def _process_multi_item_scoring(
275
+ self, forward_batch: ForwardBatch
276
+ ) -> MultiItemScoringParams:
277
+ """Process multi-item scoring tensors for FlashInfer attention.
278
+
279
+ This method handles sequences containing multiple "items" separated by delimiter tokens,
280
+ where each item needs specific attention patterns that respect item boundaries.
281
+
282
+ The method produces four key tensors for FlashInfer:
283
+ - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
284
+ - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
285
+ - token_pos_in_items_len: padding length for batch processing
286
+ - max_item_len_ptr: uint16 tensor with max item length for each prompt
287
+
288
+ Args:
289
+ forward_batch: The forward batch containing input sequences and delimiter info
290
+
291
+ Returns:
292
+ MultiItemScoringParams: The processed multi-item scoring parameters
293
+
294
+ Examples:
295
+ Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
296
+ token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
297
+
298
+ Case 1: Single sequence
299
+ Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
300
+ Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
301
+ Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
302
+ - prefix_len_ptr: [7] (query length before first delimiter)
303
+ - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
304
+ - token_pos_in_items_len: 7 (actual length)
305
+ - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
306
+
307
+ Case 2: Batch processing (batch_size=2)
308
+ Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
309
+ Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
310
+ After padding both to length 10:
311
+ - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
312
+ - token_pos_in_items_len: 10 (padded length for batch processing)
313
+ - max_item_len_ptr: [2, 3] (max lengths per sequence)
314
+ """
315
+
316
+ delimiter = self.multi_item_scoring_delimiter
317
+ if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
318
+ return MultiItemScoringParams()
319
+
320
+ delimiter_mask = forward_batch.input_ids == delimiter
321
+ prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
322
+ extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
323
+ prefix_len_ptr, token_pos_in_items_ptr = [], []
324
+ token_pos_in_items_len = 0
325
+
326
+ # If no extend_seq_lens, treat whole batch as one sequence
327
+ if extend_seq_lens is None or len(extend_seq_lens) <= 1:
328
+ extend_seq_lens = [forward_batch.input_ids.size(0)]
329
+
330
+ seq_start = 0
331
+ for i, seq_len in enumerate(extend_seq_lens):
332
+ seq_end = seq_start + seq_len
333
+ mask = delimiter_mask[seq_start:seq_end]
334
+ pos = forward_batch.positions[seq_start:seq_end]
335
+ delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
336
+
337
+ if len(delimiter_indices) > 0:
338
+ first_delim = delimiter_indices[0]
339
+ # Prefix length: store as scalar
340
+ prefix_len = first_delim + (
341
+ prefix_cache_lens[i] if prefix_cache_lens is not None else 0
342
+ )
343
+ prefix_len_ptr.append(
344
+ prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
345
+ )
346
+
347
+ # Compute relative positions within items after delimiters
348
+ diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
349
+ token_pos = (diff - pos[first_delim]).to(torch.uint16)
350
+ token_pos_in_items_ptr.append(token_pos)
351
+
352
+ # Update forward_batch positions in-place
353
+ pos[first_delim:] = diff - 1
354
+ forward_batch.positions[seq_start:seq_end] = pos
355
+
356
+ seq_start = seq_end
357
+
358
+ # Pad token_pos_in_items_ptr for batch processing
359
+ if token_pos_in_items_ptr:
360
+ token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
361
+ device = forward_batch.input_ids.device
362
+ token_pos_in_items_ptr = [
363
+ torch.cat(
364
+ [
365
+ t,
366
+ torch.zeros(
367
+ token_pos_in_items_len - t.numel(),
368
+ dtype=torch.uint16,
369
+ device=device,
370
+ ),
371
+ ]
372
+ )
373
+ for t in token_pos_in_items_ptr
374
+ ]
375
+
376
+ if not prefix_len_ptr or not token_pos_in_items_ptr:
377
+ return MultiItemScoringParams()
378
+
379
+ # Build final params
380
+ device = forward_batch.input_ids.device
381
+ return MultiItemScoringParams(
382
+ prefix_len_ptr=torch.tensor(
383
+ prefix_len_ptr, dtype=torch.uint32, device=device
384
+ ),
385
+ token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
386
+ token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
387
+ max_item_len_ptr=torch.stack(
388
+ [
389
+ t.to(torch.int32).max().to(torch.uint16)
390
+ for t in token_pos_in_items_ptr
391
+ ],
392
+ dim=0,
393
+ ),
394
+ )
395
+
236
396
  def init_forward_metadata(self, forward_batch: ForwardBatch):
237
397
  if forward_batch.forward_mode.is_decode_or_idle():
238
398
  self.indices_updater_decode.update(
@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
280
440
  else:
281
441
  prefix_lens = forward_batch.extend_prefix_lens
282
442
 
283
- if self.is_multimodal:
443
+ # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
444
+ if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
445
+ # use_ragged = False: Multi-item scoring requires the paged wrapper because:
446
+ # 1. Ragged wrapper doesn't support the specialized multi-item parameters
447
+ # (prefix_len_ptr, token_pos_in_items_ptr, etc.)
448
+ # 2. Paged wrapper provides better control over attention masking needed
449
+ # for respecting item boundaries in multi-item sequences
450
+ # 3. Custom masking logic conflicts with ragged wrapper's assumptions
284
451
  use_ragged = False
285
452
  extend_no_prefix = False
286
453
  else:
287
454
  use_ragged = not self.enable_deterministic
288
455
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
289
456
 
457
+ # Process multi-item scoring in attention backend instead of ForwardBatch
458
+ multi_item_params = MultiItemScoringParams()
459
+ if self.multi_item_scoring_delimiter is not None:
460
+ # Use new backend-specific implementation
461
+ multi_item_params = self._process_multi_item_scoring(forward_batch)
462
+
290
463
  self.indices_updater_prefill.update(
291
464
  forward_batch.req_pool_indices,
292
465
  forward_batch.seq_lens,
@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
298
471
  encoder_lens=forward_batch.encoder_lens,
299
472
  spec_info=None,
300
473
  fixed_split_size=self.prefill_split_tile_size,
474
+ multi_item_params=multi_item_params,
301
475
  )
302
476
  self.forward_metadata = PrefillMetadata(
303
- self.prefill_wrappers_paged, use_ragged, extend_no_prefix
477
+ self.prefill_wrappers_paged,
478
+ use_ragged,
479
+ extend_no_prefix,
480
+ multi_item_params,
304
481
  )
305
482
 
306
483
  def init_cuda_graph_state(
@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
531
708
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
532
709
  causal=not layer.is_cross_attention,
533
710
  sm_scale=layer.scaling,
534
- window_left=layer.sliding_window_size,
711
+ # Disable sliding window attention for multi-item scoring:
712
+ # - Sliding window could cut across item boundaries, breaking semantic coherence
713
+ # - Multi-item sequences need full attention to properly handle delimiter tokens
714
+ # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
715
+ # provide more precise attention control than simple sliding windows
716
+ # - Item-aware masking takes precedence over window-based masking
717
+ window_left=(
718
+ layer.sliding_window_size
719
+ if not (
720
+ self.forward_metadata.multi_item_params
721
+ and self.forward_metadata.multi_item_params.is_enabled()
722
+ )
723
+ else -1
724
+ ),
535
725
  logits_soft_cap=logits_soft_cap,
536
726
  # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
537
727
  k_scale=layer.k_scale_float,
@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
952
1142
  encoder_lens: Optional[torch.Tensor],
953
1143
  spec_info: Optional[SpecInput],
954
1144
  fixed_split_size: Optional[int] = None,
1145
+ multi_item_params: Optional[MultiItemScoringParams] = None,
955
1146
  ):
956
1147
  if use_ragged:
957
1148
  # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
976
1167
  use_ragged,
977
1168
  spec_info,
978
1169
  fixed_split_size=fixed_split_size,
1170
+ multi_item_params=multi_item_params,
979
1171
  )
980
1172
 
981
1173
  def update_sliding_window(
@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
990
1182
  encoder_lens: Optional[torch.Tensor],
991
1183
  spec_info: Optional[SpecInput],
992
1184
  fixed_split_size: Optional[int] = None,
1185
+ multi_item_params: Optional[MultiItemScoringParams] = None,
993
1186
  ):
994
1187
  for wrapper_id in range(2):
995
1188
  if wrapper_id == 0:
@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
1023
1216
  use_ragged,
1024
1217
  spec_info,
1025
1218
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
1219
+ multi_item_params=multi_item_params,
1026
1220
  )
1027
1221
 
1028
1222
  def update_cross_attention(
@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
1037
1231
  encoder_lens: Optional[torch.Tensor],
1038
1232
  spec_info: Optional[SpecInput],
1039
1233
  fixed_split_size: Optional[int] = None,
1234
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1040
1235
  ):
1041
1236
  for wrapper_id in range(2):
1042
1237
  if wrapper_id == 0:
@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
1063
1258
  self.qo_indptr[wrapper_id],
1064
1259
  use_ragged,
1065
1260
  spec_info,
1261
+ multi_item_params=multi_item_params,
1066
1262
  )
1067
1263
 
1068
1264
  def call_begin_forward(
@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
1081
1277
  spec_info: Optional[SpecInput],
1082
1278
  use_sliding_window_kv_pool: bool = False,
1083
1279
  fixed_split_size: Optional[int] = None,
1280
+ multi_item_params: Optional[MultiItemScoringParams] = None,
1084
1281
  ):
1085
1282
  bs = len(seq_lens)
1086
1283
  if spec_info is None:
@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
1136
1333
  )
1137
1334
 
1138
1335
  # cached part
1336
+ # Conditionally set multi-item parameters
1337
+ if multi_item_params is not None and multi_item_params.is_enabled():
1338
+ # Multi-item scoring is active - use specialized parameters and disable generic custom_mask
1339
+ use_custom_mask = None
1340
+ prefix_len_ptr = multi_item_params.prefix_len_ptr
1341
+ token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
1342
+ token_pos_in_items_len = multi_item_params.token_pos_in_items_len
1343
+ max_item_len_ptr = multi_item_params.max_item_len_ptr
1344
+ else:
1345
+ # No multi-item scoring - use standard parameters
1346
+ use_custom_mask = custom_mask
1347
+ prefix_len_ptr = None
1348
+ token_pos_in_items_ptr = None
1349
+ token_pos_in_items_len = 0
1350
+ max_item_len_ptr = None
1351
+
1139
1352
  wrapper_paged.begin_forward(
1140
1353
  qo_indptr,
1141
1354
  kv_indptr,
@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
1147
1360
  1,
1148
1361
  q_data_type=self.q_data_type,
1149
1362
  kv_data_type=self.data_type,
1150
- custom_mask=custom_mask,
1363
+ custom_mask=use_custom_mask,
1151
1364
  non_blocking=True,
1152
1365
  fixed_split_size=fixed_split_size,
1366
+ prefix_len_ptr=prefix_len_ptr,
1367
+ token_pos_in_items_ptr=token_pos_in_items_ptr,
1368
+ token_pos_in_items_len=token_pos_in_items_len,
1369
+ max_item_len_ptr=max_item_len_ptr,
1153
1370
  )
1154
1371
 
1155
1372
 
@@ -1060,7 +1060,7 @@ def fast_mla_decode_plan(
1060
1060
 
1061
1061
  try:
1062
1062
  # Standard version with just the required arguments (no use_profiler)
1063
- self._cached_module.plan.default(
1063
+ self._cached_module.plan(
1064
1064
  self._float_workspace_buffer,
1065
1065
  self._int_workspace_buffer,
1066
1066
  self._pin_memory_int_workspace_buffer,