sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ from typing import Optional
18
18
 
19
19
  from torch import nn
20
20
 
21
- from sglang.srt.layers.linear import UnquantizedLinearMethod
22
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
23
 
@@ -52,9 +51,9 @@ class RadixAttention(nn.Module):
52
51
  sliding_window_size: int = -1,
53
52
  is_cross_attention: bool = False,
54
53
  quant_config: Optional[QuantizationConfig] = None,
55
- attn_type=AttentionType.DECODER,
56
- prefix: str = "",
54
+ attn_type: AttentionType = AttentionType.DECODER,
57
55
  use_irope: bool = False,
56
+ prefix: str = "",
58
57
  ):
59
58
  super().__init__()
60
59
  self.tp_q_head_num = num_heads
@@ -81,7 +81,7 @@ class LoRAManager:
81
81
  seg_indptr=torch.zeros(
82
82
  self.max_bs_in_cuda_graph + 1, dtype=torch.int32
83
83
  ),
84
- max_len=0,
84
+ max_len=1,
85
85
  weight_indices=torch.zeros(
86
86
  self.max_bs_in_cuda_graph, dtype=torch.int32
87
87
  ),
@@ -89,6 +89,17 @@ class LoRAManager:
89
89
  scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
90
90
  )
91
91
 
92
+ # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
93
+ # across batches.
94
+ self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
95
+ torch.cumsum(
96
+ self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
97
+ dim=0,
98
+ out=self.cuda_graph_batch_info.seg_indptr[
99
+ 1 : self.max_bs_in_cuda_graph + 1
100
+ ],
101
+ )
102
+
92
103
  def init_loras(self):
93
104
  # Config of each LoRA adapter
94
105
  self.configs: Dict[str, LoRAConfig] = {}
@@ -159,6 +170,45 @@ class LoRAManager:
159
170
  # set up batch info shared by all lora modules
160
171
  bs = forward_batch.batch_size
161
172
 
173
+ def transfer_adapter_info(
174
+ weight_indices_out: torch.Tensor,
175
+ lora_ranks_out: torch.Tensor,
176
+ scalings_out: torch.Tensor,
177
+ ):
178
+ """
179
+ Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
180
+ to device (CUDA) asynchronously.
181
+ """
182
+ weight_indices = [0] * len(forward_batch.lora_paths)
183
+ lora_ranks = [0] * self.max_loras_per_batch
184
+ scalings = [0] * self.max_loras_per_batch
185
+ for i, lora_path in enumerate(forward_batch.lora_paths):
186
+ weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
187
+ if lora_path is not None:
188
+ lora = self.loras[lora_path]
189
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
190
+ scalings[weight_indices[i]] = lora.scaling
191
+
192
+ # Use pinned memory to avoid synchronizations during host-to-device transfer
193
+ weight_indices_tensor = torch.tensor(
194
+ weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
195
+ )
196
+ lora_ranks_tensor = torch.tensor(
197
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
198
+ )
199
+ scalings_tensor = torch.tensor(
200
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
201
+ )
202
+
203
+ # Copy to device tensors asynchronously
204
+ weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
205
+ lora_ranks_out[: self.max_loras_per_batch].copy_(
206
+ lora_ranks_tensor, non_blocking=True
207
+ )
208
+ scalings_out[: self.max_loras_per_batch].copy_(
209
+ scalings_tensor, non_blocking=True
210
+ )
211
+
162
212
  if (
163
213
  hasattr(self, "max_bs_in_cuda_graph")
164
214
  and bs <= self.max_bs_in_cuda_graph
@@ -166,51 +216,46 @@ class LoRAManager:
166
216
  ):
167
217
  # Do in-place updates when CUDA graph is enabled and the batch forward mode
168
218
  # could use CUDA graph.
169
- self.cuda_graph_batch_info.bs = bs
170
- self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
171
- torch.cumsum(
172
- self.cuda_graph_batch_info.seg_lens[:bs],
173
- dim=0,
174
- out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
219
+
220
+ transfer_adapter_info(
221
+ self.cuda_graph_batch_info.weight_indices,
222
+ self.cuda_graph_batch_info.lora_ranks,
223
+ self.cuda_graph_batch_info.scalings,
175
224
  )
176
- self.cuda_graph_batch_info.max_len = 1
177
225
 
178
- for i, lora_path in enumerate(forward_batch.lora_paths):
179
- self.cuda_graph_batch_info.weight_indices[i] = (
180
- self.memory_pool.get_buffer_id(lora_path)
181
- )
182
- if lora_path is not None:
183
- lora = self.loras[lora_path]
184
- self.cuda_graph_batch_info.lora_ranks[
185
- self.cuda_graph_batch_info.weight_indices[i]
186
- ] = lora.config.hf_config["r"]
187
- self.cuda_graph_batch_info.scalings[
188
- self.cuda_graph_batch_info.weight_indices[i]
189
- ] = lora.scaling
226
+ self.cuda_graph_batch_info.bs = bs
227
+ self.cuda_graph_batch_info.max_len = 1
190
228
  batch_info = self.cuda_graph_batch_info
191
229
  else:
230
+ weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
231
+ lora_ranks = torch.zeros(
232
+ (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
233
+ )
234
+ scalings = torch.zeros(
235
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
236
+ )
237
+ transfer_adapter_info(
238
+ weight_indices,
239
+ lora_ranks,
240
+ scalings,
241
+ )
242
+
192
243
  seg_lens = (
193
244
  forward_batch.extend_seq_lens
194
245
  if forward_batch.forward_mode.is_extend()
195
246
  else torch.ones(bs, device=self.device)
196
247
  )
248
+
249
+ max_len = (
250
+ # Calculate max_len from the CPU copy to avoid D2H transfer.
251
+ max(forward_batch.extend_seq_lens_cpu)
252
+ if forward_batch.forward_mode.is_extend()
253
+ else 1
254
+ )
255
+
197
256
  seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
198
257
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
199
- max_len = int(torch.max(seg_lens))
200
- weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
201
258
 
202
- lora_ranks = torch.zeros(
203
- (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
204
- )
205
- scalings = torch.zeros(
206
- (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
207
- )
208
- for i, lora_path in enumerate(forward_batch.lora_paths):
209
- weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
210
- if lora_path is not None:
211
- lora = self.loras[lora_path]
212
- lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
213
- scalings[weight_indices[i]] = lora.scaling
214
259
  batch_info = LoRABatchInfo(
215
260
  bs=bs,
216
261
  seg_lens=seg_lens,
@@ -132,12 +132,13 @@ class LoRAMemoryPool:
132
132
  for buffer_id in range(self.max_loras_per_batch):
133
133
  # Prioritize empty slots
134
134
  if self.buffer_id_to_uid[buffer_id] == "":
135
- return buffer_id, ""
135
+ return buffer_id
136
136
 
137
137
  for buffer_id in range(self.max_loras_per_batch):
138
138
  # Evict unneeded lora
139
139
  if self.buffer_id_to_uid[buffer_id] not in cur_uids:
140
- return buffer_id, self.buffer_id_to_uid[buffer_id]
140
+ self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
141
+ return buffer_id
141
142
 
142
143
  raise ValueError(
143
144
  "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
@@ -145,9 +146,7 @@ class LoRAMemoryPool:
145
146
 
146
147
  for uid in cur_uids:
147
148
  if uid not in self.uid_to_buffer_id:
148
- buffer_id, evicted_lora_uid = get_available_buffer_slot()
149
- if evicted_lora_uid != "":
150
- self.uid_to_buffer_id.pop(evicted_lora_uid)
149
+ buffer_id = get_available_buffer_slot()
151
150
  self.load_lora_weight_to_buffer(
152
151
  uid, buffer_id, lora_adapters.get(uid, None)
153
152
  )
@@ -22,7 +22,8 @@ from typing import List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
25
+ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
26
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
26
27
 
27
28
  logger = logging.getLogger(__name__)
28
29
 
@@ -87,7 +87,7 @@ class GenerateReqInput:
87
87
 
88
88
  # The modalities of the image data [image, multi-images, video]
89
89
  modalities: Optional[List[str]] = None
90
- # LoRA related
90
+ # The path to the LoRA
91
91
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
92
92
 
93
93
  # Session info for continual prompting
@@ -99,7 +99,7 @@ class GenerateReqInput:
99
99
  custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
100
100
 
101
101
  # Whether to return hidden states
102
- return_hidden_states: bool = False
102
+ return_hidden_states: Union[List[bool], bool] = False
103
103
 
104
104
  # For disaggregated inference
105
105
  bootstrap_host: Optional[Union[List[str], str]] = None
@@ -409,7 +409,11 @@ class GenerateReqInput:
409
409
  if self.custom_logit_processor is not None
410
410
  else None
411
411
  ),
412
- return_hidden_states=self.return_hidden_states,
412
+ return_hidden_states=(
413
+ self.return_hidden_states[i]
414
+ if isinstance(self.return_hidden_states, list)
415
+ else self.return_hidden_states
416
+ ),
413
417
  # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
414
418
  bootstrap_host=(
415
419
  self.bootstrap_host[i] if self.bootstrap_host is not None else None
@@ -477,7 +481,7 @@ class TokenizedGenerateReqInput:
477
481
  @dataclass
478
482
  class EmbeddingReqInput:
479
483
  # The input prompt. It can be a single prompt or a batch of prompts.
480
- text: Optional[Union[List[str], str]] = None
484
+ text: Optional[Union[List[List[str]], List[str], str]] = None
481
485
  # The image input. It can be an image instance, file name, URL, or base64 encoded string.
482
486
  # Can be formatted as:
483
487
  # - Single image for a single request
@@ -501,6 +505,8 @@ class EmbeddingReqInput:
501
505
  log_metrics: bool = True
502
506
  # The modalities of the image data [image, multi-images, video]
503
507
  modalities: Optional[List[str]] = None
508
+ # For cross-encoder requests
509
+ is_cross_encoder_request: bool = False
504
510
 
505
511
  def contains_mm_input(self) -> bool:
506
512
  return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
@@ -560,6 +566,16 @@ class EmbeddingReqInput:
560
566
  return self.rid
561
567
 
562
568
  def __getitem__(self, i):
569
+ if self.is_cross_encoder_request:
570
+ return EmbeddingReqInput(
571
+ text=[self.text[i]] if self.text is not None else None,
572
+ input_ids=None,
573
+ image_data=None,
574
+ sampling_params=self.sampling_params[i],
575
+ rid=self.rid[i],
576
+ is_cross_encoder_request=True,
577
+ )
578
+
563
579
  return EmbeddingReqInput(
564
580
  text=self.text[i] if self.text is not None else None,
565
581
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
@@ -579,6 +595,8 @@ class TokenizedEmbeddingReqInput:
579
595
  input_ids: List[int]
580
596
  # The image inputs
581
597
  image_inputs: dict
598
+ # The token type ids
599
+ token_type_ids: List[int]
582
600
  # Dummy sampling params for compatibility
583
601
  sampling_params: SamplingParams
584
602
 
@@ -843,6 +861,12 @@ class SetInternalStateReq:
843
861
  server_args: Dict[str, Any]
844
862
 
845
863
 
864
+ @dataclass
865
+ class V1RerankReqInput:
866
+ query: str
867
+ documents: List[str]
868
+
869
+
846
870
  @dataclass
847
871
  class SetInternalStateReqOutput:
848
872
  updated: bool
@@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC):
146
146
  request_obj,
147
147
  max_req_input_len,
148
148
  **kwargs,
149
- ):
149
+ ) -> Optional[Dict[str, Any]]:
150
150
  pass
151
151
 
152
152
  def get_estimated_frames_list(self, image_data):
@@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC):
261
261
 
262
262
  def load_mm_data(
263
263
  self,
264
- prompt: str,
264
+ prompt: str | List[int],
265
265
  multimodal_tokens: MultimodalSpecialTokens,
266
266
  max_req_input_len: int,
267
267
  image_data: Optional[list] = None,
@@ -0,0 +1,85 @@
1
+ from typing import Any, Dict, List, Optional, Type, cast
2
+
3
+ import torch.nn as nn
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.processing_utils import ProcessorMixin
6
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
7
+
8
+ from sglang.srt.managers.io_struct import (
9
+ EmbeddingReqInput,
10
+ GenerateReqInput,
11
+ ImageDataItem,
12
+ )
13
+ from sglang.srt.managers.multimodal_processors.base_processor import (
14
+ BaseMultimodalProcessor,
15
+ MultimodalSpecialTokens,
16
+ )
17
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
18
+ from sglang.srt.models.vila import VILAForConditionalGeneration
19
+ from sglang.srt.server_args import ServerArgs
20
+
21
+
22
+ class VILAProcessor(ProcessorMixin):
23
+ """A stub class for the VILA processor."""
24
+
25
+ tokenizer: PreTrainedTokenizerBase
26
+
27
+
28
+ class VILAMultimodalProcessor(BaseMultimodalProcessor):
29
+ models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
30
+
31
+ _processor: VILAProcessor
32
+
33
+ def __init__(
34
+ self,
35
+ hf_config: PretrainedConfig,
36
+ server_args: ServerArgs,
37
+ _processor: VILAProcessor,
38
+ ) -> None:
39
+ super().__init__(hf_config, server_args, _processor)
40
+
41
+ async def process_mm_data_async(
42
+ self,
43
+ image_data: Optional[ImageDataItem | List[ImageDataItem]],
44
+ input_text: str | List[int],
45
+ request_obj: GenerateReqInput | EmbeddingReqInput,
46
+ max_req_input_len: int,
47
+ **kwargs,
48
+ ) -> Optional[Dict[str, Any]]:
49
+ if not image_data:
50
+ return None
51
+
52
+ if not isinstance(image_data, list):
53
+ image_data = [image_data]
54
+
55
+ mm_data = self.load_mm_data(
56
+ prompt=input_text,
57
+ multimodal_tokens=MultimodalSpecialTokens(
58
+ image_token=self._processor.tokenizer.image_token
59
+ ),
60
+ max_req_input_len=max_req_input_len,
61
+ image_data=image_data,
62
+ )
63
+
64
+ inputs = self.process_mm_data(
65
+ input_text=mm_data.input_text,
66
+ images=mm_data.images,
67
+ )
68
+
69
+ image_offsets = self.get_mm_items_offset(
70
+ input_ids=inputs.input_ids[0],
71
+ mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
72
+ )
73
+
74
+ mm_items: List[MultimodalDataItem] = [
75
+ MultimodalDataItem(
76
+ modality=Modality.IMAGE,
77
+ image_offsets=image_offsets,
78
+ pixel_values=inputs.pixel_values,
79
+ )
80
+ ]
81
+
82
+ return dict(
83
+ input_ids=inputs.input_ids[0].tolist(),
84
+ mm_items=mm_items,
85
+ )
@@ -72,32 +72,33 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
72
72
 
73
73
  GLOBAL_SERVER_ARGS_KEYS = [
74
74
  "attention_backend",
75
+ "mm_attention_backend",
75
76
  "debug_tensor_dump_inject",
76
77
  "debug_tensor_dump_output_folder",
77
78
  "chunked_prefill_size",
78
- "deepep_mode",
79
79
  "device",
80
80
  "disable_chunked_prefix_cache",
81
81
  "disable_radix_cache",
82
- "enable_deepep_moe",
83
82
  "enable_dp_attention",
84
83
  "enable_two_batch_overlap",
85
84
  "enable_dp_lm_head",
85
+ "enable_deepep_moe",
86
+ "deepep_mode",
86
87
  "enable_ep_moe",
88
+ "moe_dense_tp_size",
89
+ "ep_dispatch_algorithm",
87
90
  "deepep_config",
91
+ "ep_num_redundant_experts",
88
92
  "enable_nan_detection",
89
93
  "flashinfer_mla_disable_ragged",
90
94
  "max_micro_batch_size",
91
- "moe_dense_tp_size",
92
- "ep_dispatch_algorithm",
93
95
  "disable_shared_experts_fusion",
94
96
  "sampling_backend",
95
97
  "speculative_accept_threshold_acc",
96
98
  "speculative_accept_threshold_single",
97
99
  "torchao_config",
98
100
  "triton_attention_reduce_in_fp32",
99
- "ep_num_redundant_experts",
100
- "mm_attention_backend",
101
+ "num_reserved_decode_tokens",
101
102
  ]
102
103
 
103
104
  # Put some global args for easy access
@@ -444,6 +445,7 @@ class Req:
444
445
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
445
446
  lora_path: Optional[str] = None,
446
447
  input_embeds: Optional[List[List[float]]] = None,
448
+ token_type_ids: List[int] = None,
447
449
  session_id: Optional[str] = None,
448
450
  custom_logit_processor: Optional[str] = None,
449
451
  return_hidden_states: bool = False,
@@ -469,6 +471,9 @@ class Req:
469
471
  self.session_id = session_id
470
472
  self.input_embeds = input_embeds
471
473
 
474
+ # for corss-endoder model
475
+ self.token_type_ids = token_type_ids
476
+
472
477
  # Sampling info
473
478
  if isinstance(sampling_params.custom_params, dict):
474
479
  sampling_params = copy.copy(sampling_params)
@@ -840,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
840
845
  # Batched arguments to model runner
841
846
  input_ids: torch.Tensor = None # shape: [b], int64
842
847
  input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
848
+ token_type_ids: torch.Tensor = None # shape: [b], int64
843
849
  req_pool_indices: torch.Tensor = None # shape: [b], int64
844
850
  seq_lens: torch.Tensor = None # shape: [b], int64
845
851
  # The output locations of the KV cache
@@ -1141,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1141
1147
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1142
1148
  extend_lens = [r.extend_input_len for r in reqs]
1143
1149
 
1150
+ token_type_ids = [
1151
+ r.token_type_ids for r in reqs if r.token_type_ids is not None
1152
+ ]
1153
+
1144
1154
  req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1145
1155
  self.device, non_blocking=True
1146
1156
  )
@@ -1153,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1153
1163
  prefix_lens_tensor = torch.tensor(
1154
1164
  prefix_lens, dtype=torch.int64, device=self.device
1155
1165
  )
1166
+
1167
+ token_type_ids_tensor = None
1168
+ if len(token_type_ids) > 0:
1169
+ token_type_ids_tensor = torch.tensor(
1170
+ sum(token_type_ids, []), dtype=torch.int64
1171
+ ).to(self.device, non_blocking=True)
1172
+
1156
1173
  extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1157
1174
 
1158
1175
  # Copy prefix and do some basic check
@@ -1268,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1268
1285
  self.device, non_blocking=True
1269
1286
  )
1270
1287
  self.multimodal_inputs = multimodal_inputs
1288
+ self.token_type_ids = token_type_ids_tensor
1271
1289
  self.seq_lens_sum = sum(seq_lens)
1272
1290
 
1273
1291
  if self.return_logprob:
@@ -1414,6 +1432,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1414
1432
  req = self.reqs[idx]
1415
1433
  retracted_reqs.append(req)
1416
1434
 
1435
+ if server_args.disaggregation_mode == "decode":
1436
+ req.offload_kv_cache(
1437
+ self.req_to_token_pool, self.token_to_kv_pool_allocator
1438
+ )
1439
+
1417
1440
  if isinstance(self.tree_cache, ChunkCache):
1418
1441
  # ChunkCache does not have eviction
1419
1442
  token_indices = self.req_to_token_pool.req_to_token[
@@ -1445,6 +1468,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1445
1468
 
1446
1469
  req.reset_for_retract()
1447
1470
 
1471
+ if len(retracted_reqs) == 0:
1472
+ # Corner case: only one request left
1473
+ raise ValueError(
1474
+ "Failed to retract any request. No space left for only one request."
1475
+ )
1476
+
1448
1477
  self.filter_batch(keep_indices=sorted_indices)
1449
1478
 
1450
1479
  # Reqs in batch are filtered
@@ -1702,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1702
1731
  lora_paths=[req.lora_path for req in self.reqs],
1703
1732
  sampling_info=self.sampling_info,
1704
1733
  input_embeds=self.input_embeds,
1734
+ token_type_ids=self.token_type_ids,
1705
1735
  spec_algorithm=self.spec_algorithm,
1706
1736
  spec_info=self.spec_info,
1707
1737
  capture_hidden_mode=(
@@ -1795,6 +1825,9 @@ class ModelWorkerBatch:
1795
1825
  # The input Embeds
1796
1826
  input_embeds: Optional[torch.tensor] = None
1797
1827
 
1828
+ # For corss-encoder model
1829
+ token_type_ids: Optional[torch.Tensor] = None
1830
+
1798
1831
  # Speculative decoding
1799
1832
  spec_algorithm: SpeculativeAlgorithm = None
1800
1833
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None