sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import dataclasses
2
2
  import logging
3
- from typing import Dict, List, Optional, Sequence
3
+ from dataclasses import replace
4
+ from typing import Dict, List, Optional, Sequence, Union
4
5
 
5
6
  import torch
6
7
 
@@ -12,10 +13,11 @@ from sglang.srt.layers.communicator import (
12
13
  )
13
14
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
14
15
  from sglang.srt.layers.quantization import deep_gemm_wrapper
15
- from sglang.srt.managers.schedule_batch import global_server_args_dict
16
+ from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
16
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
17
18
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
18
19
  from sglang.srt.operations_strategy import OperationsStrategy
20
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
19
21
  from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
20
22
 
21
23
  _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
26
28
  # -------------------------------- Compute Basic Info ---------------------------------------
27
29
 
28
30
 
31
+ def get_token_num_per_seq(
32
+ forward_mode: ForwardMode,
33
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
34
+ ):
35
+ if forward_mode.is_target_verify():
36
+ return spec_info.draft_token_num
37
+ elif forward_mode.is_decode():
38
+ return 1
39
+ elif forward_mode.is_idle():
40
+ return 0
41
+ else:
42
+ # For extend, we should not use `token_num_per_seq`.
43
+ return None
44
+
45
+
29
46
  # TODO: may smartly disable TBO when batch size is too small b/c it will slow down
30
47
  def compute_split_seq_index(
31
48
  forward_mode: "ForwardMode",
32
49
  num_tokens: int,
33
50
  extend_lens: Optional[Sequence[int]],
51
+ token_num_per_seq: Optional[int],
34
52
  ) -> Optional[int]:
35
- if forward_mode.is_extend():
53
+ if forward_mode == ForwardMode.EXTEND:
36
54
  assert extend_lens is not None
37
55
  return _split_array_by_half_sum(extend_lens)
38
- elif forward_mode.is_decode():
39
- return num_tokens // 2
56
+ elif forward_mode.is_target_verify() or forward_mode.is_decode():
57
+ assert token_num_per_seq is not None
58
+ return (num_tokens // token_num_per_seq) // 2
40
59
  elif forward_mode.is_idle():
41
60
  assert num_tokens == 0
42
61
  return 0
@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
63
82
  return best_index
64
83
 
65
84
 
85
+ def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
86
+ if seq_index == 0:
87
+ return 0
88
+
89
+ offset = 0
90
+ max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
91
+ for i in range(max_seq_len):
92
+ offset += (
93
+ spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
94
+ ) * spec_info.draft_token_num
95
+ return offset
96
+
97
+
98
+ def split_spec_info(
99
+ spec_info: Optional[EagleVerifyInput],
100
+ start_seq_index: int,
101
+ end_seq_index: int,
102
+ start_token_index: int,
103
+ end_token_index: int,
104
+ ):
105
+ if spec_info is None:
106
+ return None
107
+ if spec_info.draft_token is not None:
108
+ draft_token = spec_info.draft_token[start_token_index:end_token_index]
109
+ else:
110
+ draft_token = None
111
+ if spec_info.custom_mask is not None and spec_info.draft_token is not None:
112
+ custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
113
+ if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
114
+ custom_mask_end = spec_info.custom_mask.shape[0]
115
+ else:
116
+ custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
117
+
118
+ if custom_mask_end > custom_mask_start:
119
+ custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
120
+ else:
121
+ custom_mask = spec_info.custom_mask
122
+ else:
123
+ custom_mask = spec_info.custom_mask
124
+ if spec_info.positions is not None:
125
+ positions = spec_info.positions[start_token_index:end_token_index]
126
+ else:
127
+ positions = None
128
+ if spec_info.retrive_index is not None:
129
+ retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
130
+ else:
131
+ retrive_index = None
132
+ if spec_info.retrive_next_token is not None:
133
+ retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
134
+ else:
135
+ retrive_next_token = None
136
+ if spec_info.retrive_next_sibling is not None:
137
+ retrive_next_sibling = spec_info.retrive_next_sibling[
138
+ start_seq_index:end_seq_index
139
+ ]
140
+ else:
141
+ retrive_next_sibling = None
142
+ if spec_info.retrive_cum_len is not None:
143
+ retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
144
+ else:
145
+ retrive_cum_len = None
146
+
147
+ if spec_info.seq_lens_cpu is not None:
148
+ seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
149
+ else:
150
+ seq_lens_cpu = None
151
+ if seq_lens_cpu is not None:
152
+ seq_lens_sum = seq_lens_cpu.sum()
153
+ else:
154
+ seq_lens_sum = None
155
+ output_spec_info = replace(
156
+ spec_info,
157
+ custom_mask=custom_mask,
158
+ draft_token=draft_token,
159
+ positions=positions,
160
+ retrive_index=retrive_index,
161
+ retrive_next_token=retrive_next_token,
162
+ retrive_next_sibling=retrive_next_sibling,
163
+ retrive_cum_len=retrive_cum_len,
164
+ seq_lens_cpu=seq_lens_cpu,
165
+ seq_lens_sum=seq_lens_sum,
166
+ )
167
+ return output_spec_info
168
+
169
+
66
170
  def compute_split_token_index(
67
171
  split_seq_index: int,
68
172
  forward_mode: "ForwardMode",
69
173
  extend_seq_lens: Optional[Sequence[int]],
174
+ token_num_per_seq: Optional[int],
70
175
  ) -> int:
71
- if forward_mode.is_extend():
176
+ if forward_mode == ForwardMode.EXTEND:
72
177
  assert extend_seq_lens is not None
73
178
  return sum(extend_seq_lens[:split_seq_index])
74
- elif forward_mode.is_decode():
75
- return split_seq_index
179
+ elif forward_mode.is_target_verify() or forward_mode.is_decode():
180
+ assert token_num_per_seq is not None
181
+ return split_seq_index * token_num_per_seq
76
182
  elif forward_mode.is_idle():
77
183
  assert split_seq_index == 0
78
184
  return 0
@@ -83,19 +189,25 @@ def compute_split_token_index(
83
189
  def compute_split_indices_for_cuda_graph_replay(
84
190
  forward_mode: ForwardMode,
85
191
  cuda_graph_num_tokens: int,
192
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
86
193
  ):
87
194
  forward_mode_for_tbo_split = (
88
195
  forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
89
196
  )
197
+ token_num_per_seq = get_token_num_per_seq(
198
+ forward_mode=forward_mode, spec_info=spec_info
199
+ )
90
200
  tbo_split_seq_index = compute_split_seq_index(
91
201
  forward_mode=forward_mode_for_tbo_split,
92
202
  num_tokens=cuda_graph_num_tokens,
93
203
  extend_lens=None,
204
+ token_num_per_seq=token_num_per_seq,
94
205
  )
95
206
  tbo_split_token_index = compute_split_token_index(
96
207
  split_seq_index=tbo_split_seq_index,
97
208
  forward_mode=forward_mode_for_tbo_split,
98
209
  extend_seq_lens=None,
210
+ token_num_per_seq=token_num_per_seq,
99
211
  )
100
212
  return tbo_split_seq_index, tbo_split_token_index
101
213
 
@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
110
222
  def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
111
223
  if not global_server_args_dict["enable_two_batch_overlap"]:
112
224
  return
225
+ token_num_per_seq = get_token_num_per_seq(
226
+ forward_mode=batch.forward_mode, spec_info=batch.spec_info
227
+ )
113
228
 
114
229
  batch.tbo_split_seq_index = compute_split_seq_index(
115
230
  forward_mode=batch.forward_mode,
116
231
  num_tokens=num_tokens,
117
232
  extend_lens=None,
233
+ token_num_per_seq=token_num_per_seq,
118
234
  )
119
235
  # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
120
236
  assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
129
245
  )
130
246
 
131
247
  def replay_prepare(
132
- self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int
248
+ self,
249
+ forward_mode: ForwardMode,
250
+ bs: int,
251
+ num_token_non_padded: int,
252
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
133
253
  ):
254
+ token_num_per_seq = get_token_num_per_seq(
255
+ forward_mode=forward_mode, spec_info=spec_info
256
+ )
134
257
  tbo_split_seq_index, tbo_split_token_index = (
135
258
  compute_split_indices_for_cuda_graph_replay(
136
259
  forward_mode=forward_mode,
137
- # TODO support bs!=num_tokens
138
- cuda_graph_num_tokens=bs,
260
+ cuda_graph_num_tokens=bs * token_num_per_seq,
261
+ spec_info=spec_info,
139
262
  )
140
263
  )
141
264
 
@@ -149,19 +272,38 @@ class TboCudaGraphRunnerPlugin:
149
272
 
150
273
  class TboDPAttentionPreparer:
151
274
  def prepare_all_gather(
152
- self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
275
+ self,
276
+ local_batch: ScheduleBatch,
277
+ deepep_mode: DeepEPMode,
278
+ enable_deepep_moe: bool,
279
+ enable_two_batch_overlap: bool,
153
280
  ):
154
281
  self.enable_two_batch_overlap = enable_two_batch_overlap
155
282
 
156
283
  if local_batch is not None:
284
+ token_num_per_seq = get_token_num_per_seq(
285
+ forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
286
+ )
287
+
288
+ if (
289
+ local_batch.forward_mode.is_target_verify()
290
+ or local_batch.forward_mode.is_decode()
291
+ ):
292
+ num_tokens = local_batch.batch_size() * token_num_per_seq
293
+ else:
294
+ num_tokens = local_batch.extend_num_tokens
157
295
  self.local_tbo_split_seq_index = compute_split_seq_index(
158
296
  forward_mode=local_batch.forward_mode,
159
- num_tokens=local_batch.input_ids.shape[0],
297
+ num_tokens=num_tokens,
160
298
  extend_lens=local_batch.extend_lens,
299
+ token_num_per_seq=token_num_per_seq,
161
300
  )
162
- resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
301
+ resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
163
302
  local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
164
- local_batch.forward_mode.is_extend()
303
+ (
304
+ local_batch.forward_mode.is_extend()
305
+ and not local_batch.forward_mode.is_target_verify()
306
+ )
165
307
  and enable_deepep_moe
166
308
  and (resolved_deepep_mode == DeepEPMode.low_latency)
167
309
  )
@@ -218,8 +360,8 @@ class TboDPAttentionPreparer:
218
360
 
219
361
  class TboForwardBatchPreparer:
220
362
  @classmethod
221
- def prepare(cls, batch: ForwardBatch):
222
- if batch.tbo_split_seq_index is None:
363
+ def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
364
+ if batch.tbo_split_seq_index is None or is_draft_worker:
223
365
  return
224
366
 
225
367
  tbo_children_num_token_non_padded = (
@@ -242,7 +384,9 @@ class TboForwardBatchPreparer:
242
384
  f"TboForwardBatchPreparer.prepare "
243
385
  f"tbo_split_seq_index={batch.tbo_split_seq_index} "
244
386
  f"tbo_split_token_index={tbo_split_token_index} "
245
- f"extend_seq_lens={batch.extend_seq_lens_cpu}"
387
+ f"extend_seq_lens={batch.extend_seq_lens_cpu} "
388
+ f"bs={batch.batch_size} "
389
+ f"forward_mode={batch.forward_mode}"
246
390
  )
247
391
 
248
392
  assert isinstance(batch.attn_backend, TboAttnBackend)
@@ -286,6 +430,9 @@ class TboForwardBatchPreparer:
286
430
  output_attn_backend: AttentionBackend,
287
431
  out_num_token_non_padded: torch.Tensor,
288
432
  ):
433
+ assert (
434
+ end_token_index >= start_token_index
435
+ ), f"{end_token_index=}, {start_token_index=}, batch={batch}"
289
436
  num_tokens = batch.input_ids.shape[0]
290
437
  num_seqs = batch.batch_size
291
438
 
@@ -317,11 +464,30 @@ class TboForwardBatchPreparer:
317
464
  old_value = getattr(batch, key)
318
465
  if old_value is None:
319
466
  continue
467
+ elif batch.forward_mode.is_target_verify() and (
468
+ key == "extend_seq_lens"
469
+ or key == "extend_prefix_lens"
470
+ or key == "extend_start_loc"
471
+ or key == "extend_prefix_lens_cpu"
472
+ or key == "extend_seq_lens_cpu"
473
+ or key == "extend_logprob_start_lens_cpu"
474
+ ):
475
+ output_dict[key] = None
476
+ continue
320
477
  assert (
321
478
  len(old_value) == num_seqs
322
479
  ), f"{key=} {old_value=} {num_seqs=} {batch=}"
323
480
  output_dict[key] = old_value[start_seq_index:end_seq_index]
324
481
 
482
+ spec_info = getattr(batch, "spec_info")
483
+ output_spec_info = split_spec_info(
484
+ spec_info=spec_info,
485
+ start_token_index=start_token_index,
486
+ end_token_index=end_token_index,
487
+ start_seq_index=start_seq_index,
488
+ end_seq_index=end_seq_index,
489
+ )
490
+ output_dict["spec_info"] = output_spec_info
325
491
  for key in [
326
492
  "forward_mode",
327
493
  "return_logprob",
@@ -329,24 +495,26 @@ class TboForwardBatchPreparer:
329
495
  "token_to_kv_pool",
330
496
  "can_run_dp_cuda_graph",
331
497
  "global_forward_mode",
332
- "spec_info",
333
498
  "spec_algorithm",
334
499
  "capture_hidden_mode",
335
500
  "padded_static_len",
336
501
  "mrope_positions", # only used by qwen2-vl, thus not care
337
502
  ]:
338
503
  output_dict[key] = getattr(batch, key)
339
-
340
- assert (
341
- _compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
342
- == batch.extend_num_tokens
343
- ), f"{batch=}"
504
+ if not batch.forward_mode.is_target_verify():
505
+ assert (
506
+ _compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
507
+ == batch.extend_num_tokens
508
+ ), f"{batch=}"
344
509
  extend_num_tokens = _compute_extend_num_tokens(
345
510
  output_dict["input_ids"], output_dict["forward_mode"]
346
511
  )
347
512
 
348
513
  # TODO improve, e.g. unify w/ `init_raw`
349
- if global_server_args_dict["moe_dense_tp_size"] == 1:
514
+ if (
515
+ global_server_args_dict["moe_dense_tp_size"] == 1
516
+ and batch.gathered_buffer is not None
517
+ ):
350
518
  sum_len = end_token_index - start_token_index
351
519
  gathered_buffer = torch.zeros(
352
520
  (sum_len, batch.gathered_buffer.shape[1]),
@@ -416,18 +584,26 @@ class TboForwardBatchPreparer:
416
584
 
417
585
  @classmethod
418
586
  def _compute_split_token_index(cls, batch: ForwardBatch):
587
+ token_num_per_seq = get_token_num_per_seq(
588
+ forward_mode=batch.forward_mode, spec_info=batch.spec_info
589
+ )
419
590
  return compute_split_token_index(
420
591
  split_seq_index=batch.tbo_split_seq_index,
421
592
  forward_mode=batch.forward_mode,
422
593
  extend_seq_lens=batch.extend_seq_lens_cpu,
594
+ token_num_per_seq=token_num_per_seq,
423
595
  )
424
596
 
425
597
 
426
598
  def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
427
- if forward_mode.is_extend():
428
- return input_ids.shape[0]
429
- elif forward_mode.is_decode() or forward_mode.is_idle():
599
+ if (
600
+ forward_mode.is_decode()
601
+ or forward_mode.is_idle()
602
+ or forward_mode.is_target_verify()
603
+ ):
430
604
  return None
605
+ elif forward_mode.is_extend():
606
+ return input_ids.shape[0]
431
607
  raise NotImplementedError
432
608
 
433
609