sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/bench_serving.py +11 -3
  3. sglang/lang/backend/openai.py +10 -0
  4. sglang/srt/configs/model_config.py +11 -2
  5. sglang/srt/constrained/xgrammar_backend.py +6 -0
  6. sglang/srt/layers/attention/__init__.py +0 -1
  7. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  9. sglang/srt/layers/logits_processor.py +30 -2
  10. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
  11. sglang/srt/layers/moe/topk.py +14 -0
  12. sglang/srt/layers/quantization/fp8.py +42 -2
  13. sglang/srt/layers/quantization/fp8_kernel.py +91 -18
  14. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  15. sglang/srt/managers/io_struct.py +29 -8
  16. sglang/srt/managers/schedule_batch.py +22 -15
  17. sglang/srt/managers/schedule_policy.py +1 -1
  18. sglang/srt/managers/scheduler.py +71 -34
  19. sglang/srt/managers/session_controller.py +102 -27
  20. sglang/srt/managers/tokenizer_manager.py +95 -55
  21. sglang/srt/managers/tp_worker.py +7 -0
  22. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  23. sglang/srt/model_executor/forward_batch_info.py +42 -3
  24. sglang/srt/model_executor/model_runner.py +4 -6
  25. sglang/srt/model_loader/loader.py +22 -11
  26. sglang/srt/models/gemma2.py +19 -0
  27. sglang/srt/models/llama.py +13 -2
  28. sglang/srt/models/llama_eagle.py +132 -0
  29. sglang/srt/openai_api/adapter.py +79 -2
  30. sglang/srt/openai_api/protocol.py +50 -0
  31. sglang/srt/sampling/sampling_params.py +9 -2
  32. sglang/srt/server.py +45 -39
  33. sglang/srt/server_args.py +17 -30
  34. sglang/srt/speculative/spec_info.py +19 -0
  35. sglang/srt/utils.py +62 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
  38. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
  39. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -331,6 +331,7 @@ def throughput_test(
331
331
  extra_request_body=extra_request_body,
332
332
  profile=bench_args.profile,
333
333
  )
334
+ backend.shutdown()
334
335
 
335
336
  if bench_args.result_filename:
336
337
  with open(bench_args.result_filename, "a") as fout:
sglang/bench_serving.py CHANGED
@@ -897,6 +897,7 @@ async def benchmark(
897
897
  else:
898
898
  raise ValueError(f"Unknown backend: {backend}")
899
899
 
900
+ # Limit concurrency
900
901
  # From https://github.com/vllm-project/vllm/pull/9390
901
902
  semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
902
903
 
@@ -906,6 +907,7 @@ async def benchmark(
906
907
  async with semaphore:
907
908
  return await request_func(request_func_input=request_func_input, pbar=pbar)
908
909
 
910
+ # Warmup
909
911
  print("Starting initial single prompt test run...")
910
912
  test_prompt, test_prompt_len, test_output_len = input_requests[0]
911
913
  test_input = RequestFuncInput(
@@ -924,11 +926,15 @@ async def benchmark(
924
926
  f"are correctly specified. Error: {test_output.error}"
925
927
  )
926
928
  else:
927
- requests.post(base_url + "/flush_cache")
928
929
  print("Initial test run completed. Starting main benchmark run...")
929
930
 
930
- time.sleep(1.5)
931
+ # Flush cache
932
+ if "sglang" in backend:
933
+ requests.post(base_url + "/flush_cache")
934
+
935
+ time.sleep(1.0)
931
936
 
937
+ # Start profiler
932
938
  if profile:
933
939
  print("Starting profiler...")
934
940
  profile_output = await async_request_profile(
@@ -939,6 +945,7 @@ async def benchmark(
939
945
 
940
946
  pbar = None if disable_tqdm else tqdm(total=len(input_requests))
941
947
 
948
+ # Run all requests
942
949
  benchmark_start_time = time.perf_counter()
943
950
  tasks: List[asyncio.Task] = []
944
951
  async for request in get_request(input_requests, request_rate):
@@ -959,6 +966,7 @@ async def benchmark(
959
966
  )
960
967
  outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
961
968
 
969
+ # Stop profiler
962
970
  if profile:
963
971
  print("Stopping profiler...")
964
972
  profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
@@ -968,8 +976,8 @@ async def benchmark(
968
976
  if pbar is not None:
969
977
  pbar.close()
970
978
 
979
+ # Compute metrics and print results
971
980
  benchmark_duration = time.perf_counter() - benchmark_start_time
972
-
973
981
  metrics, output_lens = calculate_metrics(
974
982
  input_requests=input_requests,
975
983
  outputs=outputs,
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
366
366
  def openai_completion(
367
367
  client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
368
368
  ):
369
+ # if "ebnf" is in kwargs, warn and remove
370
+ if "ebnf" in kwargs:
371
+ warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
372
+ del kwargs["ebnf"]
373
+
369
374
  for attempt in range(retries):
370
375
  try:
371
376
  if is_chat:
@@ -398,6 +403,11 @@ def openai_completion(
398
403
  def openai_completion_stream(
399
404
  client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
400
405
  ):
406
+ # if "ebnf" is in kwargs, warn and remove
407
+ if "ebnf" in kwargs:
408
+ warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
409
+ del kwargs["ebnf"]
410
+
401
411
  for attempt in range(retries):
402
412
  try:
403
413
  if is_chat:
@@ -15,7 +15,7 @@
15
15
  import json
16
16
  import logging
17
17
  from enum import IntEnum, auto
18
- from typing import List, Optional, Union
18
+ from typing import List, Optional, Set, Union
19
19
 
20
20
  import torch
21
21
  from transformers import PretrainedConfig
@@ -47,6 +47,7 @@ class ModelConfig:
47
47
  self.model_path = model_path
48
48
  self.revision = revision
49
49
  self.quantization = quantization
50
+
50
51
  # Parse args
51
52
  self.model_override_args = json.loads(model_override_args)
52
53
  self.hf_config = get_config(
@@ -130,7 +131,8 @@ class ModelConfig:
130
131
  # Veirfy quantization
131
132
  self._verify_quantization()
132
133
 
133
- # Multimodel attrs
134
+ # Cache attributes
135
+ self.hf_eos_token_id = self.get_hf_eos_token_id()
134
136
  self.image_token_id = getattr(self.hf_config, "image_token_id", None)
135
137
 
136
138
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
@@ -271,6 +273,13 @@ class ModelConfig:
271
273
  self.quantization,
272
274
  )
273
275
 
276
+ def get_hf_eos_token_id(self) -> Optional[Set[int]]:
277
+ eos_ids = getattr(self.hf_config, "eos_token_id", None)
278
+ if eos_ids:
279
+ # it can be either int or list of int
280
+ eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
281
+ return eos_ids
282
+
274
283
 
275
284
  def get_hf_text_config(config: PretrainedConfig):
276
285
  """Get the "sub" config relevant to llm for multi modal models.
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
126
126
  f"Skip invalid json_schema: json_schema={key_string}, {e=}"
127
127
  )
128
128
  return None
129
+ elif key_type == "ebnf":
130
+ try:
131
+ ctx = self.grammar_compiler.compile_grammar(key_string)
132
+ except RuntimeError as e:
133
+ logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
134
+ return None
129
135
  elif key_type == "regex":
130
136
  logger.warning(
131
137
  "regex hasn't been supported by xgrammar yet. This is skipped."
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
2
2
  from typing import Optional
3
3
 
4
4
  import torch
5
- from torch import nn
6
5
 
7
6
  from sglang.srt.layers.radix_attention import RadixAttention
8
7
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
8
8
  """
9
9
 
10
10
  import os
11
+ from dataclasses import dataclass
11
12
  from enum import Enum, auto
12
- from typing import TYPE_CHECKING, List
13
+ from typing import TYPE_CHECKING, List, Union
13
14
 
14
15
  import torch
15
16
  import triton
@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
38
39
  CROSS_ATTENTION = auto()
39
40
 
40
41
 
42
+ @dataclass
43
+ class DecodeMetadata:
44
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
45
+
46
+
47
+ @dataclass
48
+ class PrefillMetadata:
49
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
50
+ use_ragged: bool
51
+ extend_no_prefix: bool
52
+
53
+
41
54
  class FlashInferAttnBackend(AttentionBackend):
42
55
  """Flashinfer attention kernels."""
43
56
 
44
57
  def __init__(self, model_runner: ModelRunner):
45
58
  super().__init__()
46
59
 
60
+ # Parse constants
47
61
  self.decode_use_tensor_cores = should_use_tensor_core(
48
62
  kv_cache_dtype=model_runner.kv_cache_dtype,
49
63
  num_attention_heads=model_runner.model_config.num_attention_heads
@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
52
66
  model_runner.tp_size
53
67
  ),
54
68
  )
55
-
56
69
  self.max_context_len = model_runner.model_config.context_len
57
70
 
58
71
  assert not (
@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
120
133
  )
121
134
 
122
135
  # Other metadata
123
- self.forward_metadata = None
124
- self.cuda_graph_metadata = {}
136
+ self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
137
+ self.decode_cuda_graph_metadata = {}
125
138
 
126
139
  def init_forward_metadata(self, forward_batch: ForwardBatch):
127
140
  if forward_batch.forward_mode.is_decode():
@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
129
142
  forward_batch.req_pool_indices,
130
143
  forward_batch.seq_lens,
131
144
  forward_batch.seq_lens_sum,
132
- decode_wrappers=None,
145
+ decode_wrappers=self.decode_wrappers,
133
146
  encoder_lens=forward_batch.encoder_lens,
134
147
  )
135
- self.forward_metadata = (self.decode_wrappers,)
148
+ self.forward_metadata = DecodeMetadata(self.decode_wrappers)
136
149
  else:
137
150
  prefix_lens = forward_batch.extend_prefix_lens
138
151
 
@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
149
162
  forward_batch.seq_lens,
150
163
  forward_batch.seq_lens_sum,
151
164
  prefix_lens,
165
+ prefill_wrappers=self.prefill_wrappers_paged,
152
166
  use_ragged=use_ragged,
153
167
  encoder_lens=forward_batch.encoder_lens,
154
168
  )
155
-
156
- self.forward_metadata = (use_ragged, extend_no_prefix)
169
+ self.forward_metadata = PrefillMetadata(
170
+ self.prefill_wrappers_paged, use_ragged, extend_no_prefix
171
+ )
157
172
 
158
173
  def init_cuda_graph_state(self, max_bs: int):
159
174
  cuda_graph_kv_indices = torch.zeros(
@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
194
209
  decode_wrappers=decode_wrappers,
195
210
  encoder_lens=encoder_lens,
196
211
  )
197
- self.cuda_graph_metadata[bs] = decode_wrappers
198
- self.forward_metadata = (decode_wrappers,)
212
+ self.decode_cuda_graph_metadata[bs] = decode_wrappers
213
+ self.forward_metadata = DecodeMetadata(decode_wrappers)
199
214
 
200
215
  def init_forward_metadata_replay_cuda_graph(
201
216
  self,
@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
209
224
  req_pool_indices[:bs],
210
225
  seq_lens[:bs],
211
226
  seq_lens_sum,
212
- decode_wrappers=self.cuda_graph_metadata[bs],
227
+ decode_wrappers=self.decode_cuda_graph_metadata[bs],
213
228
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
214
229
  )
215
230
 
@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
225
240
  forward_batch: ForwardBatch,
226
241
  save_kv_cache=True,
227
242
  ):
228
- prefill_wrapper_paged = self.prefill_wrappers_paged[
243
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
229
244
  self._get_wrapper_idx(layer)
230
245
  ]
231
-
232
- use_ragged, extend_no_prefix = self.forward_metadata
233
246
  cache_loc = (
234
247
  forward_batch.out_cache_loc
235
248
  if not layer.is_cross_attention
236
249
  else forward_batch.encoder_out_cache_loc
237
250
  )
238
251
 
239
- if not use_ragged:
252
+ if not self.forward_metadata.use_ragged:
240
253
  if k is not None:
241
254
  assert v is not None
242
255
  if save_kv_cache:
@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
260
273
  logits_soft_cap=layer.logit_cap,
261
274
  )
262
275
 
263
- if extend_no_prefix:
276
+ if self.forward_metadata.extend_no_prefix:
264
277
  o = o1
265
278
  else:
266
279
  o2, s2 = prefill_wrapper_paged.forward_return_lse(
@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
287
300
  forward_batch: ForwardBatch,
288
301
  save_kv_cache=True,
289
302
  ):
290
- decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
303
+ decode_wrapper = self.forward_metadata.decode_wrappers[
304
+ self._get_wrapper_idx(layer)
305
+ ]
291
306
  cache_loc = (
292
307
  forward_batch.out_cache_loc
293
308
  if not layer.is_cross_attention
@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
322
337
 
323
338
  class FlashInferIndicesUpdaterDecode:
324
339
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
325
- # Constants
340
+ # Parse Constants
326
341
  self.num_qo_heads = (
327
342
  model_runner.model_config.num_attention_heads // model_runner.tp_size
328
343
  )
@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
340
355
  self.kv_indptr = attn_backend.kv_indptr
341
356
  self.kv_last_page_len = attn_backend.kv_last_page_len
342
357
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
343
- self.decode_wrappers = attn_backend.decode_wrappers
344
358
 
345
- # Dispatch
359
+ # Dispatch the update function
346
360
  if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
347
361
  self.update = self.update_sliding_window
348
362
  elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
356
370
  req_pool_indices: torch.Tensor,
357
371
  seq_lens: torch.Tensor,
358
372
  seq_lens_sum: int,
359
- decode_wrappers: List,
373
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
360
374
  encoder_lens: torch.Tensor,
361
375
  ):
362
376
  # Keep the signature for type checking. It will be assigned during runtime.
@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
367
381
  req_pool_indices: torch.Tensor,
368
382
  seq_lens: torch.Tensor,
369
383
  seq_lens_sum: int,
370
- decode_wrappers: List,
384
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
371
385
  encoder_lens: torch.Tensor,
372
386
  ):
373
387
  decode_wrappers = decode_wrappers or self.decode_wrappers
@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
385
399
  req_pool_indices: torch.Tensor,
386
400
  seq_lens: torch.Tensor,
387
401
  seq_lens_sum: int,
388
- decode_wrappers: List,
402
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
389
403
  encoder_lens: torch.Tensor,
390
404
  ):
391
- decode_wrappers = decode_wrappers or self.decode_wrappers
392
-
393
405
  for wrapper_id in range(2):
394
406
  if wrapper_id == 0:
395
407
  # Sliding window attention
@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
419
431
  req_pool_indices: torch.Tensor,
420
432
  seq_lens: torch.Tensor,
421
433
  seq_lens_sum: int,
422
- decode_wrappers: List,
434
+ decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
423
435
  encoder_lens: torch.Tensor,
424
436
  ):
425
- decode_wrappers = decode_wrappers or self.decode_wrappers
426
-
427
437
  for wrapper_id in range(2):
428
438
  if wrapper_id == 0:
429
439
  # Normal attention
@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
446
456
 
447
457
  def call_begin_forward(
448
458
  self,
449
- wrapper,
459
+ wrapper: BatchDecodeWithPagedKVCacheWrapper,
450
460
  req_pool_indices: torch.Tensor,
451
461
  paged_kernel_lens: torch.Tensor,
452
462
  paged_kernel_lens_sum: int,
@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
486
496
 
487
497
  class FlashInferIndicesUpdaterPrefill:
488
498
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
489
- # Constants
499
+ # Parse Constants
490
500
  self.num_qo_heads = (
491
501
  model_runner.model_config.num_attention_heads // model_runner.tp_size
492
502
  )
@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
505
515
  self.kv_last_page_len = attn_backend.kv_last_page_len
506
516
  self.qo_indptr = attn_backend.qo_indptr
507
517
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
508
- self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
509
- self.wrappers_paged = attn_backend.prefill_wrappers_paged
518
+ self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
510
519
 
511
- # Dispatch
520
+ # Dispatch the update function
512
521
  if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
513
522
  self.update = self.update_sliding_window
514
523
  elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
523
532
  seq_lens: torch.Tensor,
524
533
  seq_lens_sum: int,
525
534
  prefix_lens: torch.Tensor,
535
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
526
536
  use_ragged: bool,
527
537
  encoder_lens: torch.Tensor,
528
538
  ):
@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
535
545
  seq_lens: torch.Tensor,
536
546
  seq_lens_sum: int,
537
547
  prefix_lens: torch.Tensor,
548
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
538
549
  use_ragged: bool,
539
550
  encoder_lens: torch.Tensor,
540
551
  ):
@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
546
557
  paged_kernel_lens_sum = seq_lens_sum
547
558
 
548
559
  self.call_begin_forward(
549
- self.wrapper_ragged,
550
- self.wrappers_paged[0],
560
+ self.prefill_wrapper_ragged,
561
+ prefill_wrappers[0],
551
562
  req_pool_indices,
552
563
  paged_kernel_lens,
553
564
  paged_kernel_lens_sum,
@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
565
576
  seq_lens: torch.Tensor,
566
577
  seq_lens_sum: int,
567
578
  prefix_lens: torch.Tensor,
579
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
568
580
  use_ragged: bool,
569
581
  encoder_lens: torch.Tensor,
570
582
  ):
@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
584
596
  kv_start_idx = seq_lens - paged_kernel_lens
585
597
 
586
598
  self.call_begin_forward(
587
- self.wrapper_ragged,
588
- self.wrappers_paged[wrapper_id],
599
+ self.prefill_wrapper_ragged,
600
+ prefill_wrappers[wrapper_id],
589
601
  req_pool_indices,
590
602
  paged_kernel_lens,
591
603
  paged_kernel_lens_sum,
@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
603
615
  seq_lens: torch.Tensor,
604
616
  seq_lens_sum: int,
605
617
  prefix_lens: torch.Tensor,
618
+ prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
606
619
  use_ragged: bool,
607
620
  encoder_lens: torch.Tensor,
608
621
  ):
@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
619
632
  paged_kernel_lens_sum = paged_kernel_lens.sum().item()
620
633
 
621
634
  self.call_begin_forward(
622
- self.wrapper_ragged,
623
- self.wrappers_paged[wrapper_id],
635
+ self.prefill_wrapper_ragged,
636
+ prefill_wrappers[wrapper_id],
624
637
  req_pool_indices,
625
638
  paged_kernel_lens,
626
639
  paged_kernel_lens_sum,
@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
634
647
 
635
648
  def call_begin_forward(
636
649
  self,
637
- wrapper_ragged,
638
- wrapper_paged,
650
+ wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
651
+ wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
639
652
  req_pool_indices: torch.Tensor,
640
653
  paged_kernel_lens: torch.Tensor,
641
654
  paged_kernel_lens_sum: int,
@@ -292,27 +292,33 @@ def extend_attention_fwd(
292
292
  BLOCK_DPE = 0
293
293
  BLOCK_DV = triton.next_power_of_2(Lv)
294
294
 
295
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
296
- if Lq <= 256:
297
- BLOCK_M, BLOCK_N = (128, 64)
298
- else:
299
- BLOCK_M, BLOCK_N = (32, 64)
300
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
301
- if Lq <= 128:
302
- BLOCK_M, BLOCK_N = (128, 128)
303
- elif Lq <= 256:
304
- BLOCK_M, BLOCK_N = (64, 64)
305
- else:
306
- BLOCK_M, BLOCK_N = (32, 64)
295
+ if is_hip_:
296
+ BLOCK_M, BLOCK_N = (64, 64)
297
+ num_warps = 4
298
+
307
299
  else:
308
- BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
300
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
301
+ if Lq <= 256:
302
+ BLOCK_M, BLOCK_N = (128, 64)
303
+ else:
304
+ BLOCK_M, BLOCK_N = (32, 64)
305
+ elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
306
+ if Lq <= 128:
307
+ BLOCK_M, BLOCK_N = (128, 128)
308
+ elif Lq <= 256:
309
+ BLOCK_M, BLOCK_N = (64, 64)
310
+ else:
311
+ BLOCK_M, BLOCK_N = (32, 64)
312
+ else:
313
+ BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
314
+
315
+ num_warps = 4 if Lk <= 64 else 8
309
316
 
310
317
  sm_scale = sm_scale or 1.0 / (Lq**0.5)
311
318
  batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
312
319
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
313
320
 
314
321
  grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
315
- num_warps = 4 if Lk <= 64 else 8
316
322
  num_stages = 1
317
323
 
318
324
  extra_kargs = {}
@@ -24,7 +24,11 @@ from vllm.distributed import (
24
24
  )
25
25
 
26
26
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
27
+ from sglang.srt.model_executor.forward_batch_info import (
28
+ CaptureHiddenMode,
29
+ ForwardBatch,
30
+ ForwardMode,
31
+ )
28
32
 
29
33
 
30
34
  @dataclasses.dataclass
@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
46
50
  output_top_logprobs_val: List = None
47
51
  output_top_logprobs_idx: List = None
48
52
 
53
+ # Used by speculative decoding (EAGLE)
54
+ # The output of transformer layers
55
+ hidden_states: Optional[torch.Tensor] = None
56
+
49
57
 
50
58
  @dataclasses.dataclass
51
59
  class LogitsMetadata:
@@ -61,6 +69,8 @@ class LogitsMetadata:
61
69
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
62
70
  extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
63
71
 
72
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
73
+
64
74
  @classmethod
65
75
  def from_forward_batch(cls, forward_batch: ForwardBatch):
66
76
  extend_logprob_pruned_lens_cpu = None
@@ -78,6 +88,11 @@ class LogitsMetadata:
78
88
  else:
79
89
  return_top_logprob = False
80
90
 
91
+ if forward_batch.spec_info:
92
+ capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
93
+ else:
94
+ capture_hidden_mode = CaptureHiddenMode.NULL
95
+
81
96
  return cls(
82
97
  forward_mode=forward_batch.forward_mode,
83
98
  top_logprobs_nums=forward_batch.top_logprobs_nums,
@@ -87,6 +102,7 @@ class LogitsMetadata:
87
102
  extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
88
103
  extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
89
104
  extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
105
+ capture_hidden_mode=capture_hidden_mode,
90
106
  )
91
107
 
92
108
 
@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
116
132
  assert isinstance(logits_metadata, LogitsMetadata)
117
133
 
118
134
  # Get the last hidden states and last logits for the next token prediction
119
- if logits_metadata.forward_mode.is_decode():
135
+ if (
136
+ logits_metadata.forward_mode.is_decode()
137
+ or logits_metadata.forward_mode.is_target_verify()
138
+ ):
120
139
  last_index = None
121
140
  last_hidden = hidden_states
122
141
  else:
@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
137
156
  if not logits_metadata.return_logprob:
138
157
  return LogitsProcessorOutput(
139
158
  next_token_logits=last_logits,
159
+ hidden_states=(
160
+ hidden_states
161
+ if logits_metadata.capture_hidden_mode.is_full()
162
+ else (
163
+ last_hidden
164
+ if logits_metadata.capture_hidden_mode.is_last()
165
+ else None
166
+ )
167
+ ),
140
168
  )
141
169
  else:
142
170
  last_logprobs = self.compute_temp_top_p_normalized_logprobs(