sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
|
|
7
7
|
import torch
|
8
8
|
from huggingface_hub import snapshot_download
|
9
9
|
|
10
|
-
from sglang.srt.distributed import
|
11
|
-
|
10
|
+
from sglang.srt.distributed import (
|
11
|
+
GroupCoordinator,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
get_tp_group,
|
14
|
+
patch_tensor_parallel_group,
|
15
|
+
)
|
12
16
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
17
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
14
18
|
from sglang.srt.managers.schedule_batch import (
|
@@ -35,11 +39,17 @@ from sglang.srt.speculative.eagle_utils import (
|
|
35
39
|
EagleVerifyInput,
|
36
40
|
EagleVerifyOutput,
|
37
41
|
assign_draft_cache_locs,
|
42
|
+
fast_topk,
|
38
43
|
generate_token_bitmask,
|
39
44
|
select_top_k_tokens,
|
40
45
|
)
|
41
46
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
42
|
-
from sglang.srt.utils import
|
47
|
+
from sglang.srt.utils import (
|
48
|
+
empty_context,
|
49
|
+
get_available_gpu_memory,
|
50
|
+
is_cuda,
|
51
|
+
next_power_of_2,
|
52
|
+
)
|
43
53
|
|
44
54
|
if is_cuda():
|
45
55
|
from sgl_kernel import segment_packbits
|
@@ -51,7 +61,7 @@ logger = logging.getLogger(__name__)
|
|
51
61
|
def draft_tp_context(tp_group: GroupCoordinator):
|
52
62
|
# Draft model doesn't use dp and has its own tp group.
|
53
63
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
54
|
-
with
|
64
|
+
with patch_tensor_parallel_group(tp_group):
|
55
65
|
yield
|
56
66
|
|
57
67
|
|
@@ -70,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
|
70
80
|
self.server_args = server_args
|
71
81
|
self.topk = server_args.speculative_eagle_topk
|
72
82
|
self.speculative_num_steps = server_args.speculative_num_steps
|
83
|
+
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
73
84
|
self.enable_nan_detection = server_args.enable_nan_detection
|
74
85
|
self.gpu_id = gpu_id
|
75
86
|
self.device = server_args.device
|
@@ -152,8 +163,18 @@ class EAGLEWorker(TpModelWorker):
|
|
152
163
|
self.init_attention_backend()
|
153
164
|
self.init_cuda_graphs()
|
154
165
|
|
166
|
+
# Some dummy tensors
|
167
|
+
self.num_new_pages_per_topk = torch.empty(
|
168
|
+
(), dtype=torch.int64, device=self.device
|
169
|
+
)
|
170
|
+
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
171
|
+
|
155
172
|
def init_attention_backend(self):
|
156
173
|
# Create multi-step attn backends and cuda graph runners
|
174
|
+
|
175
|
+
self.has_prefill_wrapper_verify = False
|
176
|
+
self.draft_extend_attn_backend = None
|
177
|
+
|
157
178
|
if self.server_args.attention_backend == "flashinfer":
|
158
179
|
if not global_server_args_dict["use_mla_backend"]:
|
159
180
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
@@ -201,7 +222,6 @@ class EAGLEWorker(TpModelWorker):
|
|
201
222
|
self.draft_model_runner,
|
202
223
|
skip_prefill=False,
|
203
224
|
)
|
204
|
-
self.has_prefill_wrapper_verify = False
|
205
225
|
elif self.server_args.attention_backend == "fa3":
|
206
226
|
from sglang.srt.layers.attention.flashattention_backend import (
|
207
227
|
FlashAttentionBackend,
|
@@ -217,7 +237,6 @@ class EAGLEWorker(TpModelWorker):
|
|
217
237
|
self.draft_model_runner,
|
218
238
|
skip_prefill=False,
|
219
239
|
)
|
220
|
-
self.has_prefill_wrapper_verify = False
|
221
240
|
elif self.server_args.attention_backend == "flashmla":
|
222
241
|
from sglang.srt.layers.attention.flashmla_backend import (
|
223
242
|
FlashMLAMultiStepDraftBackend,
|
@@ -228,8 +247,6 @@ class EAGLEWorker(TpModelWorker):
|
|
228
247
|
self.topk,
|
229
248
|
self.speculative_num_steps,
|
230
249
|
)
|
231
|
-
self.draft_extend_attn_backend = None
|
232
|
-
self.has_prefill_wrapper_verify = False
|
233
250
|
else:
|
234
251
|
raise ValueError(
|
235
252
|
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
@@ -254,7 +271,7 @@ class EAGLEWorker(TpModelWorker):
|
|
254
271
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
255
272
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
256
273
|
logger.info(
|
257
|
-
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s.
|
274
|
+
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
258
275
|
)
|
259
276
|
|
260
277
|
# Capture extend
|
@@ -269,7 +286,7 @@ class EAGLEWorker(TpModelWorker):
|
|
269
286
|
)
|
270
287
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
271
288
|
logger.info(
|
272
|
-
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s.
|
289
|
+
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
273
290
|
)
|
274
291
|
|
275
292
|
@property
|
@@ -290,17 +307,27 @@ class EAGLEWorker(TpModelWorker):
|
|
290
307
|
A tuple of the final logit output of the target model, next tokens accepted,
|
291
308
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
292
309
|
"""
|
293
|
-
if batch.forward_mode.
|
310
|
+
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
311
|
+
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
312
|
+
self.forward_target_extend(batch)
|
313
|
+
)
|
314
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
315
|
+
self.forward_draft_extend(
|
316
|
+
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
317
|
+
)
|
318
|
+
return logits_output, next_token_ids, bid, 0, False
|
319
|
+
else:
|
294
320
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
295
321
|
spec_info = self.draft(batch)
|
296
322
|
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
297
323
|
self.verify(batch, spec_info)
|
298
324
|
)
|
299
325
|
|
300
|
-
|
301
|
-
if batch.spec_info.verified_id is not None:
|
326
|
+
if self.check_forward_draft_extend_after_decode(batch):
|
302
327
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
303
|
-
self.forward_draft_extend_after_decode(
|
328
|
+
self.forward_draft_extend_after_decode(
|
329
|
+
batch,
|
330
|
+
)
|
304
331
|
return (
|
305
332
|
logits_output,
|
306
333
|
verify_output.verified_id,
|
@@ -308,22 +335,27 @@ class EAGLEWorker(TpModelWorker):
|
|
308
335
|
sum(verify_output.accept_length_per_req_cpu),
|
309
336
|
can_run_cuda_graph,
|
310
337
|
)
|
311
|
-
elif batch.forward_mode.is_idle():
|
312
|
-
model_worker_batch = batch.get_model_worker_batch()
|
313
|
-
logits_output, next_token_ids, _ = (
|
314
|
-
self.target_worker.forward_batch_generation(model_worker_batch)
|
315
|
-
)
|
316
338
|
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
339
|
+
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
340
|
+
local_need_forward = (
|
341
|
+
batch.spec_info.verified_id is not None
|
342
|
+
and batch.spec_info.verified_id.shape[0] > 0
|
343
|
+
)
|
344
|
+
if not self.server_args.enable_dp_attention:
|
345
|
+
return local_need_forward
|
346
|
+
|
347
|
+
global_need_forward = torch.tensor(
|
348
|
+
[
|
349
|
+
(local_need_forward),
|
350
|
+
],
|
351
|
+
dtype=torch.int64,
|
352
|
+
)
|
353
|
+
torch.distributed.all_reduce(
|
354
|
+
global_need_forward, group=get_tp_group().cpu_group
|
355
|
+
)
|
356
|
+
global_need_forward_cnt = global_need_forward[0].item()
|
357
|
+
need_forward = global_need_forward_cnt > 0
|
358
|
+
return need_forward
|
327
359
|
|
328
360
|
def forward_target_extend(
|
329
361
|
self, batch: ScheduleBatch
|
@@ -342,6 +374,7 @@ class EAGLEWorker(TpModelWorker):
|
|
342
374
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
343
375
|
model_worker_batch = batch.get_model_worker_batch()
|
344
376
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
377
|
+
model_worker_batch.spec_num_draft_tokens = 1
|
345
378
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
346
379
|
model_worker_batch
|
347
380
|
)
|
@@ -352,7 +385,7 @@ class EAGLEWorker(TpModelWorker):
|
|
352
385
|
model_worker_batch.seq_lens_cpu,
|
353
386
|
)
|
354
387
|
|
355
|
-
def
|
388
|
+
def _draft_preprocess_decode(self, batch: ScheduleBatch):
|
356
389
|
# Parse args
|
357
390
|
num_seqs = batch.batch_size()
|
358
391
|
spec_info = batch.spec_info
|
@@ -365,14 +398,21 @@ class EAGLEWorker(TpModelWorker):
|
|
365
398
|
)
|
366
399
|
|
367
400
|
# Allocate cache locations
|
401
|
+
# Layout of the out_cache_loc
|
402
|
+
# [ topk 0 ] [ topk 1 ]
|
403
|
+
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
368
404
|
if self.page_size == 1:
|
369
405
|
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
370
|
-
num_seqs * self.
|
406
|
+
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
|
371
407
|
)
|
372
408
|
else:
|
373
409
|
if self.topk == 1:
|
374
|
-
prefix_lens =
|
375
|
-
|
410
|
+
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
|
411
|
+
batch.req_to_token_pool.req_to_token,
|
412
|
+
batch.req_pool_indices,
|
413
|
+
batch.seq_lens,
|
414
|
+
self.speculative_num_steps,
|
415
|
+
)
|
376
416
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
377
417
|
else:
|
378
418
|
# In this case, the last partial page needs to be duplicated.
|
@@ -385,29 +425,33 @@ class EAGLEWorker(TpModelWorker):
|
|
385
425
|
# "x" means speculative draft tokens
|
386
426
|
# "." means padded tokens
|
387
427
|
|
388
|
-
# TODO:
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
428
|
+
# TODO(lmzheng): The current implementation is still a fake support
|
429
|
+
# for page size > 1. In the `assign_draft_cache_locs` below,
|
430
|
+
# we directly move the indices instead of the real kv cache.
|
431
|
+
# This only works when the kernel backend runs with page size = 1.
|
432
|
+
# If the kernel backend runs with page size > 1, we need to
|
433
|
+
# duplicate the real KV cache. The overhead of duplicating KV
|
434
|
+
# cache seems okay because the draft KV cache only has one layer.
|
435
|
+
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
|
436
|
+
|
437
|
+
(
|
438
|
+
prefix_lens,
|
439
|
+
seq_lens,
|
440
|
+
last_loc,
|
441
|
+
self.num_new_pages_per_topk,
|
442
|
+
self.extend_lens,
|
443
|
+
) = get_last_loc_large_page_size_large_top_k(
|
444
|
+
batch.req_to_token_pool.req_to_token,
|
445
|
+
batch.req_pool_indices,
|
446
|
+
batch.seq_lens,
|
447
|
+
self.speculative_num_steps,
|
448
|
+
self.topk,
|
449
|
+
self.page_size,
|
401
450
|
)
|
402
|
-
|
403
|
-
#
|
404
|
-
|
405
|
-
|
406
|
-
last_loc = get_last_loc(
|
407
|
-
batch.req_to_token_pool.req_to_token,
|
408
|
-
batch.req_pool_indices,
|
409
|
-
prefix_lens,
|
410
|
-
)
|
451
|
+
|
452
|
+
# TODO(lmzheng): remove this device sync
|
453
|
+
extend_num_tokens = torch.sum(self.extend_lens).item()
|
454
|
+
|
411
455
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
412
456
|
batch.alloc_paged_token_slots_extend(
|
413
457
|
prefix_lens,
|
@@ -422,19 +466,54 @@ class EAGLEWorker(TpModelWorker):
|
|
422
466
|
batch.req_pool_indices,
|
423
467
|
batch.req_to_token_pool.req_to_token,
|
424
468
|
batch.seq_lens,
|
469
|
+
self.extend_lens,
|
470
|
+
self.num_new_pages_per_topk,
|
425
471
|
out_cache_loc,
|
426
472
|
batch.req_to_token_pool.req_to_token.shape[1],
|
427
473
|
self.topk,
|
428
474
|
self.speculative_num_steps,
|
429
475
|
self.page_size,
|
476
|
+
next_power_of_2(num_seqs),
|
477
|
+
next_power_of_2(self.speculative_num_steps),
|
430
478
|
)
|
479
|
+
|
480
|
+
if self.page_size > 1 and self.topk > 1:
|
481
|
+
# Remove padded slots
|
482
|
+
out_cache_loc = out_cache_loc[
|
483
|
+
: num_seqs * self.topk * self.speculative_num_steps
|
484
|
+
]
|
485
|
+
|
431
486
|
batch.out_cache_loc = out_cache_loc
|
432
487
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
488
|
+
batch.return_hidden_states = False
|
433
489
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
490
|
+
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
491
|
+
|
492
|
+
def _draft_preprocess_idle(self, batch: ScheduleBatch):
|
493
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
494
|
+
device=self.device,
|
495
|
+
hidden_size=self.model_config.hidden_size,
|
496
|
+
dtype=self.model_config.dtype,
|
497
|
+
topk=self.topk,
|
498
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
499
|
+
)
|
500
|
+
|
501
|
+
def draft(self, batch: ScheduleBatch):
|
502
|
+
# Parse args
|
503
|
+
if batch.forward_mode.is_idle():
|
504
|
+
self._draft_preprocess_idle(batch)
|
505
|
+
else:
|
506
|
+
self._draft_preprocess_decode(batch)
|
507
|
+
|
508
|
+
spec_info = batch.spec_info
|
434
509
|
|
435
|
-
# Get forward batch
|
436
510
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
511
|
+
batch.return_hidden_states = False
|
512
|
+
|
513
|
+
# Get forward batch
|
437
514
|
model_worker_batch = batch.get_model_worker_batch()
|
515
|
+
model_worker_batch.spec_num_draft_tokens = self.topk
|
516
|
+
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
438
517
|
forward_batch = ForwardBatch.init_new(
|
439
518
|
model_worker_batch, self.draft_model_runner
|
440
519
|
)
|
@@ -446,15 +525,18 @@ class EAGLEWorker(TpModelWorker):
|
|
446
525
|
forward_batch
|
447
526
|
)
|
448
527
|
else:
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
model_worker_batch, self.draft_model_runner
|
453
|
-
)
|
528
|
+
if not forward_batch.forward_mode.is_idle():
|
529
|
+
# Initialize attention backend
|
530
|
+
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
454
531
|
# Run forward steps
|
455
532
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
456
533
|
|
457
|
-
|
534
|
+
if batch.forward_mode.is_idle():
|
535
|
+
return EagleVerifyInput.create_idle_input(
|
536
|
+
self.topk,
|
537
|
+
self.speculative_num_steps,
|
538
|
+
self.speculative_num_draft_tokens,
|
539
|
+
)
|
458
540
|
|
459
541
|
(
|
460
542
|
tree_mask,
|
@@ -472,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
|
|
472
554
|
batch.seq_lens_sum,
|
473
555
|
self.topk,
|
474
556
|
self.speculative_num_steps,
|
475
|
-
self.
|
557
|
+
self.speculative_num_draft_tokens,
|
476
558
|
)
|
477
559
|
|
478
560
|
return EagleVerifyInput(
|
@@ -503,6 +585,13 @@ class EAGLEWorker(TpModelWorker):
|
|
503
585
|
if self.hot_token_id is not None:
|
504
586
|
topk_index = self.hot_token_id[topk_index]
|
505
587
|
|
588
|
+
out_cache_loc = out_cache_loc.reshape(
|
589
|
+
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
590
|
+
)
|
591
|
+
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
592
|
+
self.speculative_num_steps, -1
|
593
|
+
)
|
594
|
+
|
506
595
|
# Return values
|
507
596
|
score_list: List[torch.Tensor] = []
|
508
597
|
token_list: List[torch.Tensor] = []
|
@@ -524,10 +613,7 @@ class EAGLEWorker(TpModelWorker):
|
|
524
613
|
|
525
614
|
# Set inputs
|
526
615
|
forward_batch.input_ids = input_ids
|
527
|
-
out_cache_loc = out_cache_loc
|
528
|
-
forward_batch.out_cache_loc = out_cache_loc[
|
529
|
-
:, self.topk * i : self.topk * (i + 1)
|
530
|
-
].flatten()
|
616
|
+
forward_batch.out_cache_loc = out_cache_loc[i]
|
531
617
|
forward_batch.positions.add_(1)
|
532
618
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
533
619
|
spec_info.hidden_states = hidden_states
|
@@ -547,11 +633,18 @@ class EAGLEWorker(TpModelWorker):
|
|
547
633
|
|
548
634
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
549
635
|
spec_info.prepare_for_verify(batch, self.page_size)
|
550
|
-
batch.
|
636
|
+
batch.return_hidden_states = False
|
637
|
+
batch.forward_mode = (
|
638
|
+
ForwardMode.TARGET_VERIFY
|
639
|
+
if not batch.forward_mode.is_idle()
|
640
|
+
else ForwardMode.IDLE
|
641
|
+
)
|
551
642
|
batch.spec_info = spec_info
|
552
643
|
model_worker_batch = batch.get_model_worker_batch(
|
553
644
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
554
645
|
)
|
646
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
647
|
+
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
555
648
|
|
556
649
|
if batch.has_grammar:
|
557
650
|
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
@@ -583,7 +676,7 @@ class EAGLEWorker(TpModelWorker):
|
|
583
676
|
if vocab_mask is not None:
|
584
677
|
assert spec_info.grammar is not None
|
585
678
|
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
586
|
-
# otherwise, this vocab mask will be the one from the previous extend stage
|
679
|
+
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
|
587
680
|
# and will be applied to produce wrong results
|
588
681
|
batch.sampling_info.vocab_mask = None
|
589
682
|
|
@@ -604,13 +697,15 @@ class EAGLEWorker(TpModelWorker):
|
|
604
697
|
]
|
605
698
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
606
699
|
|
607
|
-
# Prepare the batch for the next draft forwards.
|
608
|
-
batch.forward_mode = ForwardMode.DECODE
|
609
|
-
batch.spec_info = res.draft_input
|
610
|
-
|
611
700
|
if batch.return_logprob:
|
612
701
|
self.add_logprob_values(batch, res, logits_output)
|
613
702
|
|
703
|
+
# Prepare the batch for the next draft forwards.
|
704
|
+
batch.forward_mode = (
|
705
|
+
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
|
706
|
+
)
|
707
|
+
batch.spec_info = res.draft_input
|
708
|
+
|
614
709
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
615
710
|
|
616
711
|
def add_logprob_values(
|
@@ -623,8 +718,16 @@ class EAGLEWorker(TpModelWorker):
|
|
623
718
|
logits_output = res.logits_output
|
624
719
|
top_logprobs_nums = batch.top_logprobs_nums
|
625
720
|
token_ids_logprobs = batch.token_ids_logprobs
|
721
|
+
accepted_indices = res.accepted_indices
|
722
|
+
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
723
|
+
temperatures = batch.sampling_info.temperatures
|
724
|
+
num_draft_tokens = batch.spec_info.draft_token_num
|
725
|
+
# acceptance indices are the indices in a "flattened" batch.
|
726
|
+
# dividing it to num_draft_tokens will yield the actual batch index.
|
727
|
+
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
728
|
+
|
626
729
|
logprobs = torch.nn.functional.log_softmax(
|
627
|
-
logits_output.next_token_logits, dim=-1
|
730
|
+
logits_output.next_token_logits / temperatures, dim=-1
|
628
731
|
)
|
629
732
|
batch_next_token_ids = res.verified_id
|
630
733
|
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
@@ -659,7 +762,7 @@ class EAGLEWorker(TpModelWorker):
|
|
659
762
|
pt = 0
|
660
763
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
661
764
|
verified_ids = batch_next_token_ids.tolist()
|
662
|
-
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
765
|
+
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
|
663
766
|
for _ in range(num_tokens):
|
664
767
|
if req.return_logprob:
|
665
768
|
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
@@ -691,11 +794,13 @@ class EAGLEWorker(TpModelWorker):
|
|
691
794
|
hidden_states=hidden_states,
|
692
795
|
verified_id=next_token_ids,
|
693
796
|
)
|
797
|
+
batch.return_hidden_states = False
|
694
798
|
batch.spec_info.prepare_for_extend(batch)
|
695
799
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
696
800
|
model_worker_batch = batch.get_model_worker_batch(
|
697
801
|
seq_lens_cpu_cache=seq_lens_cpu
|
698
802
|
)
|
803
|
+
model_worker_batch.spec_num_draft_tokens = 1
|
699
804
|
forward_batch = ForwardBatch.init_new(
|
700
805
|
model_worker_batch, self.draft_model_runner
|
701
806
|
)
|
@@ -712,13 +817,33 @@ class EAGLEWorker(TpModelWorker):
|
|
712
817
|
req_pool_indices_backup = batch.req_pool_indices
|
713
818
|
accept_length_backup = batch.spec_info.accept_length
|
714
819
|
return_logprob_backup = batch.return_logprob
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
batch
|
719
|
-
|
720
|
-
|
820
|
+
input_is_idle = batch.forward_mode.is_idle()
|
821
|
+
if not input_is_idle:
|
822
|
+
# Prepare metadata
|
823
|
+
if batch.spec_info.verified_id is not None:
|
824
|
+
batch.spec_info.prepare_extend_after_decode(
|
825
|
+
batch,
|
826
|
+
self.speculative_num_steps,
|
827
|
+
)
|
828
|
+
else:
|
829
|
+
batch = batch.copy()
|
830
|
+
batch.prepare_for_idle()
|
831
|
+
hidden_size = (
|
832
|
+
self.model_config.hidden_size * 3
|
833
|
+
if self.speculative_algorithm.is_eagle3()
|
834
|
+
else self.model_config.hidden_size
|
835
|
+
)
|
836
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
837
|
+
device=self.device,
|
838
|
+
hidden_size=hidden_size,
|
839
|
+
dtype=self.model_config.dtype,
|
840
|
+
topk=self.topk,
|
841
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
842
|
+
)
|
843
|
+
batch.return_hidden_states = False
|
721
844
|
model_worker_batch = batch.get_model_worker_batch()
|
845
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
846
|
+
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
722
847
|
forward_batch = ForwardBatch.init_new(
|
723
848
|
model_worker_batch, self.draft_model_runner
|
724
849
|
)
|
@@ -742,7 +867,10 @@ class EAGLEWorker(TpModelWorker):
|
|
742
867
|
)
|
743
868
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
744
869
|
else:
|
745
|
-
|
870
|
+
if not forward_batch.forward_mode.is_idle():
|
871
|
+
self.draft_model_runner.attn_backend.init_forward_metadata(
|
872
|
+
forward_batch
|
873
|
+
)
|
746
874
|
logits_output = self.draft_model_runner.model.forward(
|
747
875
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
748
876
|
)
|
@@ -752,7 +880,9 @@ class EAGLEWorker(TpModelWorker):
|
|
752
880
|
|
753
881
|
# Restore backup.
|
754
882
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
755
|
-
batch.forward_mode =
|
883
|
+
batch.forward_mode = (
|
884
|
+
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
885
|
+
)
|
756
886
|
batch.seq_lens = seq_lens_backup
|
757
887
|
batch.req_pool_indices = req_pool_indices_backup
|
758
888
|
batch.spec_info.accept_length = accept_length_backup
|
@@ -781,4 +911,48 @@ def load_token_map(token_map_path: str) -> List[int]:
|
|
781
911
|
)
|
782
912
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
783
913
|
hot_token_id = torch.load(token_map_path, weights_only=True)
|
784
|
-
return torch.tensor(hot_token_id, dtype=torch.
|
914
|
+
return torch.tensor(hot_token_id, dtype=torch.int64)
|
915
|
+
|
916
|
+
|
917
|
+
@torch.compile(dynamic=True)
|
918
|
+
def get_last_loc_large_page_size_top_k_1(
|
919
|
+
req_to_token: torch.Tensor,
|
920
|
+
req_pool_indices: torch.Tensor,
|
921
|
+
seq_lens,
|
922
|
+
speculative_num_steps: int,
|
923
|
+
):
|
924
|
+
prefix_lens = seq_lens
|
925
|
+
seq_lens = prefix_lens + speculative_num_steps
|
926
|
+
last_loc = get_last_loc(
|
927
|
+
req_to_token,
|
928
|
+
req_pool_indices,
|
929
|
+
prefix_lens,
|
930
|
+
)
|
931
|
+
return prefix_lens, seq_lens, last_loc
|
932
|
+
|
933
|
+
|
934
|
+
@torch.compile(dynamic=True)
|
935
|
+
def get_last_loc_large_page_size_large_top_k(
|
936
|
+
req_to_token: torch.Tensor,
|
937
|
+
req_pool_indices: torch.Tensor,
|
938
|
+
seq_lens: torch.Tensor,
|
939
|
+
speculative_num_steps: int,
|
940
|
+
topk: int,
|
941
|
+
page_size: int,
|
942
|
+
):
|
943
|
+
prefix_lens = seq_lens
|
944
|
+
last_page_lens = prefix_lens % page_size
|
945
|
+
num_new_pages_per_topk = (
|
946
|
+
last_page_lens + speculative_num_steps + page_size - 1
|
947
|
+
) // page_size
|
948
|
+
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
|
949
|
+
page_size * topk
|
950
|
+
)
|
951
|
+
extend_lens = seq_lens - prefix_lens
|
952
|
+
last_loc = get_last_loc(
|
953
|
+
req_to_token,
|
954
|
+
req_pool_indices,
|
955
|
+
prefix_lens,
|
956
|
+
)
|
957
|
+
|
958
|
+
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
|
@@ -1,11 +1,13 @@
|
|
1
1
|
import logging
|
2
|
+
import threading
|
3
|
+
import time
|
2
4
|
from abc import ABC
|
3
|
-
from contextlib import contextmanager
|
5
|
+
from contextlib import contextmanager, nullcontext
|
4
6
|
|
5
7
|
try:
|
6
8
|
import torch_memory_saver
|
7
9
|
|
8
|
-
|
10
|
+
_memory_saver = torch_memory_saver.torch_memory_saver
|
9
11
|
import_error = None
|
10
12
|
except ImportError as e:
|
11
13
|
import_error = e
|
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
|
|
38
40
|
def configure_subprocess(self):
|
39
41
|
raise NotImplementedError
|
40
42
|
|
41
|
-
def region(self):
|
43
|
+
def region(self, tag: str):
|
42
44
|
raise NotImplementedError
|
43
45
|
|
44
|
-
def pause(self):
|
46
|
+
def pause(self, tag: str):
|
45
47
|
raise NotImplementedError
|
46
48
|
|
47
|
-
def resume(self):
|
49
|
+
def resume(self, tag: str):
|
48
50
|
raise NotImplementedError
|
49
51
|
|
50
52
|
@property
|
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
|
|
53
55
|
|
54
56
|
|
55
57
|
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
58
|
+
"""Adapter for TorchMemorySaver with tag-based control"""
|
59
|
+
|
56
60
|
def configure_subprocess(self):
|
57
61
|
return torch_memory_saver.configure_subprocess()
|
58
62
|
|
59
|
-
def region(self):
|
60
|
-
return
|
63
|
+
def region(self, tag: str):
|
64
|
+
return _memory_saver.region(tag=tag)
|
61
65
|
|
62
|
-
def pause(self):
|
63
|
-
return
|
66
|
+
def pause(self, tag: str):
|
67
|
+
return _memory_saver.pause(tag=tag)
|
64
68
|
|
65
|
-
def resume(self):
|
66
|
-
return
|
69
|
+
def resume(self, tag: str):
|
70
|
+
return _memory_saver.resume(tag=tag)
|
67
71
|
|
68
72
|
@property
|
69
73
|
def enabled(self):
|
70
|
-
return
|
74
|
+
return _memory_saver is not None and _memory_saver.enabled
|
71
75
|
|
72
76
|
|
73
77
|
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|
76
80
|
yield
|
77
81
|
|
78
82
|
@contextmanager
|
79
|
-
def region(self):
|
83
|
+
def region(self, tag: str):
|
80
84
|
yield
|
81
85
|
|
82
|
-
def pause(self):
|
86
|
+
def pause(self, tag: str):
|
83
87
|
pass
|
84
88
|
|
85
|
-
def resume(self):
|
89
|
+
def resume(self, tag: str):
|
86
90
|
pass
|
87
91
|
|
88
92
|
@property
|