sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
20
20
|
ForwardMode,
|
21
21
|
)
|
22
22
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
require_attn_tp_gather,
|
25
|
+
require_gathered_buffer,
|
26
|
+
require_mlp_sync,
|
27
|
+
require_mlp_tp_gather,
|
28
|
+
)
|
23
29
|
|
24
30
|
if TYPE_CHECKING:
|
25
31
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
@@ -38,9 +44,18 @@ class EAGLEDraftCudaGraphRunner:
|
|
38
44
|
self.output_buffers = {}
|
39
45
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
40
46
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
47
|
+
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
48
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
49
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
50
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
51
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
52
|
+
self.dp_size = self.model_runner.dp_size
|
41
53
|
self.tp_size = self.model_runner.tp_size
|
42
54
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
43
55
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
56
|
+
self.enable_profile_cuda_graph = (
|
57
|
+
model_runner.server_args.enable_profile_cuda_graph
|
58
|
+
)
|
44
59
|
server_args = model_runner.server_args
|
45
60
|
|
46
61
|
# Batch sizes to capture
|
@@ -50,7 +65,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
50
65
|
# Attention backend
|
51
66
|
self.max_bs = max(self.capture_bs)
|
52
67
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
53
|
-
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
68
|
+
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
69
|
+
self.max_bs, self.max_num_token
|
70
|
+
)
|
54
71
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
55
72
|
0
|
56
73
|
].get_cuda_graph_seq_len_fill_value()
|
@@ -75,10 +92,32 @@ class EAGLEDraftCudaGraphRunner:
|
|
75
92
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
76
93
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
77
94
|
self.hidden_states = torch.zeros(
|
78
|
-
(self.
|
95
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
79
96
|
dtype=self.model_runner.dtype,
|
80
97
|
)
|
81
98
|
|
99
|
+
if self.require_gathered_buffer:
|
100
|
+
self.gathered_buffer = torch.zeros(
|
101
|
+
(
|
102
|
+
self.max_num_token,
|
103
|
+
self.model_runner.model_config.hidden_size,
|
104
|
+
),
|
105
|
+
dtype=self.model_runner.dtype,
|
106
|
+
)
|
107
|
+
if self.require_mlp_tp_gather:
|
108
|
+
self.global_num_tokens_gpu = torch.zeros(
|
109
|
+
(self.dp_size,), dtype=torch.int32
|
110
|
+
)
|
111
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
112
|
+
(self.dp_size,), dtype=torch.int32
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
assert self.require_attn_tp_gather
|
116
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
117
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
118
|
+
(1,), dtype=torch.int32
|
119
|
+
)
|
120
|
+
|
82
121
|
# Capture
|
83
122
|
try:
|
84
123
|
with model_capture_mode():
|
@@ -89,11 +128,24 @@ class EAGLEDraftCudaGraphRunner:
|
|
89
128
|
)
|
90
129
|
|
91
130
|
def can_run(self, forward_batch: ForwardBatch):
|
131
|
+
if self.require_mlp_tp_gather:
|
132
|
+
cuda_graph_bs = (
|
133
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
134
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
135
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
cuda_graph_bs = forward_batch.batch_size
|
139
|
+
|
92
140
|
is_bs_supported = (
|
93
|
-
|
141
|
+
cuda_graph_bs in self.graphs
|
94
142
|
if self.disable_padding
|
95
|
-
else
|
143
|
+
else cuda_graph_bs <= self.max_bs
|
96
144
|
)
|
145
|
+
|
146
|
+
if self.require_mlp_sync:
|
147
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
148
|
+
|
97
149
|
return is_bs_supported
|
98
150
|
|
99
151
|
def capture(self):
|
@@ -113,10 +165,58 @@ class EAGLEDraftCudaGraphRunner:
|
|
113
165
|
topk_index = self.topk_index[:num_seqs]
|
114
166
|
hidden_states = self.hidden_states[:num_seqs]
|
115
167
|
|
168
|
+
if self.require_mlp_tp_gather:
|
169
|
+
self.global_num_tokens_gpu.copy_(
|
170
|
+
torch.tensor(
|
171
|
+
[
|
172
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
173
|
+
for i in range(self.dp_size)
|
174
|
+
],
|
175
|
+
dtype=torch.int32,
|
176
|
+
device=self.input_ids.device,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
180
|
+
torch.tensor(
|
181
|
+
[
|
182
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
183
|
+
for i in range(self.dp_size)
|
184
|
+
],
|
185
|
+
dtype=torch.int32,
|
186
|
+
device=self.input_ids.device,
|
187
|
+
)
|
188
|
+
)
|
189
|
+
global_num_tokens = self.global_num_tokens_gpu
|
190
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
191
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
192
|
+
elif self.require_attn_tp_gather:
|
193
|
+
self.global_num_tokens_gpu.copy_(
|
194
|
+
torch.tensor(
|
195
|
+
[num_tokens],
|
196
|
+
dtype=torch.int32,
|
197
|
+
device=self.input_ids.device,
|
198
|
+
)
|
199
|
+
)
|
200
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
201
|
+
torch.tensor(
|
202
|
+
[num_tokens],
|
203
|
+
dtype=torch.int32,
|
204
|
+
device=self.input_ids.device,
|
205
|
+
)
|
206
|
+
)
|
207
|
+
global_num_tokens = self.global_num_tokens_gpu
|
208
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
209
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
210
|
+
else:
|
211
|
+
global_num_tokens = None
|
212
|
+
gathered_buffer = None
|
213
|
+
global_num_tokens_for_logprob = None
|
214
|
+
|
116
215
|
spec_info = EagleDraftInput(
|
117
216
|
topk_p=topk_p,
|
118
217
|
topk_index=topk_index,
|
119
218
|
hidden_states=hidden_states,
|
219
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
120
220
|
)
|
121
221
|
|
122
222
|
# Forward batch
|
@@ -132,11 +232,14 @@ class EAGLEDraftCudaGraphRunner:
|
|
132
232
|
seq_lens_sum=seq_lens.sum().item(),
|
133
233
|
return_logprob=False,
|
134
234
|
positions=positions,
|
235
|
+
global_num_tokens_gpu=global_num_tokens,
|
236
|
+
gathered_buffer=gathered_buffer,
|
135
237
|
spec_algorithm=self.model_runner.spec_algorithm,
|
136
238
|
spec_info=spec_info,
|
137
239
|
capture_hidden_mode=(
|
138
240
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
139
241
|
),
|
242
|
+
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
140
243
|
)
|
141
244
|
|
142
245
|
# Attention backend
|
@@ -146,6 +249,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
146
249
|
|
147
250
|
# Run and capture
|
148
251
|
def run_once():
|
252
|
+
# Clean intermediate result cache for DP attention
|
253
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
254
|
+
|
149
255
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
150
256
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
151
257
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
@@ -183,12 +289,19 @@ class EAGLEDraftCudaGraphRunner:
|
|
183
289
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
184
290
|
|
185
291
|
# Pad
|
186
|
-
|
292
|
+
if self.require_mlp_tp_gather:
|
293
|
+
total_batch_size = (
|
294
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
295
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
296
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
297
|
+
)
|
298
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
299
|
+
else:
|
300
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
187
301
|
bs = self.capture_bs[index]
|
188
302
|
if bs != raw_bs:
|
189
|
-
self.seq_lens.fill_(
|
303
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
190
304
|
self.out_cache_loc.zero_()
|
191
|
-
self.positions.zero_()
|
192
305
|
|
193
306
|
num_tokens = bs * self.num_tokens_per_bs
|
194
307
|
|
@@ -203,6 +316,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
203
316
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
204
317
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
205
318
|
|
319
|
+
if self.require_gathered_buffer:
|
320
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
321
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
322
|
+
forward_batch.global_num_tokens_for_logprob_gpu
|
323
|
+
)
|
324
|
+
forward_batch.gathered_buffer = self.gathered_buffer
|
325
|
+
|
206
326
|
# Attention backend
|
207
327
|
if bs != raw_bs:
|
208
328
|
forward_batch.batch_size = bs
|
@@ -211,14 +331,16 @@ class EAGLEDraftCudaGraphRunner:
|
|
211
331
|
forward_batch.positions = self.positions[:num_tokens]
|
212
332
|
|
213
333
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
214
|
-
if forward_batch.seq_lens_cpu is not None
|
215
|
-
|
334
|
+
if forward_batch.seq_lens_cpu is not None:
|
335
|
+
if bs != raw_bs:
|
336
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
216
337
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
217
338
|
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
218
339
|
|
219
340
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
220
341
|
forward_batch, bs
|
221
342
|
)
|
343
|
+
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
|
222
344
|
|
223
345
|
# Replay
|
224
346
|
self.graphs[bs].replay()
|
@@ -21,6 +21,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
21
21
|
ForwardMode,
|
22
22
|
)
|
23
23
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
24
|
+
from sglang.srt.utils import (
|
25
|
+
require_attn_tp_gather,
|
26
|
+
require_gathered_buffer,
|
27
|
+
require_mlp_sync,
|
28
|
+
require_mlp_tp_gather,
|
29
|
+
)
|
24
30
|
|
25
31
|
if TYPE_CHECKING:
|
26
32
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
@@ -35,10 +41,17 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
35
41
|
self.output_buffers = {}
|
36
42
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
37
43
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
44
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
45
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
46
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
47
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
38
48
|
self.tp_size = self.model_runner.tp_size
|
39
49
|
self.dp_size = model_runner.server_args.dp_size
|
40
50
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
41
51
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
52
|
+
self.enable_profile_cuda_graph = (
|
53
|
+
model_runner.server_args.enable_profile_cuda_graph
|
54
|
+
)
|
42
55
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
43
56
|
self.padded_static_len = -1
|
44
57
|
|
@@ -48,7 +61,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
48
61
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
49
62
|
|
50
63
|
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
|
51
|
-
self.max_num_token
|
64
|
+
self.max_bs, self.max_num_token
|
52
65
|
)
|
53
66
|
self.seq_len_fill_value = (
|
54
67
|
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
|
@@ -83,10 +96,31 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
83
96
|
|
84
97
|
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
85
98
|
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
86
|
-
self.accept_length = (
|
87
|
-
|
99
|
+
self.accept_length = torch.full(
|
100
|
+
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
88
101
|
)
|
89
102
|
|
103
|
+
if self.require_gathered_buffer:
|
104
|
+
self.gathered_buffer = torch.zeros(
|
105
|
+
(
|
106
|
+
self.max_num_token,
|
107
|
+
self.model_runner.model_config.hidden_size,
|
108
|
+
),
|
109
|
+
dtype=self.model_runner.dtype,
|
110
|
+
)
|
111
|
+
if self.require_mlp_tp_gather:
|
112
|
+
self.global_num_tokens_gpu = torch.zeros(
|
113
|
+
(self.dp_size,), dtype=torch.int32
|
114
|
+
)
|
115
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
116
|
+
(self.dp_size,), dtype=torch.int32
|
117
|
+
)
|
118
|
+
else:
|
119
|
+
assert self.require_attn_tp_gather
|
120
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
121
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
122
|
+
(1,), dtype=torch.int32
|
123
|
+
)
|
90
124
|
# Capture
|
91
125
|
try:
|
92
126
|
with model_capture_mode():
|
@@ -97,14 +131,24 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
97
131
|
)
|
98
132
|
|
99
133
|
def can_run(self, forward_batch: ForwardBatch):
|
100
|
-
|
134
|
+
if self.require_mlp_tp_gather:
|
135
|
+
cuda_graph_bs = (
|
136
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
137
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
138
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
cuda_graph_bs = forward_batch.seq_lens.numel()
|
101
142
|
|
102
143
|
is_bs_supported = (
|
103
|
-
|
144
|
+
cuda_graph_bs in self.graphs
|
104
145
|
if self.disable_padding
|
105
|
-
else
|
146
|
+
else cuda_graph_bs <= self.max_bs
|
106
147
|
)
|
107
148
|
|
149
|
+
if self.require_mlp_sync:
|
150
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
151
|
+
|
108
152
|
return is_bs_supported
|
109
153
|
|
110
154
|
def capture(self):
|
@@ -125,6 +169,53 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
125
169
|
positions = self.positions[:num_tokens]
|
126
170
|
hidden_states = self.hidden_states[:num_tokens]
|
127
171
|
|
172
|
+
if self.require_mlp_tp_gather:
|
173
|
+
self.global_num_tokens_gpu.copy_(
|
174
|
+
torch.tensor(
|
175
|
+
[
|
176
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
177
|
+
for i in range(self.dp_size)
|
178
|
+
],
|
179
|
+
dtype=torch.int32,
|
180
|
+
device=self.input_ids.device,
|
181
|
+
)
|
182
|
+
)
|
183
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
184
|
+
torch.tensor(
|
185
|
+
[
|
186
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
187
|
+
for i in range(self.dp_size)
|
188
|
+
],
|
189
|
+
dtype=torch.int32,
|
190
|
+
device=self.input_ids.device,
|
191
|
+
)
|
192
|
+
)
|
193
|
+
global_num_tokens = self.global_num_tokens_gpu
|
194
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
195
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
196
|
+
elif self.require_attn_tp_gather:
|
197
|
+
self.global_num_tokens_gpu.copy_(
|
198
|
+
torch.tensor(
|
199
|
+
[num_tokens],
|
200
|
+
dtype=torch.int32,
|
201
|
+
device=self.input_ids.device,
|
202
|
+
)
|
203
|
+
)
|
204
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
205
|
+
torch.tensor(
|
206
|
+
[num_tokens],
|
207
|
+
dtype=torch.int32,
|
208
|
+
device=self.input_ids.device,
|
209
|
+
)
|
210
|
+
)
|
211
|
+
global_num_tokens = self.global_num_tokens_gpu
|
212
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
213
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
214
|
+
else:
|
215
|
+
global_num_tokens = None
|
216
|
+
gathered_buffer = None
|
217
|
+
global_num_tokens_for_logprob = None
|
218
|
+
|
128
219
|
spec_info = EagleDraftInput(
|
129
220
|
hidden_states=hidden_states,
|
130
221
|
accept_length=accept_length,
|
@@ -144,6 +235,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
144
235
|
seq_lens_sum=seq_lens.sum().item(),
|
145
236
|
return_logprob=False,
|
146
237
|
positions=positions,
|
238
|
+
global_num_tokens_gpu=global_num_tokens,
|
239
|
+
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
240
|
+
gathered_buffer=gathered_buffer,
|
147
241
|
spec_algorithm=self.model_runner.spec_algorithm,
|
148
242
|
spec_info=spec_info,
|
149
243
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
@@ -164,6 +258,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
164
258
|
|
165
259
|
# Run and capture
|
166
260
|
def run_once():
|
261
|
+
# Clean intermediate result cache for DP attention
|
262
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
263
|
+
|
167
264
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
168
265
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
169
266
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
@@ -200,38 +297,57 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
200
297
|
# in the batch, which will not be counted as num_seqs
|
201
298
|
raw_bs = forward_batch.batch_size
|
202
299
|
num_tokens = forward_batch.input_ids.shape[0]
|
300
|
+
if self.require_mlp_tp_gather:
|
301
|
+
total_batch_size = (
|
302
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
303
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
304
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
305
|
+
)
|
306
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
307
|
+
else:
|
308
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
203
309
|
|
204
|
-
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
205
310
|
bs = self.capture_bs[index]
|
206
311
|
if bs * self.num_tokens_per_bs != num_tokens:
|
207
|
-
self.seq_lens.fill_(
|
208
|
-
self.accept_length.fill_(1)
|
312
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
209
313
|
self.out_cache_loc.zero_()
|
314
|
+
self.accept_length.fill_(1)
|
315
|
+
self.extend_seq_lens.fill_(1)
|
210
316
|
|
211
317
|
# Common inputs
|
212
318
|
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
213
319
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
214
|
-
|
320
|
+
if forward_batch.extend_seq_lens is not None:
|
321
|
+
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
215
322
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
216
323
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
217
324
|
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
218
|
-
|
325
|
+
if forward_batch.spec_info.accept_length is not None:
|
326
|
+
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
219
327
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
220
328
|
|
329
|
+
if self.require_gathered_buffer:
|
330
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
331
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
332
|
+
forward_batch.global_num_tokens_for_logprob_gpu
|
333
|
+
)
|
334
|
+
forward_batch.gathered_buffer = self.gathered_buffer
|
335
|
+
|
221
336
|
if forward_batch.seq_lens_cpu is not None:
|
222
337
|
if bs != raw_bs:
|
223
|
-
self.seq_lens_cpu.fill_(
|
338
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
224
339
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
225
340
|
|
226
341
|
if bs != raw_bs:
|
342
|
+
forward_batch.spec_info.positions = self.positions[:num_tokens]
|
227
343
|
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
228
|
-
forward_batch.spec_info.positions = None
|
229
344
|
|
230
345
|
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
231
346
|
bs=bs,
|
232
347
|
req_pool_indices=self.req_pool_indices,
|
233
348
|
seq_lens=self.seq_lens,
|
234
|
-
seq_lens_sum=forward_batch.seq_lens_sum
|
349
|
+
seq_lens_sum=forward_batch.seq_lens_sum
|
350
|
+
+ (bs - raw_bs) * self.seq_len_fill_value,
|
235
351
|
encoder_lens=None,
|
236
352
|
forward_mode=ForwardMode.DRAFT_EXTEND,
|
237
353
|
spec_info=forward_batch.spec_info,
|