sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/two_batch_overlap.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import logging
|
3
|
-
from
|
3
|
+
from dataclasses import replace
|
4
|
+
from typing import Dict, List, Optional, Sequence, Union
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
@@ -12,10 +13,11 @@ from sglang.srt.layers.communicator import (
|
|
12
13
|
)
|
13
14
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
14
15
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
15
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
16
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
16
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
17
18
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
18
19
|
from sglang.srt.operations_strategy import OperationsStrategy
|
20
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
19
21
|
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
20
22
|
|
21
23
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
@@ -26,17 +28,34 @@ logger = logging.getLogger(__name__)
|
|
26
28
|
# -------------------------------- Compute Basic Info ---------------------------------------
|
27
29
|
|
28
30
|
|
31
|
+
def get_token_num_per_seq(
|
32
|
+
forward_mode: ForwardMode,
|
33
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
34
|
+
):
|
35
|
+
if forward_mode.is_target_verify():
|
36
|
+
return spec_info.draft_token_num
|
37
|
+
elif forward_mode.is_decode():
|
38
|
+
return 1
|
39
|
+
elif forward_mode.is_idle():
|
40
|
+
return 0
|
41
|
+
else:
|
42
|
+
# For extend, we should not use `token_num_per_seq`.
|
43
|
+
return None
|
44
|
+
|
45
|
+
|
29
46
|
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
|
30
47
|
def compute_split_seq_index(
|
31
48
|
forward_mode: "ForwardMode",
|
32
49
|
num_tokens: int,
|
33
50
|
extend_lens: Optional[Sequence[int]],
|
51
|
+
token_num_per_seq: Optional[int],
|
34
52
|
) -> Optional[int]:
|
35
|
-
if forward_mode.
|
53
|
+
if forward_mode == ForwardMode.EXTEND:
|
36
54
|
assert extend_lens is not None
|
37
55
|
return _split_array_by_half_sum(extend_lens)
|
38
|
-
elif forward_mode.is_decode():
|
39
|
-
|
56
|
+
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
57
|
+
assert token_num_per_seq is not None
|
58
|
+
return (num_tokens // token_num_per_seq) // 2
|
40
59
|
elif forward_mode.is_idle():
|
41
60
|
assert num_tokens == 0
|
42
61
|
return 0
|
@@ -63,16 +82,103 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
|
63
82
|
return best_index
|
64
83
|
|
65
84
|
|
85
|
+
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
86
|
+
if seq_index == 0:
|
87
|
+
return 0
|
88
|
+
|
89
|
+
offset = 0
|
90
|
+
max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
|
91
|
+
for i in range(max_seq_len):
|
92
|
+
offset += (
|
93
|
+
spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
|
94
|
+
) * spec_info.draft_token_num
|
95
|
+
return offset
|
96
|
+
|
97
|
+
|
98
|
+
def split_spec_info(
|
99
|
+
spec_info: Optional[EagleVerifyInput],
|
100
|
+
start_seq_index: int,
|
101
|
+
end_seq_index: int,
|
102
|
+
start_token_index: int,
|
103
|
+
end_token_index: int,
|
104
|
+
):
|
105
|
+
if spec_info is None:
|
106
|
+
return None
|
107
|
+
if spec_info.draft_token is not None:
|
108
|
+
draft_token = spec_info.draft_token[start_token_index:end_token_index]
|
109
|
+
else:
|
110
|
+
draft_token = None
|
111
|
+
if spec_info.custom_mask is not None and spec_info.draft_token is not None:
|
112
|
+
custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
|
113
|
+
if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
|
114
|
+
custom_mask_end = spec_info.custom_mask.shape[0]
|
115
|
+
else:
|
116
|
+
custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)
|
117
|
+
|
118
|
+
if custom_mask_end > custom_mask_start:
|
119
|
+
custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
|
120
|
+
else:
|
121
|
+
custom_mask = spec_info.custom_mask
|
122
|
+
else:
|
123
|
+
custom_mask = spec_info.custom_mask
|
124
|
+
if spec_info.positions is not None:
|
125
|
+
positions = spec_info.positions[start_token_index:end_token_index]
|
126
|
+
else:
|
127
|
+
positions = None
|
128
|
+
if spec_info.retrive_index is not None:
|
129
|
+
retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
|
130
|
+
else:
|
131
|
+
retrive_index = None
|
132
|
+
if spec_info.retrive_next_token is not None:
|
133
|
+
retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
|
134
|
+
else:
|
135
|
+
retrive_next_token = None
|
136
|
+
if spec_info.retrive_next_sibling is not None:
|
137
|
+
retrive_next_sibling = spec_info.retrive_next_sibling[
|
138
|
+
start_seq_index:end_seq_index
|
139
|
+
]
|
140
|
+
else:
|
141
|
+
retrive_next_sibling = None
|
142
|
+
if spec_info.retrive_cum_len is not None:
|
143
|
+
retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
|
144
|
+
else:
|
145
|
+
retrive_cum_len = None
|
146
|
+
|
147
|
+
if spec_info.seq_lens_cpu is not None:
|
148
|
+
seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
|
149
|
+
else:
|
150
|
+
seq_lens_cpu = None
|
151
|
+
if seq_lens_cpu is not None:
|
152
|
+
seq_lens_sum = seq_lens_cpu.sum()
|
153
|
+
else:
|
154
|
+
seq_lens_sum = None
|
155
|
+
output_spec_info = replace(
|
156
|
+
spec_info,
|
157
|
+
custom_mask=custom_mask,
|
158
|
+
draft_token=draft_token,
|
159
|
+
positions=positions,
|
160
|
+
retrive_index=retrive_index,
|
161
|
+
retrive_next_token=retrive_next_token,
|
162
|
+
retrive_next_sibling=retrive_next_sibling,
|
163
|
+
retrive_cum_len=retrive_cum_len,
|
164
|
+
seq_lens_cpu=seq_lens_cpu,
|
165
|
+
seq_lens_sum=seq_lens_sum,
|
166
|
+
)
|
167
|
+
return output_spec_info
|
168
|
+
|
169
|
+
|
66
170
|
def compute_split_token_index(
|
67
171
|
split_seq_index: int,
|
68
172
|
forward_mode: "ForwardMode",
|
69
173
|
extend_seq_lens: Optional[Sequence[int]],
|
174
|
+
token_num_per_seq: Optional[int],
|
70
175
|
) -> int:
|
71
|
-
if forward_mode.
|
176
|
+
if forward_mode == ForwardMode.EXTEND:
|
72
177
|
assert extend_seq_lens is not None
|
73
178
|
return sum(extend_seq_lens[:split_seq_index])
|
74
|
-
elif forward_mode.is_decode():
|
75
|
-
|
179
|
+
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
180
|
+
assert token_num_per_seq is not None
|
181
|
+
return split_seq_index * token_num_per_seq
|
76
182
|
elif forward_mode.is_idle():
|
77
183
|
assert split_seq_index == 0
|
78
184
|
return 0
|
@@ -83,19 +189,25 @@ def compute_split_token_index(
|
|
83
189
|
def compute_split_indices_for_cuda_graph_replay(
|
84
190
|
forward_mode: ForwardMode,
|
85
191
|
cuda_graph_num_tokens: int,
|
192
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
86
193
|
):
|
87
194
|
forward_mode_for_tbo_split = (
|
88
195
|
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
89
196
|
)
|
197
|
+
token_num_per_seq = get_token_num_per_seq(
|
198
|
+
forward_mode=forward_mode, spec_info=spec_info
|
199
|
+
)
|
90
200
|
tbo_split_seq_index = compute_split_seq_index(
|
91
201
|
forward_mode=forward_mode_for_tbo_split,
|
92
202
|
num_tokens=cuda_graph_num_tokens,
|
93
203
|
extend_lens=None,
|
204
|
+
token_num_per_seq=token_num_per_seq,
|
94
205
|
)
|
95
206
|
tbo_split_token_index = compute_split_token_index(
|
96
207
|
split_seq_index=tbo_split_seq_index,
|
97
208
|
forward_mode=forward_mode_for_tbo_split,
|
98
209
|
extend_seq_lens=None,
|
210
|
+
token_num_per_seq=token_num_per_seq,
|
99
211
|
)
|
100
212
|
return tbo_split_seq_index, tbo_split_token_index
|
101
213
|
|
@@ -110,11 +222,15 @@ class TboCudaGraphRunnerPlugin:
|
|
110
222
|
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
111
223
|
if not global_server_args_dict["enable_two_batch_overlap"]:
|
112
224
|
return
|
225
|
+
token_num_per_seq = get_token_num_per_seq(
|
226
|
+
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
227
|
+
)
|
113
228
|
|
114
229
|
batch.tbo_split_seq_index = compute_split_seq_index(
|
115
230
|
forward_mode=batch.forward_mode,
|
116
231
|
num_tokens=num_tokens,
|
117
232
|
extend_lens=None,
|
233
|
+
token_num_per_seq=token_num_per_seq,
|
118
234
|
)
|
119
235
|
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
|
120
236
|
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
|
@@ -129,13 +245,20 @@ class TboCudaGraphRunnerPlugin:
|
|
129
245
|
)
|
130
246
|
|
131
247
|
def replay_prepare(
|
132
|
-
self,
|
248
|
+
self,
|
249
|
+
forward_mode: ForwardMode,
|
250
|
+
bs: int,
|
251
|
+
num_token_non_padded: int,
|
252
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
133
253
|
):
|
254
|
+
token_num_per_seq = get_token_num_per_seq(
|
255
|
+
forward_mode=forward_mode, spec_info=spec_info
|
256
|
+
)
|
134
257
|
tbo_split_seq_index, tbo_split_token_index = (
|
135
258
|
compute_split_indices_for_cuda_graph_replay(
|
136
259
|
forward_mode=forward_mode,
|
137
|
-
|
138
|
-
|
260
|
+
cuda_graph_num_tokens=bs * token_num_per_seq,
|
261
|
+
spec_info=spec_info,
|
139
262
|
)
|
140
263
|
)
|
141
264
|
|
@@ -149,19 +272,38 @@ class TboCudaGraphRunnerPlugin:
|
|
149
272
|
|
150
273
|
class TboDPAttentionPreparer:
|
151
274
|
def prepare_all_gather(
|
152
|
-
self,
|
275
|
+
self,
|
276
|
+
local_batch: ScheduleBatch,
|
277
|
+
deepep_mode: DeepEPMode,
|
278
|
+
enable_deepep_moe: bool,
|
279
|
+
enable_two_batch_overlap: bool,
|
153
280
|
):
|
154
281
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
155
282
|
|
156
283
|
if local_batch is not None:
|
284
|
+
token_num_per_seq = get_token_num_per_seq(
|
285
|
+
forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
|
286
|
+
)
|
287
|
+
|
288
|
+
if (
|
289
|
+
local_batch.forward_mode.is_target_verify()
|
290
|
+
or local_batch.forward_mode.is_decode()
|
291
|
+
):
|
292
|
+
num_tokens = local_batch.batch_size() * token_num_per_seq
|
293
|
+
else:
|
294
|
+
num_tokens = local_batch.extend_num_tokens
|
157
295
|
self.local_tbo_split_seq_index = compute_split_seq_index(
|
158
296
|
forward_mode=local_batch.forward_mode,
|
159
|
-
num_tokens=
|
297
|
+
num_tokens=num_tokens,
|
160
298
|
extend_lens=local_batch.extend_lens,
|
299
|
+
token_num_per_seq=token_num_per_seq,
|
161
300
|
)
|
162
|
-
resolved_deepep_mode = deepep_mode.resolve(local_batch.
|
301
|
+
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
|
163
302
|
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
164
|
-
|
303
|
+
(
|
304
|
+
local_batch.forward_mode.is_extend()
|
305
|
+
and not local_batch.forward_mode.is_target_verify()
|
306
|
+
)
|
165
307
|
and enable_deepep_moe
|
166
308
|
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
167
309
|
)
|
@@ -218,8 +360,8 @@ class TboDPAttentionPreparer:
|
|
218
360
|
|
219
361
|
class TboForwardBatchPreparer:
|
220
362
|
@classmethod
|
221
|
-
def prepare(cls, batch: ForwardBatch):
|
222
|
-
if batch.tbo_split_seq_index is None:
|
363
|
+
def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
|
364
|
+
if batch.tbo_split_seq_index is None or is_draft_worker:
|
223
365
|
return
|
224
366
|
|
225
367
|
tbo_children_num_token_non_padded = (
|
@@ -242,7 +384,9 @@ class TboForwardBatchPreparer:
|
|
242
384
|
f"TboForwardBatchPreparer.prepare "
|
243
385
|
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
244
386
|
f"tbo_split_token_index={tbo_split_token_index} "
|
245
|
-
f"extend_seq_lens={batch.extend_seq_lens_cpu}"
|
387
|
+
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
388
|
+
f"bs={batch.batch_size} "
|
389
|
+
f"forward_mode={batch.forward_mode}"
|
246
390
|
)
|
247
391
|
|
248
392
|
assert isinstance(batch.attn_backend, TboAttnBackend)
|
@@ -286,6 +430,9 @@ class TboForwardBatchPreparer:
|
|
286
430
|
output_attn_backend: AttentionBackend,
|
287
431
|
out_num_token_non_padded: torch.Tensor,
|
288
432
|
):
|
433
|
+
assert (
|
434
|
+
end_token_index >= start_token_index
|
435
|
+
), f"{end_token_index=}, {start_token_index=}, batch={batch}"
|
289
436
|
num_tokens = batch.input_ids.shape[0]
|
290
437
|
num_seqs = batch.batch_size
|
291
438
|
|
@@ -317,11 +464,30 @@ class TboForwardBatchPreparer:
|
|
317
464
|
old_value = getattr(batch, key)
|
318
465
|
if old_value is None:
|
319
466
|
continue
|
467
|
+
elif batch.forward_mode.is_target_verify() and (
|
468
|
+
key == "extend_seq_lens"
|
469
|
+
or key == "extend_prefix_lens"
|
470
|
+
or key == "extend_start_loc"
|
471
|
+
or key == "extend_prefix_lens_cpu"
|
472
|
+
or key == "extend_seq_lens_cpu"
|
473
|
+
or key == "extend_logprob_start_lens_cpu"
|
474
|
+
):
|
475
|
+
output_dict[key] = None
|
476
|
+
continue
|
320
477
|
assert (
|
321
478
|
len(old_value) == num_seqs
|
322
479
|
), f"{key=} {old_value=} {num_seqs=} {batch=}"
|
323
480
|
output_dict[key] = old_value[start_seq_index:end_seq_index]
|
324
481
|
|
482
|
+
spec_info = getattr(batch, "spec_info")
|
483
|
+
output_spec_info = split_spec_info(
|
484
|
+
spec_info=spec_info,
|
485
|
+
start_token_index=start_token_index,
|
486
|
+
end_token_index=end_token_index,
|
487
|
+
start_seq_index=start_seq_index,
|
488
|
+
end_seq_index=end_seq_index,
|
489
|
+
)
|
490
|
+
output_dict["spec_info"] = output_spec_info
|
325
491
|
for key in [
|
326
492
|
"forward_mode",
|
327
493
|
"return_logprob",
|
@@ -329,24 +495,26 @@ class TboForwardBatchPreparer:
|
|
329
495
|
"token_to_kv_pool",
|
330
496
|
"can_run_dp_cuda_graph",
|
331
497
|
"global_forward_mode",
|
332
|
-
"spec_info",
|
333
498
|
"spec_algorithm",
|
334
499
|
"capture_hidden_mode",
|
335
500
|
"padded_static_len",
|
336
501
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
337
502
|
]:
|
338
503
|
output_dict[key] = getattr(batch, key)
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
504
|
+
if not batch.forward_mode.is_target_verify():
|
505
|
+
assert (
|
506
|
+
_compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
|
507
|
+
== batch.extend_num_tokens
|
508
|
+
), f"{batch=}"
|
344
509
|
extend_num_tokens = _compute_extend_num_tokens(
|
345
510
|
output_dict["input_ids"], output_dict["forward_mode"]
|
346
511
|
)
|
347
512
|
|
348
513
|
# TODO improve, e.g. unify w/ `init_raw`
|
349
|
-
if
|
514
|
+
if (
|
515
|
+
global_server_args_dict["moe_dense_tp_size"] == 1
|
516
|
+
and batch.gathered_buffer is not None
|
517
|
+
):
|
350
518
|
sum_len = end_token_index - start_token_index
|
351
519
|
gathered_buffer = torch.zeros(
|
352
520
|
(sum_len, batch.gathered_buffer.shape[1]),
|
@@ -416,18 +584,26 @@ class TboForwardBatchPreparer:
|
|
416
584
|
|
417
585
|
@classmethod
|
418
586
|
def _compute_split_token_index(cls, batch: ForwardBatch):
|
587
|
+
token_num_per_seq = get_token_num_per_seq(
|
588
|
+
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
589
|
+
)
|
419
590
|
return compute_split_token_index(
|
420
591
|
split_seq_index=batch.tbo_split_seq_index,
|
421
592
|
forward_mode=batch.forward_mode,
|
422
593
|
extend_seq_lens=batch.extend_seq_lens_cpu,
|
594
|
+
token_num_per_seq=token_num_per_seq,
|
423
595
|
)
|
424
596
|
|
425
597
|
|
426
598
|
def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
|
427
|
-
if
|
428
|
-
|
429
|
-
|
599
|
+
if (
|
600
|
+
forward_mode.is_decode()
|
601
|
+
or forward_mode.is_idle()
|
602
|
+
or forward_mode.is_target_verify()
|
603
|
+
):
|
430
604
|
return None
|
605
|
+
elif forward_mode.is_extend():
|
606
|
+
return input_ids.shape[0]
|
431
607
|
raise NotImplementedError
|
432
608
|
|
433
609
|
|