sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ try:
|
|
32
32
|
mha_batch_prefill_func,
|
33
33
|
paged_attention_ragged,
|
34
34
|
)
|
35
|
-
from aiter.mla import mla_decode_fwd
|
35
|
+
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
36
36
|
except ImportError:
|
37
37
|
print(
|
38
38
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
@@ -52,10 +52,8 @@ class ForwardMetadata:
|
|
52
52
|
kv_indices: torch.Tensor
|
53
53
|
qo_indptr: torch.Tensor
|
54
54
|
kv_last_page_len: torch.Tensor
|
55
|
-
max_extend_len: int
|
56
|
-
max_prefix_extend_len: int
|
57
55
|
max_q_len: int
|
58
|
-
max_kv_len: int
|
56
|
+
max_kv_len: Optional[int]
|
59
57
|
|
60
58
|
|
61
59
|
global_workspace_buffer = None
|
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
|
|
71
69
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
72
70
|
):
|
73
71
|
super().__init__()
|
72
|
+
# Lazy import to avoid the initialization of cuda context
|
73
|
+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
74
|
+
extend_attention_fwd,
|
75
|
+
)
|
76
|
+
|
77
|
+
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
74
78
|
|
75
79
|
self.device = model_runner.device
|
76
80
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
77
81
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
82
|
+
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
78
83
|
self.num_head = (
|
79
84
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
80
85
|
)
|
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
|
|
157
162
|
spec_info = forward_batch.spec_info
|
158
163
|
qo_indptr = None
|
159
164
|
kv_last_page_len = None
|
160
|
-
|
165
|
+
max_q_len = None
|
161
166
|
|
162
167
|
if forward_batch.forward_mode.is_decode_or_idle():
|
163
168
|
if spec_info is None:
|
164
169
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
165
170
|
kv_indptr = kv_indptr[: bs + 1]
|
166
|
-
kv_indices = torch.
|
171
|
+
kv_indices = torch.empty(
|
167
172
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
168
173
|
)
|
169
174
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
|
|
183
188
|
qo_indptr = self.qo_indptr_[: bs + 1]
|
184
189
|
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
185
190
|
kv_last_page_len = self.kv_last_page_len[:bs]
|
186
|
-
|
191
|
+
max_q_len = 1
|
187
192
|
|
188
193
|
self.forward_metadata = ForwardMetadata(
|
189
194
|
kv_indptr,
|
190
195
|
kv_indices,
|
191
196
|
qo_indptr,
|
192
197
|
kv_last_page_len,
|
193
|
-
|
194
|
-
None,
|
195
|
-
None,
|
198
|
+
max_q_len,
|
196
199
|
None,
|
197
200
|
)
|
198
201
|
|
199
202
|
elif forward_batch.forward_mode.is_draft_extend():
|
200
203
|
if self.use_mla:
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
spec_info=None,
|
204
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
205
|
+
spec_info.generate_attn_arg_prefill(
|
206
|
+
forward_batch.req_pool_indices,
|
207
|
+
forward_batch.seq_lens,
|
208
|
+
forward_batch.seq_lens_sum,
|
209
|
+
self.req_to_token,
|
210
|
+
)
|
209
211
|
)
|
210
212
|
self.forward_metadata = ForwardMetadata(
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
self.mla_indices_updater_prefill.kv_last_page_len,
|
215
|
-
self.
|
216
|
-
|
217
|
-
|
218
|
-
None,
|
213
|
+
kv_indptr,
|
214
|
+
kv_indices,
|
215
|
+
qo_indptr,
|
216
|
+
# self.mla_indices_updater_prefill.kv_last_page_len,
|
217
|
+
self.kv_last_page_len[:bs],
|
218
|
+
max(forward_batch.extend_seq_lens_cpu),
|
219
|
+
forward_batch.seq_lens_cpu.max().item(),
|
219
220
|
)
|
220
221
|
else:
|
221
222
|
self.indices_updater_prefill.update(
|
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
|
|
231
232
|
self.indices_updater_prefill.kv_indices,
|
232
233
|
None,
|
233
234
|
None,
|
234
|
-
None,
|
235
|
-
None,
|
236
235
|
self.indices_updater_prefill.max_q_len,
|
237
236
|
self.indices_updater_prefill.max_kv_len,
|
238
237
|
)
|
239
238
|
elif forward_batch.forward_mode.is_target_verify():
|
240
239
|
if self.use_mla:
|
241
|
-
|
242
|
-
|
240
|
+
draft_num = spec_info.draft_token_num
|
241
|
+
kv_lens = forward_batch.seq_lens + draft_num
|
242
|
+
kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
|
243
|
+
device = forward_batch.seq_lens.device
|
244
|
+
|
245
|
+
qo_indptr = torch.arange(
|
246
|
+
0,
|
247
|
+
(1 + bs) * draft_num,
|
248
|
+
step=draft_num,
|
249
|
+
dtype=torch.int32,
|
250
|
+
device=device,
|
251
|
+
)
|
252
|
+
kv_indptr = self.kv_indptr
|
253
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
254
|
+
kv_indptr = kv_indptr[: bs + 1]
|
255
|
+
kv_indices = torch.empty(
|
256
|
+
kv_lens_sum,
|
257
|
+
dtype=torch.int32,
|
258
|
+
device=device,
|
259
|
+
)
|
260
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
261
|
+
self.req_to_token,
|
243
262
|
forward_batch.req_pool_indices,
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
263
|
+
kv_lens,
|
264
|
+
kv_indptr,
|
265
|
+
None,
|
266
|
+
kv_indices,
|
267
|
+
self.req_to_token.stride(0),
|
249
268
|
)
|
250
269
|
self.forward_metadata = ForwardMetadata(
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
self.mla_indices_updater_prefill.kv_last_page_len,
|
255
|
-
self.
|
256
|
-
|
257
|
-
None,
|
270
|
+
kv_indptr,
|
271
|
+
kv_indices,
|
272
|
+
qo_indptr,
|
273
|
+
# self.mla_indices_updater_prefill.kv_last_page_len,
|
274
|
+
self.kv_last_page_len[:bs],
|
275
|
+
draft_num,
|
258
276
|
None,
|
259
277
|
)
|
260
278
|
else:
|
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
|
|
271
289
|
self.indices_updater_prefill.kv_indices,
|
272
290
|
None,
|
273
291
|
None,
|
274
|
-
None,
|
275
|
-
None,
|
276
292
|
self.indices_updater_prefill.max_q_len,
|
277
293
|
self.indices_updater_prefill.max_kv_len,
|
278
294
|
)
|
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
|
|
283
299
|
extend_no_prefix = False
|
284
300
|
else:
|
285
301
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
286
|
-
|
287
302
|
if self.use_mla:
|
288
303
|
self.mla_indices_updater_prefill.update(
|
289
304
|
forward_batch.req_pool_indices,
|
290
|
-
|
291
|
-
|
305
|
+
forward_batch.extend_prefix_lens,
|
306
|
+
sum(forward_batch.extend_prefix_lens_cpu),
|
292
307
|
forward_batch.extend_seq_lens,
|
293
|
-
|
308
|
+
max(forward_batch.extend_seq_lens_cpu),
|
309
|
+
forward_batch.seq_lens_cpu.max().item(),
|
294
310
|
spec_info=None,
|
295
311
|
)
|
312
|
+
self.mla_indices_updater_prefill.kv_indptr += (
|
313
|
+
self.mla_indices_updater_prefill.qo_indptr
|
314
|
+
)
|
296
315
|
self.forward_metadata = ForwardMetadata(
|
297
316
|
self.mla_indices_updater_prefill.kv_indptr,
|
298
317
|
self.mla_indices_updater_prefill.kv_indices,
|
299
318
|
self.mla_indices_updater_prefill.qo_indptr,
|
300
|
-
self.
|
301
|
-
self.mla_indices_updater_prefill.
|
302
|
-
self.mla_indices_updater_prefill.
|
303
|
-
None,
|
304
|
-
None,
|
319
|
+
self.kv_last_page_len[:bs],
|
320
|
+
self.mla_indices_updater_prefill.max_q_len,
|
321
|
+
self.mla_indices_updater_prefill.max_kv_len,
|
305
322
|
)
|
306
323
|
else:
|
307
324
|
self.indices_updater_prefill.update(
|
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
|
|
317
334
|
self.indices_updater_prefill.kv_indices,
|
318
335
|
None,
|
319
336
|
None,
|
320
|
-
None,
|
321
|
-
None,
|
322
337
|
self.indices_updater_prefill.max_q_len,
|
323
338
|
self.indices_updater_prefill.max_kv_len,
|
324
339
|
)
|
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
359
374
|
if forward_mode.is_decode_or_idle():
|
360
375
|
qo_indptr = None
|
361
376
|
kv_last_page_len = None
|
362
|
-
|
377
|
+
max_q_len = None
|
363
378
|
|
364
379
|
if spec_info is None:
|
365
380
|
kv_indptr = self.kv_indptr
|
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
|
|
383
398
|
qo_indptr[1 : bs + 1] = torch.cumsum(
|
384
399
|
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
385
400
|
)
|
386
|
-
max_extend_len = 1
|
387
401
|
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
402
|
+
max_q_len = 1
|
388
403
|
|
389
404
|
self.forward_metadata = ForwardMetadata(
|
390
405
|
kv_indptr,
|
391
406
|
kv_indices,
|
392
407
|
qo_indptr,
|
393
408
|
kv_last_page_len,
|
394
|
-
|
395
|
-
None,
|
396
|
-
None,
|
409
|
+
max_q_len,
|
397
410
|
None,
|
398
411
|
)
|
399
412
|
|
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
|
|
419
432
|
kv_indices,
|
420
433
|
self.req_to_token.stride(0),
|
421
434
|
)
|
422
|
-
|
423
|
-
|
424
|
-
kv_last_page_len = None
|
435
|
+
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
436
|
+
max_q_len = self.num_draft_tokens
|
425
437
|
|
426
438
|
self.forward_metadata = ForwardMetadata(
|
427
439
|
kv_indptr,
|
428
440
|
kv_indices,
|
429
441
|
qo_indptr,
|
430
442
|
kv_last_page_len,
|
431
|
-
|
432
|
-
None,
|
433
|
-
None,
|
443
|
+
max_q_len,
|
434
444
|
None,
|
435
445
|
)
|
436
446
|
else:
|
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
|
|
448
458
|
self.indices_updater_prefill.kv_indices,
|
449
459
|
None,
|
450
460
|
None,
|
451
|
-
None,
|
452
|
-
None,
|
453
461
|
self.indices_updater_prefill.max_q_len,
|
454
462
|
self.indices_updater_prefill.max_kv_len,
|
455
463
|
)
|
456
|
-
|
464
|
+
elif forward_mode.is_draft_extend():
|
465
|
+
num_tokens_per_bs = self.speculative_num_steps + 1
|
466
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
467
|
+
qo_indptr[: bs + 1] = torch.arange(
|
468
|
+
0,
|
469
|
+
bs * num_tokens_per_bs + 1,
|
470
|
+
step=num_tokens_per_bs,
|
471
|
+
dtype=torch.int32,
|
472
|
+
device=self.device,
|
473
|
+
)
|
474
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
475
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
476
|
+
kv_indices = self.cuda_graph_kv_indices
|
477
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
478
|
+
self.req_to_token,
|
479
|
+
req_pool_indices,
|
480
|
+
seq_lens,
|
481
|
+
kv_indptr,
|
482
|
+
None,
|
483
|
+
kv_indices,
|
484
|
+
self.req_to_token.stride(0),
|
485
|
+
)
|
486
|
+
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
487
|
+
max_q_len = num_tokens_per_bs
|
488
|
+
self.forward_metadata = ForwardMetadata(
|
489
|
+
kv_indptr,
|
490
|
+
kv_indices,
|
491
|
+
qo_indptr,
|
492
|
+
kv_last_page_len,
|
493
|
+
max_q_len,
|
494
|
+
None,
|
495
|
+
)
|
457
496
|
else:
|
458
497
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
459
498
|
|
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
|
|
488
527
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
489
528
|
|
490
529
|
elif forward_mode.is_target_verify():
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
530
|
+
bs = len(req_pool_indices)
|
531
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
532
|
+
qo_indptr[: bs + 1] = torch.arange(
|
533
|
+
0,
|
534
|
+
(1 + bs) * self.num_draft_tokens,
|
535
|
+
step=self.num_draft_tokens,
|
536
|
+
dtype=torch.int32,
|
537
|
+
device=self.device,
|
538
|
+
)
|
539
|
+
kv_lens = seq_lens + self.num_draft_tokens
|
540
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
541
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
542
|
+
kv_indices = self.cuda_graph_kv_indices
|
543
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
544
|
+
self.req_to_token,
|
545
|
+
req_pool_indices,
|
546
|
+
kv_lens,
|
547
|
+
kv_indptr,
|
548
|
+
None,
|
549
|
+
kv_indices,
|
550
|
+
self.req_to_token.stride(0),
|
551
|
+
)
|
552
|
+
elif forward_mode.is_draft_extend():
|
553
|
+
seq_lens = seq_lens[:bs]
|
554
|
+
accept_lens = spec_info.accept_length[:bs]
|
555
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
556
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
|
557
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
558
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
559
|
+
kv_indices = self.cuda_graph_kv_indices
|
560
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
561
|
+
self.req_to_token,
|
562
|
+
req_pool_indices,
|
563
|
+
seq_lens,
|
564
|
+
kv_indptr,
|
565
|
+
None,
|
566
|
+
kv_indices,
|
567
|
+
self.req_to_token.stride(0),
|
498
568
|
)
|
499
569
|
else:
|
500
570
|
raise ValueError("Invalid forward mode")
|
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
|
|
530
600
|
)
|
531
601
|
|
532
602
|
if self.use_mla:
|
533
|
-
|
534
|
-
|
603
|
+
max_q_len = self.forward_metadata.max_q_len
|
604
|
+
max_kv_len = self.forward_metadata.max_kv_len
|
535
605
|
kv_indptr = self.forward_metadata.kv_indptr
|
536
606
|
kv_indices = self.forward_metadata.kv_indices
|
537
|
-
kv_last_page_lens = self.forward_metadata.kv_last_page_len
|
538
607
|
qo_indptr = self.forward_metadata.qo_indptr
|
539
608
|
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
540
609
|
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
|
|
552
621
|
v,
|
553
622
|
qo_indptr,
|
554
623
|
qo_indptr,
|
555
|
-
|
556
|
-
|
624
|
+
max_q_len,
|
625
|
+
max_q_len,
|
557
626
|
softmax_scale=layer.scaling,
|
558
627
|
causal=True,
|
559
628
|
)
|
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
|
|
599
668
|
v,
|
600
669
|
qo_indptr,
|
601
670
|
kv_indptr,
|
602
|
-
|
603
|
-
|
671
|
+
max_q_len,
|
672
|
+
max_kv_len,
|
604
673
|
softmax_scale=layer.scaling,
|
605
674
|
causal=True,
|
606
675
|
)
|
607
676
|
return o
|
677
|
+
elif forward_batch.forward_mode.is_target_verify():
|
678
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
679
|
+
mla_decode_fwd(
|
680
|
+
q,
|
681
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
682
|
+
o,
|
683
|
+
self.forward_metadata.qo_indptr,
|
684
|
+
self.forward_metadata.kv_indptr,
|
685
|
+
self.forward_metadata.kv_indices,
|
686
|
+
self.forward_metadata.kv_last_page_len,
|
687
|
+
self.forward_metadata.max_q_len,
|
688
|
+
layer.scaling,
|
689
|
+
layer.logit_cap,
|
690
|
+
)
|
691
|
+
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
|
692
|
+
return o
|
693
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
694
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
695
|
+
causal = True
|
696
|
+
sliding_window_size = -1
|
697
|
+
kv_indptr = self.forward_metadata.kv_indptr
|
698
|
+
kv_indices = self.forward_metadata.kv_indices
|
699
|
+
mla_prefill_fwd(
|
700
|
+
q,
|
701
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
702
|
+
o,
|
703
|
+
self.forward_metadata.qo_indptr,
|
704
|
+
self.forward_metadata.kv_indptr,
|
705
|
+
self.forward_metadata.kv_indices,
|
706
|
+
self.forward_metadata.kv_last_page_len,
|
707
|
+
self.forward_metadata.max_q_len,
|
708
|
+
layer.scaling,
|
709
|
+
layer.logit_cap,
|
710
|
+
)
|
711
|
+
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
|
712
|
+
return o
|
713
|
+
# self.extend_attention_fwd(
|
714
|
+
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
715
|
+
# k.contiguous(),
|
716
|
+
# v.contiguous(),
|
717
|
+
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
718
|
+
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
719
|
+
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
720
|
+
# self.forward_metadata.qo_indptr,
|
721
|
+
# kv_indptr,
|
722
|
+
# kv_indices,
|
723
|
+
# None,
|
724
|
+
# causal,
|
725
|
+
# None,
|
726
|
+
# self.forward_metadata.max_q_len,
|
727
|
+
# layer.scaling,
|
728
|
+
# layer.logit_cap,
|
729
|
+
# sliding_window_size,
|
730
|
+
# )
|
731
|
+
# return o
|
732
|
+
else:
|
733
|
+
raise ValueError(
|
734
|
+
f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
|
735
|
+
)
|
608
736
|
else:
|
609
737
|
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
610
738
|
layer.layer_id
|
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
662
790
|
self.forward_metadata.kv_indptr,
|
663
791
|
self.forward_metadata.kv_indices,
|
664
792
|
self.forward_metadata.kv_last_page_len,
|
665
|
-
self.forward_metadata.
|
793
|
+
self.forward_metadata.max_q_len,
|
666
794
|
layer.scaling,
|
667
795
|
layer.logit_cap,
|
668
796
|
)
|
@@ -720,11 +848,6 @@ class AiterIndicesUpdaterPrefill:
|
|
720
848
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
721
849
|
self.update = self.update_single_wrapper
|
722
850
|
|
723
|
-
# get the last index of the pool
|
724
|
-
self.pool_size = (
|
725
|
-
model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
|
726
|
-
) - 1
|
727
|
-
|
728
851
|
self.kv_indices = None
|
729
852
|
self.max_q_len = 0
|
730
853
|
self.max_kv_len = 0
|
@@ -769,9 +892,8 @@ class AiterIndicesUpdaterPrefill:
|
|
769
892
|
# but the 0 location will be made nan (noqa) in cuda graph capture mode
|
770
893
|
# this will cause the output tensor value becomes nan
|
771
894
|
# WA is to assure that last index of pool not changed
|
772
|
-
kv_indices = torch.
|
773
|
-
|
774
|
-
self.pool_size,
|
895
|
+
kv_indices = torch.empty(
|
896
|
+
paged_kernel_lens_sum + 256,
|
775
897
|
dtype=torch.int32,
|
776
898
|
device=req_pool_indices.device,
|
777
899
|
)
|
@@ -785,6 +907,9 @@ class AiterIndicesUpdaterPrefill:
|
|
785
907
|
self.req_to_token.shape[1],
|
786
908
|
)
|
787
909
|
|
910
|
+
token_num = kv_indptr[-1]
|
911
|
+
kv_indices[token_num:] = kv_indices[0]
|
912
|
+
|
788
913
|
self.max_kv_len = torch.max(paged_kernel_lens).item()
|
789
914
|
|
790
915
|
extend_lens = seq_lens - prefix_lens
|
@@ -819,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
819
944
|
self.kv_indices = None
|
820
945
|
self.qo_indptr = None
|
821
946
|
self.kv_last_page_len = None
|
822
|
-
self.
|
823
|
-
self.
|
947
|
+
self.max_q_len = 0
|
948
|
+
self.max_kv_len = 0
|
824
949
|
|
825
950
|
def update(
|
826
951
|
self,
|
827
952
|
req_pool_indices: torch.Tensor,
|
828
|
-
|
829
|
-
|
953
|
+
kv_lens: torch.Tensor,
|
954
|
+
kv_lens_sum: int,
|
830
955
|
extend_lens: torch.Tensor,
|
831
|
-
|
956
|
+
max_q_len: int,
|
957
|
+
max_kv_len: int,
|
832
958
|
spec_info: Optional[SpecInfo],
|
833
959
|
):
|
834
960
|
# Keep the signature for type checking. It will be assigned during runtime.
|
@@ -837,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
837
963
|
def update_single_wrapper(
|
838
964
|
self,
|
839
965
|
req_pool_indices: torch.Tensor,
|
840
|
-
|
841
|
-
|
966
|
+
kv_lens: torch.Tensor,
|
967
|
+
kv_lens_sum: int,
|
842
968
|
extend_lens: torch.Tensor,
|
843
|
-
|
969
|
+
max_q_len: int,
|
970
|
+
max_kv_len: int,
|
844
971
|
spec_info: Optional[SpecInfo],
|
845
972
|
):
|
846
|
-
|
847
|
-
paged_kernel_lens = prefix_lens
|
848
|
-
paged_kernel_lens_sum = prefix_lens_sum
|
849
|
-
|
850
973
|
bs = len(req_pool_indices)
|
851
974
|
|
852
975
|
kv_indptr = self.attn_backend.kv_indptr
|
853
976
|
|
854
977
|
if spec_info is None:
|
855
978
|
# Normal extend
|
856
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(
|
979
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
857
980
|
kv_indptr = kv_indptr[: bs + 1]
|
858
981
|
kv_indices = torch.empty(
|
859
|
-
|
982
|
+
kv_lens_sum,
|
860
983
|
dtype=torch.int32,
|
861
984
|
device=req_pool_indices.device,
|
862
985
|
)
|
863
986
|
create_flashinfer_kv_indices_triton[(bs,)](
|
864
987
|
self.req_to_token,
|
865
988
|
req_pool_indices,
|
866
|
-
|
989
|
+
kv_lens,
|
867
990
|
kv_indptr,
|
868
991
|
None,
|
869
992
|
kv_indices,
|
@@ -873,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
873
996
|
qo_indptr = self.attn_backend.qo_indptr
|
874
997
|
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
875
998
|
qo_indptr = qo_indptr[: bs + 1]
|
876
|
-
|
877
|
-
max_extend_len = torch.max(extend_lens).item()
|
878
|
-
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
|
879
|
-
kv_indptr += qo_indptr
|
880
999
|
else:
|
881
1000
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
882
1001
|
spec_info.generate_attn_arg_prefill(
|
883
1002
|
req_pool_indices,
|
884
|
-
|
885
|
-
|
1003
|
+
kv_lens,
|
1004
|
+
kv_lens_sum,
|
886
1005
|
self.req_to_token,
|
887
1006
|
)
|
888
1007
|
)
|
@@ -890,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
890
1009
|
self.kv_indptr = kv_indptr
|
891
1010
|
self.kv_indices = kv_indices
|
892
1011
|
self.qo_indptr = qo_indptr
|
893
|
-
self.
|
894
|
-
self.
|
1012
|
+
self.max_q_len = max_q_len
|
1013
|
+
self.max_kv_len = max_kv_len
|
1014
|
+
|
1015
|
+
|
1016
|
+
class AiterMultiStepDraftBackend:
|
1017
|
+
"""
|
1018
|
+
Wrap multiple triton attention backends as one for multiple consecutive
|
1019
|
+
draft decoding steps.
|
1020
|
+
"""
|
1021
|
+
|
1022
|
+
def __init__(
|
1023
|
+
self,
|
1024
|
+
model_runner: ModelRunner,
|
1025
|
+
topk: int,
|
1026
|
+
speculative_num_steps: int,
|
1027
|
+
):
|
1028
|
+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
1029
|
+
|
1030
|
+
self.topk = topk
|
1031
|
+
self.speculative_num_steps = speculative_num_steps
|
1032
|
+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
1033
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
1034
|
+
self.kv_indptr = torch.zeros(
|
1035
|
+
(
|
1036
|
+
self.speculative_num_steps,
|
1037
|
+
max_bs + 1,
|
1038
|
+
),
|
1039
|
+
dtype=torch.int32,
|
1040
|
+
device=model_runner.device,
|
1041
|
+
)
|
1042
|
+
self.attn_backends = []
|
1043
|
+
for i in range(self.speculative_num_steps):
|
1044
|
+
self.attn_backends.append(
|
1045
|
+
AiterAttnBackend(
|
1046
|
+
model_runner,
|
1047
|
+
skip_prefill=True,
|
1048
|
+
kv_indptr_buf=self.kv_indptr[i],
|
1049
|
+
)
|
1050
|
+
)
|
1051
|
+
self.max_context_len = self.attn_backends[0].max_context_len
|
1052
|
+
self.num_head = (
|
1053
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
1054
|
+
)
|
1055
|
+
self.device = model_runner.device
|
1056
|
+
# Cached variables for generate_draft_decode_kv_indices
|
1057
|
+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
1058
|
+
self.page_size = model_runner.server_args.page_size
|
1059
|
+
assert self.page_size == 1, "Page size must be 1"
|
1060
|
+
|
1061
|
+
def common_template(
|
1062
|
+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
1063
|
+
):
|
1064
|
+
num_seqs = forward_batch.batch_size
|
1065
|
+
bs = self.topk * num_seqs
|
1066
|
+
seq_lens_sum = forward_batch.seq_lens_sum
|
1067
|
+
|
1068
|
+
self.generate_draft_decode_kv_indices[
|
1069
|
+
(self.speculative_num_steps, num_seqs, self.topk)
|
1070
|
+
](
|
1071
|
+
forward_batch.req_pool_indices,
|
1072
|
+
forward_batch.req_to_token_pool.req_to_token,
|
1073
|
+
forward_batch.seq_lens,
|
1074
|
+
kv_indices_buffer,
|
1075
|
+
self.kv_indptr,
|
1076
|
+
forward_batch.positions,
|
1077
|
+
self.pool_len,
|
1078
|
+
kv_indices_buffer.shape[1],
|
1079
|
+
self.kv_indptr.shape[1],
|
1080
|
+
triton.next_power_of_2(num_seqs),
|
1081
|
+
triton.next_power_of_2(self.speculative_num_steps),
|
1082
|
+
triton.next_power_of_2(bs),
|
1083
|
+
self.page_size,
|
1084
|
+
)
|
1085
|
+
|
1086
|
+
for i in range(self.speculative_num_steps):
|
1087
|
+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
1088
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
1089
|
+
: seq_lens_sum * self.topk + bs * (i + 1)
|
1090
|
+
]
|
1091
|
+
call_fn(i, forward_batch)
|
1092
|
+
|
1093
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1094
|
+
kv_indices = torch.empty(
|
1095
|
+
(
|
1096
|
+
self.speculative_num_steps,
|
1097
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
1098
|
+
),
|
1099
|
+
dtype=torch.int32,
|
1100
|
+
device=self.device,
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
def call_fn(i, forward_batch):
|
1104
|
+
forward_batch.spec_info.kv_indptr = (
|
1105
|
+
forward_batch.spec_info.kv_indptr.clone()
|
1106
|
+
)
|
1107
|
+
forward_batch.spec_info.kv_indices = (
|
1108
|
+
forward_batch.spec_info.kv_indices.clone()
|
1109
|
+
)
|
1110
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
1111
|
+
|
1112
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
1113
|
+
|
1114
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1115
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
1116
|
+
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
1117
|
+
dtype=torch.int32,
|
1118
|
+
device=self.device,
|
1119
|
+
)
|
1120
|
+
for i in range(self.speculative_num_steps):
|
1121
|
+
self.attn_backends[i].init_cuda_graph_state(
|
1122
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
1126
|
+
def call_fn(i, forward_batch):
|
1127
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
1128
|
+
forward_batch.batch_size,
|
1129
|
+
forward_batch.batch_size * self.topk,
|
1130
|
+
forward_batch.req_pool_indices,
|
1131
|
+
forward_batch.seq_lens,
|
1132
|
+
encoder_lens=None,
|
1133
|
+
forward_mode=ForwardMode.DECODE,
|
1134
|
+
spec_info=forward_batch.spec_info,
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1138
|
+
|
1139
|
+
def init_forward_metadata_replay_cuda_graph(
|
1140
|
+
self, forward_batch: ForwardBatch, bs: int
|
1141
|
+
):
|
1142
|
+
def call_fn(i, forward_batch):
|
1143
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1144
|
+
bs,
|
1145
|
+
forward_batch.req_pool_indices,
|
1146
|
+
forward_batch.seq_lens,
|
1147
|
+
seq_lens_sum=-1,
|
1148
|
+
encoder_lens=None,
|
1149
|
+
forward_mode=ForwardMode.DECODE,
|
1150
|
+
spec_info=forward_batch.spec_info,
|
1151
|
+
seq_lens_cpu=None,
|
1152
|
+
)
|
1153
|
+
|
1154
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|