sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/two_batch_overlap.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import logging
|
3
|
-
from
|
3
|
+
from dataclasses import replace
|
4
|
+
from typing import Dict, List, Optional, Sequence, Union
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
@@ -12,10 +13,11 @@ from sglang.srt.layers.communicator import (
|
|
12
13
|
)
|
13
14
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
14
15
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
15
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
16
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
16
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
17
18
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
18
19
|
from sglang.srt.operations_strategy import OperationsStrategy
|
20
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
19
21
|
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
20
22
|
|
21
23
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
|
|
26
28
|
# -------------------------------- Compute Basic Info ---------------------------------------
|
27
29
|
|
28
30
|
|
31
|
+
def get_token_num_per_seq(
|
32
|
+
forward_mode: ForwardMode,
|
33
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
34
|
+
):
|
35
|
+
if forward_mode.is_target_verify():
|
36
|
+
return spec_info.draft_token_num
|
37
|
+
elif forward_mode.is_decode():
|
38
|
+
return 1
|
39
|
+
elif forward_mode.is_idle():
|
40
|
+
return 0
|
41
|
+
else:
|
42
|
+
# For extend, we should not use `token_num_per_seq`.
|
43
|
+
return None
|
44
|
+
|
45
|
+
|
29
46
|
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
30
47
|
def compute_split_seq_index(
|
31
48
|
forward_mode: "ForwardMode",
|
32
49
|
num_tokens: int,
|
33
50
|
extend_lens: Optional[Sequence[int]],
|
51
|
+
token_num_per_seq: Optional[int],
|
34
52
|
) -> Optional[int]:
|
35
|
-
if forward_mode.
|
53
|
+
if forward_mode == ForwardMode.EXTEND:
|
36
54
|
assert extend_lens is not None
|
37
55
|
return _split_array_by_half_sum(extend_lens)
|
38
|
-
elif forward_mode.is_decode():
|
39
|
-
|
56
|
+
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
57
|
+
assert token_num_per_seq is not None
|
58
|
+
return (num_tokens // token_num_per_seq) // 2
|
40
59
|
elif forward_mode.is_idle():
|
41
60
|
assert num_tokens == 0
|
42
61
|
return 0
|
@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
|
63
82
|
return best_index
|
64
83
|
|
65
84
|
|
85
|
+
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
86
|
+
if seq_index == 0:
|
87
|
+
return 0
|
88
|
+
|
89
|
+
offset = 0
|
90
|
+
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
|
91
|
+
for i in range(max_seq_len):
|
92
|
+
offset += (
|
93
|
+
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
|
94
|
+
) * spec_info.draft_token_num
|
95
|
+
return offset
|
96
|
+
|
97
|
+
|
98
|
+
def split_spec_info(
|
99
|
+
spec_info: Optional[EagleVerifyInput],
|
100
|
+
start_seq_index: int,
|
101
|
+
end_seq_index: int,
|
102
|
+
start_token_index: int,
|
103
|
+
end_token_index: int,
|
104
|
+
):
|
105
|
+
if spec_info is None:
|
106
|
+
return None
|
107
|
+
if spec_info.draft_token is not None:
|
108
|
+
draft_token = spec_info.draft_token[start_token_index:end_token_index]
|
109
|
+
else:
|
110
|
+
draft_token = None
|
111
|
+
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
|
112
|
+
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
|
113
|
+
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
|
114
|
+
custom_mask_end = spec_info.custom_mask.shape[0]
|
115
|
+
else:
|
116
|
+
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
|
117
|
+
|
118
|
+
if custom_mask_end > custom_mask_start:
|
119
|
+
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
|
120
|
+
else:
|
121
|
+
custom_mask = spec_info.custom_mask
|
122
|
+
else:
|
123
|
+
custom_mask = spec_info.custom_mask
|
124
|
+
if spec_info.positions is not None:
|
125
|
+
positions = spec_info.positions[start_token_index:end_token_index]
|
126
|
+
else:
|
127
|
+
positions = None
|
128
|
+
if spec_info.retrive_index is not None:
|
129
|
+
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
|
130
|
+
else:
|
131
|
+
retrive_index = None
|
132
|
+
if spec_info.retrive_next_token is not None:
|
133
|
+
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
|
134
|
+
else:
|
135
|
+
retrive_next_token = None
|
136
|
+
if spec_info.retrive_next_sibling is not None:
|
137
|
+
retrive_next_sibling = spec_info.retrive_next_sibling[
|
138
|
+
start_seq_index:end_seq_index
|
139
|
+
]
|
140
|
+
else:
|
141
|
+
retrive_next_sibling = None
|
142
|
+
if spec_info.retrive_cum_len is not None:
|
143
|
+
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
|
144
|
+
else:
|
145
|
+
retrive_cum_len = None
|
146
|
+
|
147
|
+
if spec_info.seq_lens_cpu is not None:
|
148
|
+
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
|
149
|
+
else:
|
150
|
+
seq_lens_cpu = None
|
151
|
+
if seq_lens_cpu is not None:
|
152
|
+
seq_lens_sum = seq_lens_cpu.sum()
|
153
|
+
else:
|
154
|
+
seq_lens_sum = None
|
155
|
+
output_spec_info = replace(
|
156
|
+
spec_info,
|
157
|
+
custom_mask=custom_mask,
|
158
|
+
draft_token=draft_token,
|
159
|
+
positions=positions,
|
160
|
+
retrive_index=retrive_index,
|
161
|
+
retrive_next_token=retrive_next_token,
|
162
|
+
retrive_next_sibling=retrive_next_sibling,
|
163
|
+
retrive_cum_len=retrive_cum_len,
|
164
|
+
seq_lens_cpu=seq_lens_cpu,
|
165
|
+
seq_lens_sum=seq_lens_sum,
|
166
|
+
)
|
167
|
+
return output_spec_info
|
168
|
+
|
169
|
+
|
66
170
|
def compute_split_token_index(
|
67
171
|
split_seq_index: int,
|
68
172
|
forward_mode: "ForwardMode",
|
69
173
|
extend_seq_lens: Optional[Sequence[int]],
|
174
|
+
token_num_per_seq: Optional[int],
|
70
175
|
) -> int:
|
71
|
-
if forward_mode.
|
176
|
+
if forward_mode == ForwardMode.EXTEND:
|
72
177
|
assert extend_seq_lens is not None
|
73
178
|
return sum(extend_seq_lens[:split_seq_index])
|
74
|
-
elif forward_mode.is_decode():
|
75
|
-
|
179
|
+
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
180
|
+
assert token_num_per_seq is not None
|
181
|
+
return split_seq_index * token_num_per_seq
|
76
182
|
elif forward_mode.is_idle():
|
77
183
|
assert split_seq_index == 0
|
78
184
|
return 0
|
@@ -83,19 +189,25 @@ def compute_split_token_index(
|
|
83
189
|
def compute_split_indices_for_cuda_graph_replay(
|
84
190
|
forward_mode: ForwardMode,
|
85
191
|
cuda_graph_num_tokens: int,
|
192
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
86
193
|
):
|
87
194
|
forward_mode_for_tbo_split = (
|
88
195
|
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
89
196
|
)
|
197
|
+
token_num_per_seq = get_token_num_per_seq(
|
198
|
+
forward_mode=forward_mode, spec_info=spec_info
|
199
|
+
)
|
90
200
|
tbo_split_seq_index = compute_split_seq_index(
|
91
201
|
forward_mode=forward_mode_for_tbo_split,
|
92
202
|
num_tokens=cuda_graph_num_tokens,
|
93
203
|
extend_lens=None,
|
204
|
+
token_num_per_seq=token_num_per_seq,
|
94
205
|
)
|
95
206
|
tbo_split_token_index = compute_split_token_index(
|
96
207
|
split_seq_index=tbo_split_seq_index,
|
97
208
|
forward_mode=forward_mode_for_tbo_split,
|
98
209
|
extend_seq_lens=None,
|
210
|
+
token_num_per_seq=token_num_per_seq,
|
99
211
|
)
|
100
212
|
return tbo_split_seq_index, tbo_split_token_index
|
101
213
|
|
@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
|
|
110
222
|
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
111
223
|
if not global_server_args_dict["enable_two_batch_overlap"]:
|
112
224
|
return
|
225
|
+
token_num_per_seq = get_token_num_per_seq(
|
226
|
+
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
227
|
+
)
|
113
228
|
|
114
229
|
batch.tbo_split_seq_index = compute_split_seq_index(
|
115
230
|
forward_mode=batch.forward_mode,
|
116
231
|
num_tokens=num_tokens,
|
117
232
|
extend_lens=None,
|
233
|
+
token_num_per_seq=token_num_per_seq,
|
118
234
|
)
|
119
235
|
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
120
236
|
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
|
|
129
245
|
)
|
130
246
|
|
131
247
|
def replay_prepare(
|
132
|
-
self,
|
248
|
+
self,
|
249
|
+
forward_mode: ForwardMode,
|
250
|
+
bs: int,
|
251
|
+
num_token_non_padded: int,
|
252
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
133
253
|
):
|
254
|
+
token_num_per_seq = get_token_num_per_seq(
|
255
|
+
forward_mode=forward_mode, spec_info=spec_info
|
256
|
+
)
|
134
257
|
tbo_split_seq_index, tbo_split_token_index = (
|
135
258
|
compute_split_indices_for_cuda_graph_replay(
|
136
259
|
forward_mode=forward_mode,
|
137
|
-
|
138
|
-
|
260
|
+
cuda_graph_num_tokens=bs * token_num_per_seq,
|
261
|
+
spec_info=spec_info,
|
139
262
|
)
|
140
263
|
)
|
141
264
|
|
@@ -149,19 +272,38 @@ class TboCudaGraphRunnerPlugin:
|
|
149
272
|
|
150
273
|
class TboDPAttentionPreparer:
|
151
274
|
def prepare_all_gather(
|
152
|
-
self,
|
275
|
+
self,
|
276
|
+
local_batch: ScheduleBatch,
|
277
|
+
deepep_mode: DeepEPMode,
|
278
|
+
enable_deepep_moe: bool,
|
279
|
+
enable_two_batch_overlap: bool,
|
153
280
|
):
|
154
281
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
155
282
|
|
156
283
|
if local_batch is not None:
|
284
|
+
token_num_per_seq = get_token_num_per_seq(
|
285
|
+
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
|
286
|
+
)
|
287
|
+
|
288
|
+
if (
|
289
|
+
local_batch.forward_mode.is_target_verify()
|
290
|
+
or local_batch.forward_mode.is_decode()
|
291
|
+
):
|
292
|
+
num_tokens = local_batch.batch_size() * token_num_per_seq
|
293
|
+
else:
|
294
|
+
num_tokens = local_batch.extend_num_tokens
|
157
295
|
self.local_tbo_split_seq_index = compute_split_seq_index(
|
158
296
|
forward_mode=local_batch.forward_mode,
|
159
|
-
num_tokens=
|
297
|
+
num_tokens=num_tokens,
|
160
298
|
extend_lens=local_batch.extend_lens,
|
299
|
+
token_num_per_seq=token_num_per_seq,
|
161
300
|
)
|
162
|
-
resolved_deepep_mode = deepep_mode.resolve(local_batch.
|
301
|
+
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
|
163
302
|
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
164
|
-
|
303
|
+
(
|
304
|
+
local_batch.forward_mode.is_extend()
|
305
|
+
and not local_batch.forward_mode.is_target_verify()
|
306
|
+
)
|
165
307
|
and enable_deepep_moe
|
166
308
|
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
167
309
|
)
|
@@ -218,8 +360,8 @@ class TboDPAttentionPreparer:
|
|
218
360
|
|
219
361
|
class TboForwardBatchPreparer:
|
220
362
|
@classmethod
|
221
|
-
def prepare(cls, batch: ForwardBatch):
|
222
|
-
if batch.tbo_split_seq_index is None:
|
363
|
+
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
|
364
|
+
if batch.tbo_split_seq_index is None or is_draft_worker:
|
223
365
|
return
|
224
366
|
|
225
367
|
tbo_children_num_token_non_padded = (
|
@@ -242,7 +384,9 @@ class TboForwardBatchPreparer:
|
|
242
384
|
f"TboForwardBatchPreparer.prepare "
|
243
385
|
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
244
386
|
f"tbo_split_token_index={tbo_split_token_index} "
|
245
|
-
f"extend_seq_lens={batch.extend_seq_lens_cpu}"
|
387
|
+
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
388
|
+
f"bs={batch.batch_size} "
|
389
|
+
f"forward_mode={batch.forward_mode}"
|
246
390
|
)
|
247
391
|
|
248
392
|
assert isinstance(batch.attn_backend, TboAttnBackend)
|
@@ -286,6 +430,9 @@ class TboForwardBatchPreparer:
|
|
286
430
|
output_attn_backend: AttentionBackend,
|
287
431
|
out_num_token_non_padded: torch.Tensor,
|
288
432
|
):
|
433
|
+
assert (
|
434
|
+
end_token_index >= start_token_index
|
435
|
+
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
|
289
436
|
num_tokens = batch.input_ids.shape[0]
|
290
437
|
num_seqs = batch.batch_size
|
291
438
|
|
@@ -317,30 +464,49 @@ class TboForwardBatchPreparer:
|
|
317
464
|
old_value = getattr(batch, key)
|
318
465
|
if old_value is None:
|
319
466
|
continue
|
467
|
+
elif batch.forward_mode.is_target_verify() and (
|
468
|
+
key == "extend_seq_lens"
|
469
|
+
or key == "extend_prefix_lens"
|
470
|
+
or key == "extend_start_loc"
|
471
|
+
or key == "extend_prefix_lens_cpu"
|
472
|
+
or key == "extend_seq_lens_cpu"
|
473
|
+
or key == "extend_logprob_start_lens_cpu"
|
474
|
+
):
|
475
|
+
output_dict[key] = None
|
476
|
+
continue
|
320
477
|
assert (
|
321
478
|
len(old_value) == num_seqs
|
322
479
|
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
323
480
|
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
324
481
|
|
482
|
+
spec_info = getattr(batch, "spec_info")
|
483
|
+
output_spec_info = split_spec_info(
|
484
|
+
spec_info=spec_info,
|
485
|
+
start_token_index=start_token_index,
|
486
|
+
end_token_index=end_token_index,
|
487
|
+
start_seq_index=start_seq_index,
|
488
|
+
end_seq_index=end_seq_index,
|
489
|
+
)
|
490
|
+
output_dict["spec_info"] = output_spec_info
|
325
491
|
for key in [
|
326
492
|
"forward_mode",
|
493
|
+
"is_extend_in_batch",
|
327
494
|
"return_logprob",
|
328
495
|
"req_to_token_pool",
|
329
496
|
"token_to_kv_pool",
|
330
497
|
"can_run_dp_cuda_graph",
|
331
498
|
"global_forward_mode",
|
332
|
-
"spec_info",
|
333
499
|
"spec_algorithm",
|
334
500
|
"capture_hidden_mode",
|
335
501
|
"padded_static_len",
|
336
502
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
337
503
|
]:
|
338
504
|
output_dict[key] = getattr(batch, key)
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
505
|
+
if not batch.forward_mode.is_target_verify():
|
506
|
+
assert (
|
507
|
+
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
508
|
+
== batch.extend_num_tokens
|
509
|
+
), f"{batch=}"
|
344
510
|
extend_num_tokens = _compute_extend_num_tokens(
|
345
511
|
output_dict["input_ids"], output_dict["forward_mode"]
|
346
512
|
)
|
@@ -385,6 +551,8 @@ class TboForwardBatchPreparer:
|
|
385
551
|
top_p_normalized_logprobs=False,
|
386
552
|
top_p=None,
|
387
553
|
mm_inputs=None,
|
554
|
+
top_logprobs_nums=None,
|
555
|
+
token_ids_logprobs=None,
|
388
556
|
)
|
389
557
|
)
|
390
558
|
|
@@ -419,18 +587,26 @@ class TboForwardBatchPreparer:
|
|
419
587
|
|
420
588
|
@classmethod
|
421
589
|
def _compute_split_token_index(cls, batch: ForwardBatch):
|
590
|
+
token_num_per_seq = get_token_num_per_seq(
|
591
|
+
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
592
|
+
)
|
422
593
|
return compute_split_token_index(
|
423
594
|
split_seq_index=batch.tbo_split_seq_index,
|
424
595
|
forward_mode=batch.forward_mode,
|
425
596
|
extend_seq_lens=batch.extend_seq_lens_cpu,
|
597
|
+
token_num_per_seq=token_num_per_seq,
|
426
598
|
)
|
427
599
|
|
428
600
|
|
429
601
|
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
430
|
-
if
|
431
|
-
|
432
|
-
|
602
|
+
if (
|
603
|
+
forward_mode.is_decode()
|
604
|
+
or forward_mode.is_idle()
|
605
|
+
or forward_mode.is_target_verify()
|
606
|
+
):
|
433
607
|
return None
|
608
|
+
elif forward_mode.is_extend():
|
609
|
+
return input_ids.shape[0]
|
434
610
|
raise NotImplementedError
|
435
611
|
|
436
612
|
|