sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -226,11 +226,11 @@ class GenerateReqInput:
226
226
 
227
227
  # Expand input based on type
228
228
  self._expand_inputs(num)
229
+ self._normalize_rid(num)
229
230
  self._normalize_lora_paths(num)
230
231
  self._normalize_image_data(num)
231
232
  self._normalize_audio_data(num)
232
233
  self._normalize_sampling_params(num)
233
- self._normalize_rid(num)
234
234
  self._normalize_logprob_params(num)
235
235
  self._normalize_custom_logit_processor(num)
236
236
 
@@ -319,8 +319,16 @@ class GenerateReqInput:
319
319
  """Normalize request IDs for batch processing."""
320
320
  if self.rid is None:
321
321
  self.rid = [uuid.uuid4().hex for _ in range(num)]
322
- elif not isinstance(self.rid, list):
323
- raise ValueError("The rid should be a list for batch processing.")
322
+ elif isinstance(self.rid, str):
323
+ new_rids = [f"{self.rid}_{i}" for i in range(num)]
324
+ self.rid = new_rids
325
+ elif isinstance(self.rid, list):
326
+ if len(self.rid) != num:
327
+ raise ValueError(
328
+ "The specified rids length mismatch with the batch_size for batch processing."
329
+ )
330
+ else:
331
+ raise ValueError("The rid should be a string or a list of strings.")
324
332
 
325
333
  def _normalize_logprob_params(self, num):
326
334
  """Normalize logprob-related parameters for batch processing."""
@@ -530,6 +538,7 @@ class EmbeddingReqInput:
530
538
  if self.text is not None:
531
539
  if isinstance(self.text, list):
532
540
  self.batch_size += len(self.text)
541
+ self.is_single = False
533
542
  else:
534
543
  self.batch_size += 1
535
544
 
@@ -537,12 +546,10 @@ class EmbeddingReqInput:
537
546
  if self.input_ids is not None:
538
547
  if isinstance(self.input_ids[0], list):
539
548
  self.batch_size += len(self.input_ids)
549
+ self.is_single = False
540
550
  else:
541
551
  self.batch_size += 1
542
552
 
543
- if self.batch_size > 1:
544
- self.is_single = False
545
-
546
553
  # Fill in default arguments
547
554
  if self.is_single:
548
555
  if self.rid is None:
@@ -812,7 +819,9 @@ class GetWeightsByNameReqOutput:
812
819
 
813
820
  @dataclass
814
821
  class ReleaseMemoryOccupationReqInput:
815
- pass
822
+ # Optional tags to identify the memory region, which is primarily used for RL
823
+ # Currently we only support `weights` and `kv_cache`
824
+ tags: Optional[List[str]] = None
816
825
 
817
826
 
818
827
  @dataclass
@@ -822,7 +831,9 @@ class ReleaseMemoryOccupationReqOutput:
822
831
 
823
832
  @dataclass
824
833
  class ResumeMemoryOccupationReqInput:
825
- pass
834
+ # Optional tags to identify the memory region, which is primarily used for RL
835
+ # Currently we only support `weights` and `kv_cache`
836
+ tags: Optional[List[str]] = None
826
837
 
827
838
 
828
839
  @dataclass
@@ -861,12 +872,6 @@ class SetInternalStateReq:
861
872
  server_args: Dict[str, Any]
862
873
 
863
874
 
864
- @dataclass
865
- class V1RerankReqInput:
866
- query: str
867
- documents: List[str]
868
-
869
-
870
875
  @dataclass
871
876
  class SetInternalStateReqOutput:
872
877
  updated: bool
@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
23
23
  RAW_IMAGES = "raw_images"
24
24
  PRECOMPUTED_FEATURES = "precomputed_features"
25
25
  PIXEL_VALUES = "pixel_values"
26
+ AUDIO = "audio"
26
27
 
27
28
 
28
29
  @dataclasses.dataclass
@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
441
442
  has_image = False
442
443
  has_pixel_values = False
443
444
  has_precomputed_features = False
445
+ has_audio = False
444
446
 
445
447
  for mm_input in mm_inputs:
446
448
  if isinstance(mm_input, Image.Image):
447
449
  has_image = True
450
+ elif isinstance(mm_input, np.ndarray):
451
+ has_audio = True
448
452
  elif isinstance(mm_input, dict):
449
453
  if mm_input.get("precomputed_features", None) is not None:
450
454
  has_precomputed_features = True
@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
461
465
 
462
466
  # Validate format consistency
463
467
  format_count = sum(
464
- [has_image, has_pixel_values, has_precomputed_features]
468
+ [has_image, has_pixel_values, has_precomputed_features, has_audio]
465
469
  )
466
470
  if format_count > 1:
467
471
  raise ValueError(
468
472
  "Unsupported: mixture of multimodal input formats. "
469
473
  f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
470
- f"precomputed_features={has_precomputed_features}"
474
+ f"precomputed_features={has_precomputed_features}, audio={has_audio}"
471
475
  )
472
476
 
473
477
  if has_image:
@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
476
480
  return MultimodalInputFormat.PRECOMPUTED_FEATURES
477
481
  elif has_pixel_values:
478
482
  return MultimodalInputFormat.PIXEL_VALUES
483
+ elif has_audio:
484
+ return MultimodalInputFormat.AUDIO
479
485
  else:
480
486
  raise ValueError("No valid multimodal input format found")
481
487
  except Exception as e:
@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
521
527
  input_ids = tokenize_text(base_output.input_text)
522
528
  return combined_mm_item, input_ids
523
529
 
530
+ def process_audio(
531
+ base_output: BaseMultiModalProcessorOutput,
532
+ ) -> Tuple[MultimodalDataItem, torch.Tensor]:
533
+ """Process inputs with audio."""
534
+ ret = self.process_mm_data(
535
+ input_text=base_output.input_text,
536
+ audio=base_output.audios, # Note: "audio" is for gemma3n only
537
+ )
538
+ combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
539
+ for key, value in ret.items():
540
+ if key != "input_ids" and hasattr(combined_mm_item, key):
541
+ setattr(combined_mm_item, key, value)
542
+ input_ids = ret["input_ids"].flatten()
543
+ return combined_mm_item, input_ids
544
+
524
545
  def finalize_mm_item(
525
546
  combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
526
547
  ) -> MultimodalDataItem:
527
548
  """Apply common post-processing to the multimodal item."""
528
- combined_mm_item.image_offsets = self.get_mm_items_offset(
529
- input_ids=input_ids,
530
- mm_token_id=self.IM_TOKEN_ID,
531
- )
549
+ if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
550
+ combined_mm_item.image_offsets = self.get_mm_items_offset(
551
+ input_ids=input_ids,
552
+ mm_token_id=self.IM_TOKEN_ID,
553
+ )
554
+ elif combined_mm_item.modality == Modality.AUDIO:
555
+ combined_mm_item.audio_offsets = self.get_mm_items_offset(
556
+ input_ids=input_ids,
557
+ mm_token_id=self.AUDIO_TOKEN_ID,
558
+ )
559
+ elif combined_mm_item.modality == Modality.VIDEO:
560
+ combined_mm_item.video_offsets = self.get_mm_items_offset(
561
+ input_ids=input_ids,
562
+ mm_token_id=self.VIDEO_TOKEN_ID,
563
+ )
564
+ else:
565
+ raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
532
566
  return combined_mm_item
533
567
 
534
- # Main logic
535
- mm_inputs = base_output.images
568
+ # Main logic - determine input type and handle text-only case
569
+ mm_inputs = base_output.images or base_output.audios
536
570
  if not mm_inputs:
537
- # Return text-only case
538
571
  input_ids = tokenize_text(base_output.input_text)
539
572
  return None, input_ids
540
573
 
@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
548
581
  combined_mm_item, input_ids = process_precomputed_features(base_output)
549
582
  elif input_format == MultimodalInputFormat.PIXEL_VALUES:
550
583
  combined_mm_item, input_ids = process_pixel_values(base_output)
584
+ elif input_format == MultimodalInputFormat.AUDIO:
585
+ combined_mm_item, input_ids = process_audio(base_output)
551
586
  else:
552
587
  raise ValueError(f"Unknown input format: {input_format}")
553
588
 
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ import re
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ from sglang.srt.managers.multimodal_processor import (
19
+ BaseMultimodalProcessor as SGLangBaseProcessor,
20
+ )
21
+ from sglang.srt.managers.multimodal_processors.base_processor import (
22
+ MultimodalSpecialTokens,
23
+ )
24
+ from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
25
+
26
+
27
+ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
28
+ """Multimodal processor for Gemma3n supporting image and audio inputs."""
29
+
30
+ models = [Gemma3nForConditionalGeneration]
31
+
32
+ def __init__(self, hf_config, server_args, _processor):
33
+ super().__init__(hf_config, server_args, _processor)
34
+
35
+ self.IMAGE_TOKEN = "<image_soft_token>"
36
+ self.IMAGE_TOKEN_REGEX = re.compile(
37
+ r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
38
+ )
39
+
40
+ self.AUDIO_TOKEN = "<audio_soft_token>"
41
+ self.AUDIO_TOKEN_REGEX = re.compile(
42
+ r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
43
+ )
44
+
45
+ self.IM_TOKEN_ID = hf_config.image_token_id
46
+ self.IM_START_TOKEN_ID = hf_config.boi_token_id
47
+ self.IM_END_TOKEN_ID = hf_config.eoi_token_id
48
+
49
+ self.AUDIO_TOKEN_ID = hf_config.audio_token_id
50
+ self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
51
+ self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
52
+
53
+ async def process_mm_data_async(
54
+ self,
55
+ image_data: Optional[List[Union[str, bytes, Dict]]] = None,
56
+ audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
57
+ input_text: str = "",
58
+ request_obj=None,
59
+ max_req_input_len: int = 0,
60
+ *args,
61
+ **kwargs,
62
+ ):
63
+ """Process multimodal data including images and audio."""
64
+
65
+ audio_data = request_obj.audio_data
66
+ if not image_data and not audio_data:
67
+ return None
68
+
69
+ if isinstance(image_data, str):
70
+ image_data = [image_data]
71
+
72
+ if isinstance(audio_data, str):
73
+ audio_data = [audio_data]
74
+
75
+ base_output = self.load_mm_data(
76
+ prompt=input_text,
77
+ image_data=image_data,
78
+ audio_data=audio_data,
79
+ max_req_input_len=max_req_input_len,
80
+ multimodal_tokens=MultimodalSpecialTokens(
81
+ image_token=self.IMAGE_TOKEN,
82
+ image_token_regex=self.IMAGE_TOKEN_REGEX,
83
+ audio_token=self.AUDIO_TOKEN,
84
+ audio_token_regex=self.AUDIO_TOKEN_REGEX,
85
+ ),
86
+ )
87
+
88
+ combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
89
+
90
+ return {
91
+ "input_ids": input_ids.tolist(),
92
+ "mm_items": [combined_mm_item] if combined_mm_item is not None else [],
93
+ "im_start_id": self.IM_START_TOKEN_ID,
94
+ "im_end_id": self.IM_END_TOKEN_ID,
95
+ "audio_start_id": self.AUDIO_START_TOKEN_ID,
96
+ "audio_end_id": self.AUDIO_END_TOKEN_ID,
97
+ }
@@ -38,7 +38,7 @@ import logging
38
38
  import threading
39
39
  from enum import Enum, auto
40
40
  from http import HTTPStatus
41
- from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
41
+ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
42
42
 
43
43
  import numpy as np
44
44
  import torch
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
54
54
  )
55
55
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
56
  from sglang.srt.layers.multimodal import gpu_tensor_hash
57
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
57
58
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
58
59
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
59
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
60
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
60
61
  from sglang.srt.metrics.collector import TimeStats
61
62
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
62
63
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -85,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
85
86
  "enable_deepep_moe",
86
87
  "deepep_mode",
87
88
  "enable_ep_moe",
89
+ "enable_flashinfer_moe",
88
90
  "moe_dense_tp_size",
89
91
  "ep_dispatch_algorithm",
90
92
  "deepep_config",
@@ -99,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
99
101
  "torchao_config",
100
102
  "triton_attention_reduce_in_fp32",
101
103
  "num_reserved_decode_tokens",
104
+ "weight_loader_disable_mmap",
102
105
  ]
103
106
 
104
107
  # Put some global args for easy access
@@ -211,6 +214,10 @@ class MultimodalDataItem:
211
214
  audio_feature_lens: Optional[List[torch.Tensor]] = None
212
215
  audio_offsets: Optional[List[Tuple[int, int]]] = None
213
216
 
217
+ # gemma3n related
218
+ input_features: Optional[torch.Tensor] = None
219
+ input_features_mask: Optional[torch.Tensor] = None
220
+
214
221
  precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
215
222
 
216
223
  @staticmethod
@@ -274,7 +281,10 @@ class MultimodalDataItem:
274
281
  if self.precomputed_features is not None:
275
282
  self.hash = hash_feature(self.precomputed_features)
276
283
  elif self.is_audio():
277
- self.hash = hash_feature(self.audio_features)
284
+ if self.audio_features is not None:
285
+ self.hash = hash_feature(self.audio_features)
286
+ elif self.input_features is not None:
287
+ self.hash = hash_feature(self.input_features)
278
288
  else:
279
289
  self.hash = hash_feature(self.pixel_values)
280
290
 
@@ -285,6 +295,7 @@ class MultimodalDataItem:
285
295
  return (self.modality == Modality.AUDIO) and (
286
296
  self.precomputed_features is not None
287
297
  or not MultimodalDataItem.is_empty_list(self.audio_features)
298
+ or not MultimodalDataItem.is_empty_list(self.input_features)
288
299
  )
289
300
 
290
301
  def is_image(self):
@@ -436,7 +447,7 @@ class Req:
436
447
  self,
437
448
  rid: str,
438
449
  origin_input_text: str,
439
- origin_input_ids: Tuple[int],
450
+ origin_input_ids: List[int],
440
451
  sampling_params: SamplingParams,
441
452
  return_logprob: bool = False,
442
453
  top_logprobs_num: int = 0,
@@ -467,7 +478,7 @@ class Req:
467
478
  # Each decode stage's output ids
468
479
  self.output_ids = []
469
480
  # fill_ids = origin_input_ids + output_ids. Updated if chunked.
470
- self.fill_ids = None
481
+ self.fill_ids = []
471
482
  self.session_id = session_id
472
483
  self.input_embeds = input_embeds
473
484
 
@@ -519,13 +530,14 @@ class Req:
519
530
 
520
531
  # Prefix info
521
532
  # The indices to kv cache for the shared prefix.
522
- self.prefix_indices = []
533
+ self.prefix_indices: torch.Tensor = []
523
534
  # Number of tokens to run prefill.
524
535
  self.extend_input_len = 0
525
536
  # The relative logprob_start_len in an extend batch
526
537
  self.extend_logprob_start_len = 0
527
- self.last_node = None
528
- self.last_node_global = None
538
+ self.last_node: Any = None
539
+ self.last_host_node: Any = None
540
+ self.host_hit_length = 0
529
541
 
530
542
  # Whether or not if it is chunked. It increments whenever
531
543
  # it is chunked, and decrement whenever chunked request is
@@ -583,6 +595,7 @@ class Req:
583
595
  self.output_token_ids_logprobs_idx
584
596
  ) = None
585
597
  self.hidden_states: List[List[float]] = []
598
+ self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
586
599
 
587
600
  # Embedding (return values)
588
601
  self.embedding = None
@@ -644,29 +657,17 @@ class Req:
644
657
  def init_next_round_input(
645
658
  self,
646
659
  tree_cache: Optional[BasePrefixCache] = None,
647
- enable_hierarchical_cache=False,
648
660
  ):
649
661
  self.fill_ids = self.origin_input_ids + self.output_ids
650
662
  if tree_cache is not None:
651
- # tree cache is None if the prefix is not computed with tree cache.
652
- if enable_hierarchical_cache:
653
- self.prefix_indices, self.last_node, self.last_node_global = (
654
- tree_cache.match_prefix(
655
- key=self.adjust_max_prefix_ids(), include_evicted=True
656
- )
657
- )
658
- else:
659
- self.prefix_indices, self.last_node = tree_cache.match_prefix(
660
- rid=self.rid, key=self.adjust_max_prefix_ids()
661
- )
662
- elif enable_hierarchical_cache:
663
- # in case last_node is evicted during scheduling, we need to update the prefix_indices
664
- while self.last_node.evicted:
665
- self.prefix_indices = self.prefix_indices[
666
- : -len(self.last_node.host_value)
667
- ]
668
- self.last_node = self.last_node.parent
669
-
663
+ (
664
+ self.prefix_indices,
665
+ self.last_node,
666
+ self.last_host_node,
667
+ self.host_hit_length,
668
+ ) = tree_cache.match_prefix(
669
+ key=self.adjust_max_prefix_ids(),
670
+ )
670
671
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
671
672
 
672
673
  def adjust_max_prefix_ids(self):
@@ -796,6 +797,7 @@ class Req:
796
797
  self.multimodal_inputs = None
797
798
  self.grammar = None
798
799
  self.origin_input_ids = [0] # set it to one token to skip the long prefill
800
+ self.return_logprob = False
799
801
  self.finished_reason = FINISH_ABORT(
800
802
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
801
803
  )
@@ -820,7 +822,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
820
822
  # Request, memory pool, and cache
821
823
  reqs: List[Req]
822
824
  req_to_token_pool: ReqToTokenPool = None
823
- token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
825
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
824
826
  tree_cache: BasePrefixCache = None
825
827
 
826
828
  # Batch configs
@@ -862,6 +864,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
862
864
  global_num_tokens: Optional[List[int]] = None
863
865
  global_num_tokens_for_logprob: Optional[List[int]] = None
864
866
  can_run_dp_cuda_graph: bool = False
867
+ is_extend_in_batch: bool = False
865
868
  tbo_split_seq_index: Optional[int] = None
866
869
  global_forward_mode: Optional[ForwardMode] = None
867
870
 
@@ -908,12 +911,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
908
911
  # Whether to return hidden states
909
912
  return_hidden_states: bool = False
910
913
 
914
+ # hicache pointer for synchronizing data loading from CPU to GPU
915
+ hicache_consumer_index: int = 0
916
+
911
917
  @classmethod
912
918
  def init_new(
913
919
  cls,
914
920
  reqs: List[Req],
915
921
  req_to_token_pool: ReqToTokenPool,
916
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
922
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
917
923
  tree_cache: BasePrefixCache,
918
924
  model_config: ModelConfig,
919
925
  enable_overlap: bool,
@@ -1365,7 +1371,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1365
1371
  return len(self.reqs)
1366
1372
  # In the decoding phase, the length of a request's KV cache should be
1367
1373
  # the total length of the request minus 1
1368
- return sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1374
+ return (
1375
+ sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1376
+ if self.enable_overlap
1377
+ else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1378
+ )
1369
1379
 
1370
1380
  def check_decode_mem(self, buf_multiplier=1):
1371
1381
  tokens_required = (
@@ -1734,6 +1744,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1734
1744
  token_type_ids=self.token_type_ids,
1735
1745
  spec_algorithm=self.spec_algorithm,
1736
1746
  spec_info=self.spec_info,
1747
+ hicache_consumer_index=self.hicache_consumer_index,
1737
1748
  capture_hidden_mode=(
1738
1749
  CaptureHiddenMode.FULL
1739
1750
  if self.return_hidden_states
@@ -1760,11 +1771,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1760
1771
  decoding_reqs=self.decoding_reqs,
1761
1772
  spec_algorithm=self.spec_algorithm,
1762
1773
  enable_custom_logit_processor=self.enable_custom_logit_processor,
1774
+ global_num_tokens=self.global_num_tokens,
1775
+ global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1776
+ can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1777
+ is_extend_in_batch=self.is_extend_in_batch,
1763
1778
  )
1764
1779
 
1765
1780
  def __str__(self):
1766
1781
  return (
1767
- f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
1782
+ f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1768
1783
  f"#req={(len(self.reqs))})"
1769
1784
  )
1770
1785
 
@@ -1833,6 +1848,8 @@ class ModelWorkerBatch:
1833
1848
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1834
1849
  # If set, the output of the batch contains the hidden states of the run.
1835
1850
  capture_hidden_mode: CaptureHiddenMode = None
1851
+ spec_num_draft_tokens: Optional[int] = None
1852
+ hicache_consumer_index: int = 0
1836
1853
 
1837
1854
  # Overlap event
1838
1855
  launch_done: Optional[threading.Event] = None