sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
import time
|
4
|
+
from contextlib import contextmanager
|
4
5
|
from typing import List, Optional, Tuple
|
5
6
|
|
6
7
|
import torch
|
7
8
|
from huggingface_hub import snapshot_download
|
8
9
|
|
10
|
+
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
11
|
+
from sglang.srt.layers.dp_attention import disable_dp_size
|
9
12
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
13
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
11
14
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -27,11 +30,23 @@ from sglang.srt.speculative.eagle_utils import (
|
|
27
30
|
fast_topk,
|
28
31
|
select_top_k_tokens,
|
29
32
|
)
|
30
|
-
from sglang.srt.
|
33
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
34
|
+
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
|
35
|
+
|
36
|
+
if is_cuda_available():
|
37
|
+
from sgl_kernel import segment_packbits
|
31
38
|
|
32
39
|
logger = logging.getLogger(__name__)
|
33
40
|
|
34
41
|
|
42
|
+
@contextmanager
|
43
|
+
def draft_tp_context(tp_group: GroupCoordinator):
|
44
|
+
# Draft model doesn't use dp and has its own tp group.
|
45
|
+
# We disable mscclpp now because it doesn't support 2 comm groups.
|
46
|
+
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
|
47
|
+
yield
|
48
|
+
|
49
|
+
|
35
50
|
class EAGLEWorker(TpModelWorker):
|
36
51
|
|
37
52
|
def __init__(
|
@@ -52,6 +67,9 @@ class EAGLEWorker(TpModelWorker):
|
|
52
67
|
self.gpu_id = gpu_id
|
53
68
|
self.device = server_args.device
|
54
69
|
self.target_worker = target_worker
|
70
|
+
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
71
|
+
server_args.speculative_algorithm
|
72
|
+
)
|
55
73
|
|
56
74
|
# Override context length with target model's context length
|
57
75
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
@@ -67,7 +85,13 @@ class EAGLEWorker(TpModelWorker):
|
|
67
85
|
)
|
68
86
|
|
69
87
|
# Load hot token ids
|
70
|
-
if
|
88
|
+
if self.speculative_algorithm.is_eagle3():
|
89
|
+
if server_args.speculative_token_map is not None:
|
90
|
+
logger.warning(
|
91
|
+
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
|
92
|
+
)
|
93
|
+
self.hot_token_id = None
|
94
|
+
elif server_args.speculative_token_map is not None:
|
71
95
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
72
96
|
server_args.json_model_override_args = (
|
73
97
|
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
@@ -76,30 +100,47 @@ class EAGLEWorker(TpModelWorker):
|
|
76
100
|
self.hot_token_id = None
|
77
101
|
|
78
102
|
# Init draft worker
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
103
|
+
with empty_context():
|
104
|
+
super().__init__(
|
105
|
+
gpu_id=gpu_id,
|
106
|
+
tp_rank=tp_rank,
|
107
|
+
server_args=server_args,
|
108
|
+
nccl_port=nccl_port,
|
109
|
+
dp_rank=dp_rank,
|
110
|
+
is_draft_worker=True,
|
111
|
+
req_to_token_pool=self.req_to_token_pool,
|
112
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
113
|
+
)
|
89
114
|
|
90
|
-
# Share the embedding and lm_head
|
91
115
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
116
|
+
|
117
|
+
if self.speculative_algorithm.is_eagle3():
|
118
|
+
# EAGLE3 models don't share lm_head
|
119
|
+
self.draft_model_runner.model.set_embed(embed)
|
120
|
+
|
121
|
+
# grab hot token ids
|
122
|
+
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
|
123
|
+
embed.device
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
if self.hot_token_id is not None:
|
127
|
+
head = head.clone()
|
128
|
+
self.hot_token_id = self.hot_token_id.to(head.device)
|
129
|
+
head.data = head.data[self.hot_token_id]
|
130
|
+
|
131
|
+
# Share the embedding and lm_head
|
132
|
+
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
133
|
+
|
134
|
+
# Init attention backend and cuda graphs
|
97
135
|
self.draft_model_runner.server_args.disable_cuda_graph = (
|
98
136
|
backup_disable_cuda_graph
|
99
137
|
)
|
100
|
-
|
101
|
-
|
102
|
-
|
138
|
+
self.draft_tp_context = (
|
139
|
+
draft_tp_context if server_args.enable_dp_attention else empty_context
|
140
|
+
)
|
141
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
142
|
+
self.init_attention_backend()
|
143
|
+
self.init_cuda_graphs()
|
103
144
|
|
104
145
|
def init_attention_backend(self):
|
105
146
|
# Create multi-step attn backends and cuda graph runners
|
@@ -109,52 +150,70 @@ class EAGLEWorker(TpModelWorker):
|
|
109
150
|
)
|
110
151
|
|
111
152
|
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
112
|
-
self.
|
153
|
+
self.draft_model_runner,
|
113
154
|
self.topk,
|
114
155
|
self.speculative_num_steps,
|
115
156
|
)
|
157
|
+
self.draft_extend_attn_backend = None
|
158
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
159
|
+
self.has_prefill_wrapper_verify = True
|
116
160
|
elif self.server_args.attention_backend == "triton":
|
117
161
|
from sglang.srt.layers.attention.triton_backend import (
|
118
162
|
TritonMultiStepDraftBackend,
|
119
163
|
)
|
120
164
|
|
121
165
|
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
122
|
-
self.
|
166
|
+
self.draft_model_runner,
|
123
167
|
self.topk,
|
124
168
|
self.speculative_num_steps,
|
125
169
|
)
|
170
|
+
self.draft_extend_attn_backend = None
|
171
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
172
|
+
self.has_prefill_wrapper_verify = False
|
126
173
|
elif self.server_args.attention_backend == "flashinfer_mla":
|
127
174
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
128
175
|
FlashInferMLAMultiStepDraftBackend,
|
129
176
|
)
|
130
177
|
|
131
178
|
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
132
|
-
self.
|
179
|
+
self.draft_model_runner,
|
133
180
|
self.topk,
|
134
181
|
self.speculative_num_steps,
|
135
182
|
)
|
183
|
+
self.draft_extend_attn_backend = None
|
184
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
185
|
+
self.has_prefill_wrapper_verify = True
|
136
186
|
else:
|
137
187
|
raise ValueError(
|
138
188
|
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
139
189
|
)
|
190
|
+
|
140
191
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
141
192
|
|
142
193
|
def init_cuda_graphs(self):
|
143
194
|
"""Capture cuda graphs."""
|
144
195
|
self.cuda_graph_runner = None
|
196
|
+
self.cuda_graph_runner_for_draft_extend = None
|
145
197
|
|
146
198
|
if self.server_args.disable_cuda_graph:
|
147
199
|
return
|
148
200
|
|
201
|
+
# Capture draft
|
149
202
|
tic = time.time()
|
203
|
+
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
150
204
|
logger.info(
|
151
|
-
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={
|
205
|
+
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
152
206
|
)
|
153
207
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
208
|
+
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
154
209
|
logger.info(
|
155
|
-
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={
|
210
|
+
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
156
211
|
)
|
157
212
|
|
213
|
+
# Capture extend
|
214
|
+
if self.draft_extend_attn_backend:
|
215
|
+
raise NotImplementedError()
|
216
|
+
|
158
217
|
@property
|
159
218
|
def draft_model_runner(self):
|
160
219
|
return self.model_runner
|
@@ -164,8 +223,8 @@ class EAGLEWorker(TpModelWorker):
|
|
164
223
|
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
165
224
|
"""Run speculative decoding forward.
|
166
225
|
|
167
|
-
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
168
|
-
the final output batch
|
226
|
+
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
227
|
+
the final output batch have the same state as the input.
|
169
228
|
|
170
229
|
Args:
|
171
230
|
batch: The batch to run forward. The state of the batch is modified as it runs.
|
@@ -173,30 +232,42 @@ class EAGLEWorker(TpModelWorker):
|
|
173
232
|
A tuple of the final logit output of the target model, next tokens accepeted,
|
174
233
|
the batch id (used for overlap schedule), and number of accepeted tokens.
|
175
234
|
"""
|
176
|
-
assert not batch.spec_algorithm.is_none()
|
177
235
|
if batch.forward_mode.is_decode():
|
178
|
-
|
236
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
237
|
+
spec_info, to_free_cache_loc = self.draft(batch)
|
179
238
|
logits_output, verify_output, model_worker_batch = self.verify(
|
180
239
|
batch, spec_info
|
181
240
|
)
|
241
|
+
|
182
242
|
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
183
243
|
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
184
|
-
# if it is None, means all requests are finished
|
185
|
-
if batch.spec_info.verified_id is not None:
|
186
|
-
self.forward_draft_extend_after_decode(batch)
|
187
244
|
|
245
|
+
# If it is None, it means all requests are finished
|
246
|
+
if batch.spec_info.verified_id is not None:
|
247
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
248
|
+
self.forward_draft_extend_after_decode(batch)
|
188
249
|
return (
|
189
250
|
logits_output,
|
190
251
|
verify_output.verified_id,
|
191
252
|
model_worker_batch.bid,
|
192
253
|
sum(verify_output.accept_length_per_req_cpu),
|
193
254
|
)
|
194
|
-
|
255
|
+
elif batch.forward_mode.is_idle():
|
256
|
+
model_worker_batch = batch.get_model_worker_batch()
|
257
|
+
logits_output, next_token_ids, _ = (
|
258
|
+
self.target_worker.forward_batch_generation(
|
259
|
+
ForwardBatch.init_new(
|
260
|
+
model_worker_batch, self.target_worker.model_runner
|
261
|
+
)
|
262
|
+
)
|
263
|
+
)
|
264
|
+
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
195
265
|
else:
|
196
266
|
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
197
|
-
self.
|
198
|
-
|
199
|
-
|
267
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
268
|
+
self.forward_draft_extend(
|
269
|
+
batch, logits_output.hidden_states, next_token_ids
|
270
|
+
)
|
200
271
|
return logits_output, next_token_ids, bid, 0
|
201
272
|
|
202
273
|
def forward_target_extend(
|
@@ -226,6 +297,13 @@ class EAGLEWorker(TpModelWorker):
|
|
226
297
|
num_seqs = batch.batch_size()
|
227
298
|
spec_info = batch.spec_info
|
228
299
|
|
300
|
+
# Accumulate penalty
|
301
|
+
if batch.sampling_info.penalizer_orchestrator.is_required:
|
302
|
+
# This is a relaxed version of penalties for speculative decoding.
|
303
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
304
|
+
spec_info.verified_id.to(torch.int64)
|
305
|
+
)
|
306
|
+
|
229
307
|
# Allocate cache locations
|
230
308
|
out_cache_loc = batch.alloc_token_slots(
|
231
309
|
num_seqs * self.topk * self.speculative_num_steps
|
@@ -275,9 +353,7 @@ class EAGLEWorker(TpModelWorker):
|
|
275
353
|
self.topk,
|
276
354
|
self.speculative_num_steps,
|
277
355
|
self.server_args.speculative_num_draft_tokens,
|
278
|
-
batch.sampling_info.is_all_greedy,
|
279
356
|
)
|
280
|
-
|
281
357
|
return ret, out_cache_loc
|
282
358
|
|
283
359
|
def draft_forward(self, forward_batch: ForwardBatch):
|
@@ -307,7 +383,7 @@ class EAGLEWorker(TpModelWorker):
|
|
307
383
|
token_list.append(tree_info[1])
|
308
384
|
parents_list.append(tree_info[2])
|
309
385
|
|
310
|
-
#
|
386
|
+
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
311
387
|
if i == self.speculative_num_steps - 1:
|
312
388
|
break
|
313
389
|
|
@@ -322,7 +398,7 @@ class EAGLEWorker(TpModelWorker):
|
|
322
398
|
spec_info.hidden_states = hidden_states
|
323
399
|
|
324
400
|
# Run forward
|
325
|
-
logits_output = self.
|
401
|
+
logits_output = self.draft_model_runner.model.forward(
|
326
402
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
327
403
|
)
|
328
404
|
self._detect_nan_if_needed(logits_output)
|
@@ -351,11 +427,10 @@ class EAGLEWorker(TpModelWorker):
|
|
351
427
|
# Post process based on verified outputs.
|
352
428
|
# Pick indices that we care (accepeted)
|
353
429
|
logits_output.next_token_logits = logits_output.next_token_logits[
|
354
|
-
res.
|
355
|
-
]
|
356
|
-
logits_output.hidden_states = logits_output.hidden_states[
|
357
|
-
res.accepeted_indices_cpu
|
430
|
+
res.accepeted_indices
|
358
431
|
]
|
432
|
+
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
|
433
|
+
|
359
434
|
# Prepare the batch for the next draft forwards.
|
360
435
|
batch.forward_mode = ForwardMode.DECODE
|
361
436
|
batch.spec_info = res.draft_input
|
@@ -407,7 +482,7 @@ class EAGLEWorker(TpModelWorker):
|
|
407
482
|
batch_next_token_ids,
|
408
483
|
]
|
409
484
|
|
410
|
-
# Add output logprobs to the request
|
485
|
+
# Add output logprobs to the request
|
411
486
|
pt = 0
|
412
487
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
413
488
|
verified_ids = batch_next_token_ids.tolist()
|
@@ -456,27 +531,38 @@ class EAGLEWorker(TpModelWorker):
|
|
456
531
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
457
532
|
|
458
533
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
459
|
-
|
534
|
+
# Backup fileds that will be modified in-place
|
535
|
+
seq_lens_backup = batch.seq_lens.clone()
|
536
|
+
req_pool_indices_backup = batch.req_pool_indices
|
537
|
+
accept_length_backup = batch.spec_info.accept_length
|
538
|
+
return_logprob_backup = batch.return_logprob
|
539
|
+
|
540
|
+
# Prepare metadata
|
460
541
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
461
|
-
batch.spec_info.prepare_extend_after_decode(
|
542
|
+
batch.spec_info.prepare_extend_after_decode(
|
543
|
+
batch,
|
544
|
+
self.speculative_num_steps,
|
545
|
+
)
|
462
546
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
463
|
-
# We don't need logprob for this extend.
|
464
|
-
original_return_logprob = batch.return_logprob
|
465
547
|
batch.return_logprob = False
|
466
548
|
model_worker_batch = batch.get_model_worker_batch()
|
467
549
|
forward_batch = ForwardBatch.init_new(
|
468
550
|
model_worker_batch, self.draft_model_runner
|
469
551
|
)
|
552
|
+
|
553
|
+
# Run
|
470
554
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
555
|
+
|
471
556
|
self._detect_nan_if_needed(logits_output)
|
472
|
-
assert forward_batch.spec_info is batch.spec_info
|
473
557
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
474
558
|
|
475
559
|
# Restore backup.
|
476
560
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
477
|
-
batch.return_logprob = original_return_logprob
|
478
561
|
batch.forward_mode = ForwardMode.DECODE
|
479
562
|
batch.seq_lens = seq_lens_backup
|
563
|
+
batch.req_pool_indices = req_pool_indices_backup
|
564
|
+
batch.spec_info.accept_length = accept_length_backup
|
565
|
+
batch.return_logprob = return_logprob_backup
|
480
566
|
|
481
567
|
def capture_for_decode(
|
482
568
|
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
@@ -489,7 +575,7 @@ class EAGLEWorker(TpModelWorker):
|
|
489
575
|
if self.enable_nan_detection:
|
490
576
|
logits = logits_output.next_token_logits
|
491
577
|
if torch.any(torch.isnan(logits)):
|
492
|
-
logger.
|
578
|
+
logger.error("Detected errors during sampling! NaN in the logits.")
|
493
579
|
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
494
580
|
|
495
581
|
|
@@ -500,5 +586,5 @@ def load_token_map(token_map_path: str) -> List[int]:
|
|
500
586
|
ignore_patterns=["*.bin", "*.safetensors"],
|
501
587
|
)
|
502
588
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
503
|
-
hot_token_id = torch.load(token_map_path)
|
589
|
+
hot_token_id = torch.load(token_map_path, weights_only=True)
|
504
590
|
return torch.tensor(hot_token_id, dtype=torch.int32)
|
@@ -4,17 +4,22 @@ from enum import IntEnum, auto
|
|
4
4
|
class SpeculativeAlgorithm(IntEnum):
|
5
5
|
NONE = auto()
|
6
6
|
EAGLE = auto()
|
7
|
+
EAGLE3 = auto()
|
7
8
|
|
8
9
|
def is_none(self):
|
9
10
|
return self == SpeculativeAlgorithm.NONE
|
10
11
|
|
11
12
|
def is_eagle(self):
|
12
|
-
return self == SpeculativeAlgorithm.EAGLE
|
13
|
+
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
|
14
|
+
|
15
|
+
def is_eagle3(self):
|
16
|
+
return self == SpeculativeAlgorithm.EAGLE3
|
13
17
|
|
14
18
|
@staticmethod
|
15
19
|
def from_string(name: str):
|
16
20
|
name_map = {
|
17
21
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
22
|
+
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
18
23
|
None: SpeculativeAlgorithm.NONE,
|
19
24
|
}
|
20
25
|
if name is not None:
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from abc import ABC
|
2
3
|
from contextlib import contextmanager
|
3
4
|
|
@@ -8,6 +9,8 @@ try:
|
|
8
9
|
except ImportError:
|
9
10
|
pass
|
10
11
|
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
11
14
|
|
12
15
|
class TorchMemorySaverAdapter(ABC):
|
13
16
|
@staticmethod
|
@@ -16,6 +19,13 @@ class TorchMemorySaverAdapter(ABC):
|
|
16
19
|
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
17
20
|
)
|
18
21
|
|
22
|
+
def check_validity(self, caller_name):
|
23
|
+
if not self.enabled:
|
24
|
+
logger.warning(
|
25
|
+
f"`{caller_name}` will not save memory because torch_memory_saver is not enabled. "
|
26
|
+
f"Potential causes: `enable_memory_saver` is false, or torch_memory_saver has installation issues."
|
27
|
+
)
|
28
|
+
|
19
29
|
def configure_subprocess(self):
|
20
30
|
raise NotImplementedError
|
21
31
|
|
@@ -28,6 +38,10 @@ class TorchMemorySaverAdapter(ABC):
|
|
28
38
|
def resume(self):
|
29
39
|
raise NotImplementedError
|
30
40
|
|
41
|
+
@property
|
42
|
+
def enabled(self):
|
43
|
+
raise NotImplementedError
|
44
|
+
|
31
45
|
|
32
46
|
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
33
47
|
def configure_subprocess(self):
|
@@ -42,6 +56,10 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
|
42
56
|
def resume(self):
|
43
57
|
return _primary_memory_saver.resume()
|
44
58
|
|
59
|
+
@property
|
60
|
+
def enabled(self):
|
61
|
+
return _primary_memory_saver.enabled
|
62
|
+
|
45
63
|
|
46
64
|
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
47
65
|
@contextmanager
|
@@ -57,3 +75,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|
57
75
|
|
58
76
|
def resume(self):
|
59
77
|
pass
|
78
|
+
|
79
|
+
@property
|
80
|
+
def enabled(self):
|
81
|
+
return False
|