sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
|
|
7
7
|
import torch
|
8
8
|
from huggingface_hub import snapshot_download
|
9
9
|
|
10
|
-
from sglang.srt.distributed import
|
11
|
-
|
10
|
+
from sglang.srt.distributed import (
|
11
|
+
GroupCoordinator,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
get_tp_group,
|
14
|
+
patch_tensor_parallel_group,
|
15
|
+
)
|
12
16
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
17
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
14
18
|
from sglang.srt.managers.schedule_batch import (
|
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
|
|
57
61
|
def draft_tp_context(tp_group: GroupCoordinator):
|
58
62
|
# Draft model doesn't use dp and has its own tp group.
|
59
63
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
60
|
-
with
|
64
|
+
with patch_tensor_parallel_group(tp_group):
|
61
65
|
yield
|
62
66
|
|
63
67
|
|
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
|
76
80
|
self.server_args = server_args
|
77
81
|
self.topk = server_args.speculative_eagle_topk
|
78
82
|
self.speculative_num_steps = server_args.speculative_num_steps
|
83
|
+
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
79
84
|
self.enable_nan_detection = server_args.enable_nan_detection
|
80
85
|
self.gpu_id = gpu_id
|
81
86
|
self.device = server_args.device
|
@@ -166,6 +171,10 @@ class EAGLEWorker(TpModelWorker):
|
|
166
171
|
|
167
172
|
def init_attention_backend(self):
|
168
173
|
# Create multi-step attn backends and cuda graph runners
|
174
|
+
|
175
|
+
self.has_prefill_wrapper_verify = False
|
176
|
+
self.draft_extend_attn_backend = None
|
177
|
+
|
169
178
|
if self.server_args.attention_backend == "flashinfer":
|
170
179
|
if not global_server_args_dict["use_mla_backend"]:
|
171
180
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
@@ -213,7 +222,6 @@ class EAGLEWorker(TpModelWorker):
|
|
213
222
|
self.draft_model_runner,
|
214
223
|
skip_prefill=False,
|
215
224
|
)
|
216
|
-
self.has_prefill_wrapper_verify = False
|
217
225
|
elif self.server_args.attention_backend == "fa3":
|
218
226
|
from sglang.srt.layers.attention.flashattention_backend import (
|
219
227
|
FlashAttentionBackend,
|
@@ -229,7 +237,6 @@ class EAGLEWorker(TpModelWorker):
|
|
229
237
|
self.draft_model_runner,
|
230
238
|
skip_prefill=False,
|
231
239
|
)
|
232
|
-
self.has_prefill_wrapper_verify = False
|
233
240
|
elif self.server_args.attention_backend == "flashmla":
|
234
241
|
from sglang.srt.layers.attention.flashmla_backend import (
|
235
242
|
FlashMLAMultiStepDraftBackend,
|
@@ -240,8 +247,6 @@ class EAGLEWorker(TpModelWorker):
|
|
240
247
|
self.topk,
|
241
248
|
self.speculative_num_steps,
|
242
249
|
)
|
243
|
-
self.draft_extend_attn_backend = None
|
244
|
-
self.has_prefill_wrapper_verify = False
|
245
250
|
else:
|
246
251
|
raise ValueError(
|
247
252
|
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
@@ -302,17 +307,27 @@ class EAGLEWorker(TpModelWorker):
|
|
302
307
|
A tuple of the final logit output of the target model, next tokens accepted,
|
303
308
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
304
309
|
"""
|
305
|
-
if batch.forward_mode.
|
310
|
+
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
311
|
+
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
312
|
+
self.forward_target_extend(batch)
|
313
|
+
)
|
314
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
315
|
+
self.forward_draft_extend(
|
316
|
+
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
317
|
+
)
|
318
|
+
return logits_output, next_token_ids, bid, 0, False
|
319
|
+
else:
|
306
320
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
307
321
|
spec_info = self.draft(batch)
|
308
322
|
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
309
323
|
self.verify(batch, spec_info)
|
310
324
|
)
|
311
325
|
|
312
|
-
|
313
|
-
if batch.spec_info.verified_id is not None:
|
326
|
+
if self.check_forward_draft_extend_after_decode(batch):
|
314
327
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
315
|
-
self.forward_draft_extend_after_decode(
|
328
|
+
self.forward_draft_extend_after_decode(
|
329
|
+
batch,
|
330
|
+
)
|
316
331
|
return (
|
317
332
|
logits_output,
|
318
333
|
verify_output.verified_id,
|
@@ -320,22 +335,27 @@ class EAGLEWorker(TpModelWorker):
|
|
320
335
|
sum(verify_output.accept_length_per_req_cpu),
|
321
336
|
can_run_cuda_graph,
|
322
337
|
)
|
323
|
-
elif batch.forward_mode.is_idle():
|
324
|
-
model_worker_batch = batch.get_model_worker_batch()
|
325
|
-
logits_output, next_token_ids, _ = (
|
326
|
-
self.target_worker.forward_batch_generation(model_worker_batch)
|
327
|
-
)
|
328
338
|
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
+
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
340
|
+
local_need_forward = (
|
341
|
+
batch.spec_info.verified_id is not None
|
342
|
+
and batch.spec_info.verified_id.shape[0] > 0
|
343
|
+
)
|
344
|
+
if not self.server_args.enable_dp_attention:
|
345
|
+
return local_need_forward
|
346
|
+
|
347
|
+
global_need_forward = torch.tensor(
|
348
|
+
[
|
349
|
+
(local_need_forward),
|
350
|
+
],
|
351
|
+
dtype=torch.int64,
|
352
|
+
)
|
353
|
+
torch.distributed.all_reduce(
|
354
|
+
global_need_forward, group=get_tp_group().cpu_group
|
355
|
+
)
|
356
|
+
global_need_forward_cnt = global_need_forward[0].item()
|
357
|
+
need_forward = global_need_forward_cnt > 0
|
358
|
+
return need_forward
|
339
359
|
|
340
360
|
def forward_target_extend(
|
341
361
|
self, batch: ScheduleBatch
|
@@ -354,6 +374,7 @@ class EAGLEWorker(TpModelWorker):
|
|
354
374
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
355
375
|
model_worker_batch = batch.get_model_worker_batch()
|
356
376
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
377
|
+
model_worker_batch.spec_num_draft_tokens = 1
|
357
378
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
358
379
|
model_worker_batch
|
359
380
|
)
|
@@ -364,7 +385,7 @@ class EAGLEWorker(TpModelWorker):
|
|
364
385
|
model_worker_batch.seq_lens_cpu,
|
365
386
|
)
|
366
387
|
|
367
|
-
def
|
388
|
+
def _draft_preprocess_decode(self, batch: ScheduleBatch):
|
368
389
|
# Parse args
|
369
390
|
num_seqs = batch.batch_size()
|
370
391
|
spec_info = batch.spec_info
|
@@ -466,10 +487,33 @@ class EAGLEWorker(TpModelWorker):
|
|
466
487
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
467
488
|
batch.return_hidden_states = False
|
468
489
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
490
|
+
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
491
|
+
|
492
|
+
def _draft_preprocess_idle(self, batch: ScheduleBatch):
|
493
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
494
|
+
device=self.device,
|
495
|
+
hidden_size=self.model_config.hidden_size,
|
496
|
+
dtype=self.model_config.dtype,
|
497
|
+
topk=self.topk,
|
498
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
499
|
+
)
|
500
|
+
|
501
|
+
def draft(self, batch: ScheduleBatch):
|
502
|
+
# Parse args
|
503
|
+
if batch.forward_mode.is_idle():
|
504
|
+
self._draft_preprocess_idle(batch)
|
505
|
+
else:
|
506
|
+
self._draft_preprocess_decode(batch)
|
507
|
+
|
508
|
+
spec_info = batch.spec_info
|
509
|
+
|
469
510
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
511
|
+
batch.return_hidden_states = False
|
470
512
|
|
471
513
|
# Get forward batch
|
472
514
|
model_worker_batch = batch.get_model_worker_batch()
|
515
|
+
model_worker_batch.spec_num_draft_tokens = self.topk
|
516
|
+
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
473
517
|
forward_batch = ForwardBatch.init_new(
|
474
518
|
model_worker_batch, self.draft_model_runner
|
475
519
|
)
|
@@ -481,12 +525,18 @@ class EAGLEWorker(TpModelWorker):
|
|
481
525
|
forward_batch
|
482
526
|
)
|
483
527
|
else:
|
484
|
-
|
485
|
-
|
528
|
+
if not forward_batch.forward_mode.is_idle():
|
529
|
+
# Initialize attention backend
|
530
|
+
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
486
531
|
# Run forward steps
|
487
532
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
488
533
|
|
489
|
-
|
534
|
+
if batch.forward_mode.is_idle():
|
535
|
+
return EagleVerifyInput.create_idle_input(
|
536
|
+
self.topk,
|
537
|
+
self.speculative_num_steps,
|
538
|
+
self.speculative_num_draft_tokens,
|
539
|
+
)
|
490
540
|
|
491
541
|
(
|
492
542
|
tree_mask,
|
@@ -504,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
|
|
504
554
|
batch.seq_lens_sum,
|
505
555
|
self.topk,
|
506
556
|
self.speculative_num_steps,
|
507
|
-
self.
|
557
|
+
self.speculative_num_draft_tokens,
|
508
558
|
)
|
509
559
|
|
510
560
|
return EagleVerifyInput(
|
@@ -584,11 +634,16 @@ class EAGLEWorker(TpModelWorker):
|
|
584
634
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
585
635
|
spec_info.prepare_for_verify(batch, self.page_size)
|
586
636
|
batch.return_hidden_states = False
|
587
|
-
batch.forward_mode =
|
637
|
+
batch.forward_mode = (
|
638
|
+
ForwardMode.TARGET_VERIFY
|
639
|
+
if not batch.forward_mode.is_idle()
|
640
|
+
else ForwardMode.IDLE
|
641
|
+
)
|
588
642
|
batch.spec_info = spec_info
|
589
643
|
model_worker_batch = batch.get_model_worker_batch(
|
590
644
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
591
645
|
)
|
646
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
592
647
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
593
648
|
|
594
649
|
if batch.has_grammar:
|
@@ -646,7 +701,9 @@ class EAGLEWorker(TpModelWorker):
|
|
646
701
|
self.add_logprob_values(batch, res, logits_output)
|
647
702
|
|
648
703
|
# Prepare the batch for the next draft forwards.
|
649
|
-
batch.forward_mode =
|
704
|
+
batch.forward_mode = (
|
705
|
+
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
|
706
|
+
)
|
650
707
|
batch.spec_info = res.draft_input
|
651
708
|
|
652
709
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
@@ -743,6 +800,7 @@ class EAGLEWorker(TpModelWorker):
|
|
743
800
|
model_worker_batch = batch.get_model_worker_batch(
|
744
801
|
seq_lens_cpu_cache=seq_lens_cpu
|
745
802
|
)
|
803
|
+
model_worker_batch.spec_num_draft_tokens = 1
|
746
804
|
forward_batch = ForwardBatch.init_new(
|
747
805
|
model_worker_batch, self.draft_model_runner
|
748
806
|
)
|
@@ -759,13 +817,33 @@ class EAGLEWorker(TpModelWorker):
|
|
759
817
|
req_pool_indices_backup = batch.req_pool_indices
|
760
818
|
accept_length_backup = batch.spec_info.accept_length
|
761
819
|
return_logprob_backup = batch.return_logprob
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
batch
|
766
|
-
|
767
|
-
|
820
|
+
input_is_idle = batch.forward_mode.is_idle()
|
821
|
+
if not input_is_idle:
|
822
|
+
# Prepare metadata
|
823
|
+
if batch.spec_info.verified_id is not None:
|
824
|
+
batch.spec_info.prepare_extend_after_decode(
|
825
|
+
batch,
|
826
|
+
self.speculative_num_steps,
|
827
|
+
)
|
828
|
+
else:
|
829
|
+
batch = batch.copy()
|
830
|
+
batch.prepare_for_idle()
|
831
|
+
hidden_size = (
|
832
|
+
self.model_config.hidden_size * 3
|
833
|
+
if self.speculative_algorithm.is_eagle3()
|
834
|
+
else self.model_config.hidden_size
|
835
|
+
)
|
836
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
837
|
+
device=self.device,
|
838
|
+
hidden_size=hidden_size,
|
839
|
+
dtype=self.model_config.dtype,
|
840
|
+
topk=self.topk,
|
841
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
842
|
+
)
|
843
|
+
batch.return_hidden_states = False
|
768
844
|
model_worker_batch = batch.get_model_worker_batch()
|
845
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
846
|
+
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
769
847
|
forward_batch = ForwardBatch.init_new(
|
770
848
|
model_worker_batch, self.draft_model_runner
|
771
849
|
)
|
@@ -789,7 +867,10 @@ class EAGLEWorker(TpModelWorker):
|
|
789
867
|
)
|
790
868
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
791
869
|
else:
|
792
|
-
|
870
|
+
if not forward_batch.forward_mode.is_idle():
|
871
|
+
self.draft_model_runner.attn_backend.init_forward_metadata(
|
872
|
+
forward_batch
|
873
|
+
)
|
793
874
|
logits_output = self.draft_model_runner.model.forward(
|
794
875
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
795
876
|
)
|
@@ -799,7 +880,9 @@ class EAGLEWorker(TpModelWorker):
|
|
799
880
|
|
800
881
|
# Restore backup.
|
801
882
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
802
|
-
batch.forward_mode =
|
883
|
+
batch.forward_mode = (
|
884
|
+
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
885
|
+
)
|
803
886
|
batch.seq_lens = seq_lens_backup
|
804
887
|
batch.req_pool_indices = req_pool_indices_backup
|
805
888
|
batch.spec_info.accept_length = accept_length_backup
|
@@ -1,11 +1,13 @@
|
|
1
1
|
import logging
|
2
|
+
import threading
|
3
|
+
import time
|
2
4
|
from abc import ABC
|
3
|
-
from contextlib import contextmanager
|
5
|
+
from contextlib import contextmanager, nullcontext
|
4
6
|
|
5
7
|
try:
|
6
8
|
import torch_memory_saver
|
7
9
|
|
8
|
-
|
10
|
+
_memory_saver = torch_memory_saver.torch_memory_saver
|
9
11
|
import_error = None
|
10
12
|
except ImportError as e:
|
11
13
|
import_error = e
|
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
|
|
38
40
|
def configure_subprocess(self):
|
39
41
|
raise NotImplementedError
|
40
42
|
|
41
|
-
def region(self):
|
43
|
+
def region(self, tag: str):
|
42
44
|
raise NotImplementedError
|
43
45
|
|
44
|
-
def pause(self):
|
46
|
+
def pause(self, tag: str):
|
45
47
|
raise NotImplementedError
|
46
48
|
|
47
|
-
def resume(self):
|
49
|
+
def resume(self, tag: str):
|
48
50
|
raise NotImplementedError
|
49
51
|
|
50
52
|
@property
|
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
|
|
53
55
|
|
54
56
|
|
55
57
|
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
58
|
+
"""Adapter for TorchMemorySaver with tag-based control"""
|
59
|
+
|
56
60
|
def configure_subprocess(self):
|
57
61
|
return torch_memory_saver.configure_subprocess()
|
58
62
|
|
59
|
-
def region(self):
|
60
|
-
return
|
63
|
+
def region(self, tag: str):
|
64
|
+
return _memory_saver.region(tag=tag)
|
61
65
|
|
62
|
-
def pause(self):
|
63
|
-
return
|
66
|
+
def pause(self, tag: str):
|
67
|
+
return _memory_saver.pause(tag=tag)
|
64
68
|
|
65
|
-
def resume(self):
|
66
|
-
return
|
69
|
+
def resume(self, tag: str):
|
70
|
+
return _memory_saver.resume(tag=tag)
|
67
71
|
|
68
72
|
@property
|
69
73
|
def enabled(self):
|
70
|
-
return
|
74
|
+
return _memory_saver is not None and _memory_saver.enabled
|
71
75
|
|
72
76
|
|
73
77
|
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|
76
80
|
yield
|
77
81
|
|
78
82
|
@contextmanager
|
79
|
-
def region(self):
|
83
|
+
def region(self, tag: str):
|
80
84
|
yield
|
81
85
|
|
82
|
-
def pause(self):
|
86
|
+
def pause(self, tag: str):
|
83
87
|
pass
|
84
88
|
|
85
|
-
def resume(self):
|
89
|
+
def resume(self, tag: str):
|
86
90
|
pass
|
87
91
|
|
88
92
|
@property
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -346,7 +346,10 @@ class TboForwardBatchPreparer:
|
|
346
346
|
)
|
347
347
|
|
348
348
|
# TODO improve, e.g. unify w/ `init_raw`
|
349
|
-
if
|
349
|
+
if (
|
350
|
+
global_server_args_dict["moe_dense_tp_size"] == 1
|
351
|
+
and batch.gathered_buffer is not None
|
352
|
+
):
|
350
353
|
sum_len = end_token_index - start_token_index
|
351
354
|
gathered_buffer = torch.zeros(
|
352
355
|
(sum_len, batch.gathered_buffer.shape[1]),
|