sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -21,19 +21,22 @@ from sglang.srt.managers.schedule_batch import (
|
|
21
21
|
get_last_loc,
|
22
22
|
global_server_args_dict,
|
23
23
|
)
|
24
|
-
from sglang.srt.mem_cache.
|
24
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
27
29
|
|
28
30
|
if is_cuda():
|
29
31
|
from sgl_kernel import (
|
32
|
+
fast_topk,
|
30
33
|
top_k_renorm_prob,
|
31
34
|
top_p_renorm_prob,
|
32
35
|
tree_speculative_sampling_target_only,
|
33
36
|
verify_tree_greedy,
|
34
37
|
)
|
35
38
|
elif is_hip():
|
36
|
-
from sgl_kernel import verify_tree_greedy
|
39
|
+
from sgl_kernel import fast_topk, verify_tree_greedy
|
37
40
|
|
38
41
|
|
39
42
|
logger = logging.getLogger(__name__)
|
@@ -67,9 +70,9 @@ class EagleDraftInput:
|
|
67
70
|
kv_indptr: torch.Tensor = None
|
68
71
|
kv_indices: torch.Tensor = None
|
69
72
|
|
70
|
-
all_padding_lens: Optional[torch.Tensor] = None
|
71
|
-
|
72
73
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
74
|
+
if batch.forward_mode.is_idle():
|
75
|
+
return
|
73
76
|
# Prefill only generate 1 token.
|
74
77
|
assert len(self.verified_id) == len(batch.seq_lens)
|
75
78
|
|
@@ -81,6 +84,25 @@ class EagleDraftInput:
|
|
81
84
|
)
|
82
85
|
pt += extend_len
|
83
86
|
|
87
|
+
@classmethod
|
88
|
+
def create_idle_input(
|
89
|
+
cls,
|
90
|
+
device: torch.device,
|
91
|
+
hidden_size: int,
|
92
|
+
dtype: torch.dtype,
|
93
|
+
topk: int,
|
94
|
+
capture_hidden_mode: CaptureHiddenMode,
|
95
|
+
):
|
96
|
+
return cls(
|
97
|
+
verified_id=None,
|
98
|
+
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
99
|
+
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
100
|
+
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
101
|
+
capture_hidden_mode=capture_hidden_mode,
|
102
|
+
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
|
103
|
+
accept_length_cpu=[],
|
104
|
+
)
|
105
|
+
|
84
106
|
def prepare_extend_after_decode(
|
85
107
|
self,
|
86
108
|
batch: ScheduleBatch,
|
@@ -93,6 +115,7 @@ class EagleDraftInput:
|
|
93
115
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
94
116
|
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
95
117
|
batch.return_logprob = False
|
118
|
+
batch.return_hidden_states = False
|
96
119
|
|
97
120
|
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
98
121
|
self.accept_length.add_(1)
|
@@ -116,13 +139,14 @@ class EagleDraftInput:
|
|
116
139
|
req_to_token: torch.Tensor,
|
117
140
|
):
|
118
141
|
bs = self.accept_length.numel()
|
119
|
-
|
120
142
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
121
143
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
122
|
-
|
123
144
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
124
145
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
125
146
|
|
147
|
+
if paged_kernel_lens_sum is None:
|
148
|
+
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
149
|
+
|
126
150
|
kv_indices = torch.empty(
|
127
151
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
128
152
|
)
|
@@ -136,7 +160,6 @@ class EagleDraftInput:
|
|
136
160
|
kv_indices,
|
137
161
|
req_to_token.size(1),
|
138
162
|
)
|
139
|
-
|
140
163
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
141
164
|
|
142
165
|
def filter_batch(self, new_indices: torch.Tensor):
|
@@ -193,7 +216,35 @@ class EagleVerifyInput:
|
|
193
216
|
seq_lens_cpu: torch.Tensor
|
194
217
|
grammar: BaseGrammarObject = None
|
195
218
|
|
219
|
+
@classmethod
|
220
|
+
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
221
|
+
return cls(
|
222
|
+
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
|
223
|
+
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
|
224
|
+
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
|
225
|
+
retrive_index=torch.full(
|
226
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
227
|
+
),
|
228
|
+
retrive_next_token=torch.full(
|
229
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
230
|
+
),
|
231
|
+
retrive_next_sibling=torch.full(
|
232
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
233
|
+
),
|
234
|
+
retrive_cum_len=None,
|
235
|
+
topk=topk,
|
236
|
+
draft_token_num=num_verify_tokens,
|
237
|
+
spec_steps=spec_steps,
|
238
|
+
capture_hidden_mode=CaptureHiddenMode.FULL,
|
239
|
+
seq_lens_sum=0,
|
240
|
+
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
|
241
|
+
)
|
242
|
+
|
196
243
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
244
|
+
|
245
|
+
if batch.forward_mode.is_idle():
|
246
|
+
return
|
247
|
+
|
197
248
|
batch.input_ids = self.draft_token
|
198
249
|
|
199
250
|
if page_size == 1:
|
@@ -265,9 +316,9 @@ class EagleVerifyInput:
|
|
265
316
|
self,
|
266
317
|
batch: ScheduleBatch,
|
267
318
|
logits_output: torch.Tensor,
|
268
|
-
token_to_kv_pool_allocator:
|
319
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
269
320
|
page_size: int,
|
270
|
-
vocab_mask: Optional[torch.Tensor] = None,
|
321
|
+
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
271
322
|
) -> torch.Tensor:
|
272
323
|
"""
|
273
324
|
Verify and find accepted tokens based on logits output and batch
|
@@ -279,6 +330,26 @@ class EagleVerifyInput:
|
|
279
330
|
tokens. I.e., logits_output.next_token_logits only contains
|
280
331
|
accepted token logits.
|
281
332
|
"""
|
333
|
+
if batch.forward_mode.is_idle():
|
334
|
+
return EagleVerifyOutput(
|
335
|
+
draft_input=EagleDraftInput.create_idle_input(
|
336
|
+
device=batch.device,
|
337
|
+
hidden_size=batch.model_config.hidden_size,
|
338
|
+
dtype=batch.model_config.dtype,
|
339
|
+
topk=self.topk,
|
340
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
341
|
+
),
|
342
|
+
logits_output=logits_output,
|
343
|
+
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
|
344
|
+
accept_length_per_req_cpu=[],
|
345
|
+
accepted_indices=torch.full(
|
346
|
+
(0, self.spec_steps + 1),
|
347
|
+
-1,
|
348
|
+
dtype=torch.int32,
|
349
|
+
device=batch.device,
|
350
|
+
),
|
351
|
+
)
|
352
|
+
|
282
353
|
bs = self.retrive_index.shape[0]
|
283
354
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
284
355
|
sampling_info = batch.sampling_info
|
@@ -291,6 +362,14 @@ class EagleVerifyInput:
|
|
291
362
|
)
|
292
363
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
293
364
|
|
365
|
+
# Apply the custom logit processors if registered in the sampling info.
|
366
|
+
if sampling_info.has_custom_logit_processor:
|
367
|
+
apply_custom_logit_processor(
|
368
|
+
logits_output.next_token_logits,
|
369
|
+
sampling_info,
|
370
|
+
num_tokens_in_batch=self.draft_token_num,
|
371
|
+
)
|
372
|
+
|
294
373
|
# Apply penalty
|
295
374
|
if sampling_info.penalizer_orchestrator.is_required:
|
296
375
|
# This is a relaxed version of penalties for speculative decoding.
|
@@ -320,11 +399,11 @@ class EagleVerifyInput:
|
|
320
399
|
predicts=predict, # mutable
|
321
400
|
accept_index=accept_index, # mutable
|
322
401
|
accept_token_num=accept_length, # mutable
|
323
|
-
candidates=candidates
|
324
|
-
retrive_index=self.retrive_index
|
325
|
-
retrive_next_token=self.retrive_next_token
|
326
|
-
retrive_next_sibling=self.retrive_next_sibling
|
327
|
-
target_predict=target_predict
|
402
|
+
candidates=candidates,
|
403
|
+
retrive_index=self.retrive_index,
|
404
|
+
retrive_next_token=self.retrive_next_token,
|
405
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
406
|
+
target_predict=target_predict,
|
328
407
|
)
|
329
408
|
else:
|
330
409
|
# apply temperature and get target probs
|
@@ -352,16 +431,23 @@ class EagleVerifyInput:
|
|
352
431
|
draft_probs = torch.zeros(
|
353
432
|
target_probs.shape, dtype=torch.float32, device="cuda"
|
354
433
|
)
|
434
|
+
|
435
|
+
# coins for rejection sampling
|
355
436
|
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
|
437
|
+
# coins for final sampling
|
438
|
+
coins_for_final_sampling = torch.rand(
|
439
|
+
(bs,), dtype=torch.float32, device="cuda"
|
440
|
+
)
|
356
441
|
tree_speculative_sampling_target_only(
|
357
442
|
predicts=predict, # mutable
|
358
443
|
accept_index=accept_index, # mutable
|
359
444
|
accept_token_num=accept_length, # mutable
|
360
|
-
candidates=candidates
|
361
|
-
retrive_index=self.retrive_index
|
362
|
-
retrive_next_token=self.retrive_next_token
|
363
|
-
retrive_next_sibling=self.retrive_next_sibling
|
445
|
+
candidates=candidates,
|
446
|
+
retrive_index=self.retrive_index,
|
447
|
+
retrive_next_token=self.retrive_next_token,
|
448
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
364
449
|
uniform_samples=coins,
|
450
|
+
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
365
451
|
target_probs=target_probs,
|
366
452
|
draft_probs=draft_probs,
|
367
453
|
threshold_single=global_server_args_dict[
|
@@ -384,8 +470,8 @@ class EagleVerifyInput:
|
|
384
470
|
spec_steps=self.spec_steps,
|
385
471
|
)
|
386
472
|
|
387
|
-
new_accept_index = []
|
388
473
|
unfinished_index = []
|
474
|
+
unfinished_accept_index = []
|
389
475
|
accept_index_cpu = accept_index.tolist()
|
390
476
|
predict_cpu = predict.tolist()
|
391
477
|
has_finished = False
|
@@ -393,12 +479,10 @@ class EagleVerifyInput:
|
|
393
479
|
# Iterate every accepted token and check if req has finished after append the token
|
394
480
|
# should be checked BEFORE free kv cache slots
|
395
481
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
396
|
-
new_accept_index_ = []
|
397
482
|
for j, idx in enumerate(accept_index_row):
|
398
483
|
if idx == -1:
|
399
484
|
break
|
400
485
|
id = predict_cpu[idx]
|
401
|
-
# if not found_finished:
|
402
486
|
req.output_ids.append(id)
|
403
487
|
req.check_finished()
|
404
488
|
if req.finished():
|
@@ -407,8 +491,6 @@ class EagleVerifyInput:
|
|
407
491
|
accept_index[i, j + 1 :] = -1
|
408
492
|
break
|
409
493
|
else:
|
410
|
-
new_accept_index_.append(idx)
|
411
|
-
# update grammar state
|
412
494
|
if req.grammar is not None:
|
413
495
|
try:
|
414
496
|
req.grammar.accept_token(id)
|
@@ -418,50 +500,104 @@ class EagleVerifyInput:
|
|
418
500
|
)
|
419
501
|
raise e
|
420
502
|
if not req.finished():
|
421
|
-
new_accept_index.extend(new_accept_index_)
|
422
503
|
unfinished_index.append(i)
|
504
|
+
if idx == -1:
|
505
|
+
unfinished_accept_index.append(accept_index[i, :j])
|
506
|
+
else:
|
507
|
+
unfinished_accept_index.append(accept_index[i])
|
423
508
|
req.spec_verify_ct += 1
|
424
509
|
|
425
510
|
if has_finished:
|
426
511
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
427
512
|
|
428
513
|
# Free the KV cache for unaccepted tokens
|
514
|
+
# TODO: fuse them
|
429
515
|
accept_index = accept_index[accept_index != -1]
|
430
516
|
verified_id = predict[accept_index]
|
431
517
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
432
518
|
evict_mask[accept_index] = False
|
433
519
|
|
434
|
-
if page_size
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
520
|
+
if page_size == 1:
|
521
|
+
# TODO: boolean array index leads to a device sync. Remove it.
|
522
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
523
|
+
else:
|
524
|
+
if self.topk == 1:
|
525
|
+
# Only evict full empty page. Do not evict partial empty page
|
526
|
+
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
527
|
+
batch.seq_lens,
|
528
|
+
evict_mask,
|
529
|
+
page_size,
|
530
|
+
self.draft_token_num,
|
531
|
+
next_power_of_2(self.draft_token_num),
|
532
|
+
)
|
533
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
534
|
+
else:
|
535
|
+
# Shift the accepted tokens to the beginning.
|
536
|
+
# Only evict the last part
|
537
|
+
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
538
|
+
batch.seq_lens,
|
539
|
+
batch.out_cache_loc,
|
540
|
+
accept_index,
|
541
|
+
accept_length,
|
542
|
+
self.draft_token_num,
|
543
|
+
page_size,
|
544
|
+
)
|
545
|
+
to_free_slots = torch.empty(
|
546
|
+
(to_free_num_slots.sum().item(),),
|
547
|
+
dtype=torch.int64,
|
548
|
+
device=to_free_num_slots.device,
|
549
|
+
)
|
442
550
|
|
443
|
-
|
551
|
+
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
552
|
+
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
553
|
+
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
554
|
+
# to_free_slots: [ 2, 5, 7 8]
|
555
|
+
# to_free_slots also needs to be page-aligned without the first partial page
|
556
|
+
#
|
557
|
+
# split each row of out_cache_loc into two parts.
|
558
|
+
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
559
|
+
# 2. the second part goes to to_free_slots.
|
560
|
+
get_target_cache_loc[(bs,)](
|
561
|
+
tgt_cache_loc,
|
562
|
+
to_free_slots,
|
563
|
+
accept_length,
|
564
|
+
to_free_num_slots,
|
565
|
+
batch.out_cache_loc,
|
566
|
+
self.draft_token_num,
|
567
|
+
next_power_of_2(self.draft_token_num),
|
568
|
+
next_power_of_2(bs),
|
569
|
+
)
|
570
|
+
|
571
|
+
# Free the kv cache
|
572
|
+
token_to_kv_pool_allocator.free(to_free_slots)
|
573
|
+
|
574
|
+
# Copy the kv cache
|
575
|
+
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
576
|
+
tgt_cache_loc, src_cache_loc
|
577
|
+
)
|
444
578
|
|
445
579
|
# Construct EagleVerifyOutput
|
446
580
|
if not has_finished:
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
581
|
+
if page_size == 1 or self.topk == 1:
|
582
|
+
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
583
|
+
assign_req_to_token_pool[(bs,)](
|
584
|
+
batch.req_pool_indices,
|
585
|
+
batch.req_to_token_pool.req_to_token,
|
586
|
+
batch.seq_lens,
|
587
|
+
batch.seq_lens + accept_length + 1,
|
588
|
+
batch.out_cache_loc,
|
589
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
590
|
+
next_power_of_2(bs),
|
591
|
+
)
|
592
|
+
else:
|
593
|
+
batch.out_cache_loc = tgt_cache_loc
|
457
594
|
batch.seq_lens.add_(accept_length + 1)
|
458
|
-
accept_length_cpu = accept_length.tolist()
|
459
595
|
|
460
596
|
draft_input = EagleDraftInput()
|
461
597
|
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
|
462
598
|
draft_input.verified_id = verified_id
|
463
599
|
draft_input.accept_length = accept_length
|
464
|
-
draft_input.accept_length_cpu =
|
600
|
+
draft_input.accept_length_cpu = accept_length.tolist()
|
465
601
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
466
602
|
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
467
603
|
|
@@ -469,47 +605,66 @@ class EagleVerifyInput:
|
|
469
605
|
draft_input=draft_input,
|
470
606
|
logits_output=logits_output,
|
471
607
|
verified_id=verified_id,
|
472
|
-
accept_length_per_req_cpu=accept_length_cpu,
|
608
|
+
accept_length_per_req_cpu=draft_input.accept_length_cpu,
|
473
609
|
accepted_indices=accept_index,
|
474
610
|
)
|
475
611
|
else:
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
612
|
+
if page_size == 1 or self.topk == 1:
|
613
|
+
assign_req_to_token_pool[(bs,)](
|
614
|
+
batch.req_pool_indices,
|
615
|
+
batch.req_to_token_pool.req_to_token,
|
616
|
+
batch.seq_lens,
|
617
|
+
batch.seq_lens + accept_length + 1,
|
618
|
+
batch.out_cache_loc[accept_index],
|
619
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
620
|
+
next_power_of_2(bs),
|
621
|
+
)
|
622
|
+
batch.seq_lens.add_(accept_length + 1)
|
487
623
|
|
624
|
+
accept_length_cpu = accept_length.tolist()
|
488
625
|
draft_input = EagleDraftInput()
|
489
|
-
if len(
|
490
|
-
|
491
|
-
unfinished_index_device = torch.tensor(
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
draft_input.verified_id = predict[new_accept_index]
|
496
|
-
draft_input.accept_length_cpu = [
|
626
|
+
if len(unfinished_accept_index) > 0:
|
627
|
+
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
628
|
+
unfinished_index_device = torch.tensor(
|
629
|
+
unfinished_index, dtype=torch.int64, device=predict.device
|
630
|
+
)
|
631
|
+
draft_input_accept_length_cpu = [
|
497
632
|
accept_length_cpu[i] for i in unfinished_index
|
498
633
|
]
|
499
|
-
|
500
|
-
|
501
|
-
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
502
|
-
unfinished_index_device
|
503
|
-
]
|
504
|
-
draft_input.req_pool_indices_for_draft_extend = (
|
505
|
-
batch.req_pool_indices[unfinished_index_device]
|
506
|
-
)
|
634
|
+
if page_size == 1 or self.topk == 1:
|
635
|
+
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
507
636
|
else:
|
508
|
-
|
509
|
-
|
510
|
-
|
637
|
+
batch.out_cache_loc = torch.empty(
|
638
|
+
len(unfinished_index) + sum(draft_input_accept_length_cpu),
|
639
|
+
dtype=torch.int64,
|
640
|
+
device=predict.device,
|
641
|
+
)
|
642
|
+
accept_length_filter = create_accept_length_filter(
|
643
|
+
accept_length,
|
644
|
+
unfinished_index_device,
|
645
|
+
batch.seq_lens,
|
646
|
+
)
|
647
|
+
filter_finished_cache_loc_kernel[(bs,)](
|
648
|
+
batch.out_cache_loc,
|
649
|
+
tgt_cache_loc,
|
650
|
+
accept_length,
|
651
|
+
accept_length_filter,
|
652
|
+
next_power_of_2(bs),
|
653
|
+
next_power_of_2(self.draft_token_num),
|
511
654
|
)
|
512
|
-
|
655
|
+
|
656
|
+
draft_input.hidden_states = batch.spec_info.hidden_states[
|
657
|
+
unfinished_accept_index
|
658
|
+
]
|
659
|
+
draft_input.verified_id = predict[unfinished_accept_index]
|
660
|
+
draft_input.accept_length_cpu = draft_input_accept_length_cpu
|
661
|
+
draft_input.accept_length = accept_length[unfinished_index_device]
|
662
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
663
|
+
unfinished_index_device
|
664
|
+
]
|
665
|
+
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
666
|
+
unfinished_index_device
|
667
|
+
]
|
513
668
|
|
514
669
|
return EagleVerifyOutput(
|
515
670
|
draft_input=draft_input,
|
@@ -586,36 +741,75 @@ def assign_draft_cache_locs(
|
|
586
741
|
req_pool_indices,
|
587
742
|
req_to_token,
|
588
743
|
seq_lens,
|
744
|
+
extend_lens,
|
745
|
+
num_new_pages_per_topk,
|
589
746
|
out_cache_loc,
|
590
747
|
pool_len: tl.constexpr,
|
591
748
|
topk: tl.constexpr,
|
592
749
|
speculative_num_steps: tl.constexpr,
|
593
750
|
page_size: tl.constexpr,
|
751
|
+
bs_upper: tl.constexpr,
|
752
|
+
iter_upper: tl.constexpr,
|
594
753
|
):
|
595
|
-
BLOCK_SIZE: tl.constexpr =
|
754
|
+
BLOCK_SIZE: tl.constexpr = 128
|
596
755
|
pid = tl.program_id(axis=0)
|
597
|
-
kv_start = tl.load(seq_lens + pid)
|
598
756
|
|
599
757
|
if page_size == 1 or topk == 1:
|
600
|
-
|
758
|
+
copy_len = topk * speculative_num_steps
|
601
759
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
602
760
|
else:
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
) // page_size
|
608
|
-
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
761
|
+
bs_offset = tl.arange(0, bs_upper)
|
762
|
+
copy_len = tl.load(extend_lens + pid)
|
763
|
+
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
764
|
+
out_cache_ptr = out_cache_loc + cum_copy_len
|
609
765
|
|
766
|
+
# Part 1: Copy from out_cache_loc to req_to_token
|
767
|
+
kv_start = tl.load(seq_lens + pid)
|
610
768
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
611
|
-
|
612
|
-
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
769
|
+
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
613
770
|
for i in range(num_loop):
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
771
|
+
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
772
|
+
mask = copy_offset < copy_len
|
773
|
+
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
774
|
+
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
775
|
+
|
776
|
+
if page_size == 1 or topk == 1:
|
777
|
+
return
|
778
|
+
|
779
|
+
# Part 2: Copy the indices for the last partial page
|
780
|
+
prefix_len = tl.load(seq_lens + pid)
|
781
|
+
last_page_len = prefix_len % page_size
|
782
|
+
offsets = tl.arange(0, page_size)
|
783
|
+
mask = offsets < last_page_len
|
784
|
+
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
785
|
+
prefix_base = token_pool + prefix_len - last_page_len
|
786
|
+
|
787
|
+
for topk_id in range(topk):
|
788
|
+
value = tl.load(prefix_base + offsets, mask=mask)
|
789
|
+
tl.store(
|
790
|
+
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
791
|
+
value,
|
792
|
+
mask=mask,
|
793
|
+
)
|
794
|
+
|
795
|
+
# Part 3: Remove the padding in out_cache_loc
|
796
|
+
iter_offest = tl.arange(0, iter_upper)
|
797
|
+
for topk_id in range(topk):
|
798
|
+
indices = tl.load(
|
799
|
+
prefix_base
|
800
|
+
+ topk_id * num_new_pages_per_topk_ * page_size
|
801
|
+
+ last_page_len
|
802
|
+
+ iter_offest,
|
803
|
+
mask=iter_offest < speculative_num_steps,
|
804
|
+
)
|
805
|
+
tl.store(
|
806
|
+
out_cache_loc
|
807
|
+
+ pid * topk * speculative_num_steps
|
808
|
+
+ topk_id * speculative_num_steps
|
809
|
+
+ iter_offest,
|
810
|
+
indices,
|
811
|
+
mask=iter_offest < speculative_num_steps,
|
812
|
+
)
|
619
813
|
|
620
814
|
|
621
815
|
@triton.jit
|
@@ -626,20 +820,23 @@ def generate_draft_decode_kv_indices(
|
|
626
820
|
kv_indices,
|
627
821
|
kv_indptr,
|
628
822
|
positions,
|
629
|
-
num_seqs: tl.constexpr,
|
630
|
-
topk: tl.constexpr,
|
631
823
|
pool_len: tl.constexpr,
|
632
824
|
kv_indices_stride: tl.constexpr,
|
633
825
|
kv_indptr_stride: tl.constexpr,
|
634
826
|
bs_upper: tl.constexpr,
|
635
827
|
iter_upper: tl.constexpr,
|
636
828
|
num_tokens_upper: tl.constexpr,
|
829
|
+
page_size: tl.constexpr,
|
637
830
|
):
|
638
831
|
BLOCK_SIZE: tl.constexpr = 128
|
639
832
|
iters = tl.program_id(axis=0)
|
640
833
|
bid = tl.program_id(axis=1)
|
641
834
|
topk_id = tl.program_id(axis=2)
|
642
835
|
|
836
|
+
num_steps = tl.num_programs(axis=0)
|
837
|
+
num_seqs = tl.num_programs(axis=1)
|
838
|
+
topk = tl.num_programs(axis=2)
|
839
|
+
|
643
840
|
kv_indices += kv_indices_stride * iters
|
644
841
|
kv_indptr += kv_indptr_stride * iters
|
645
842
|
iters += 1
|
@@ -649,6 +846,7 @@ def generate_draft_decode_kv_indices(
|
|
649
846
|
seq_len = tl.load(paged_kernel_lens + bid)
|
650
847
|
cum_seq_len = tl.sum(seq_lens)
|
651
848
|
|
849
|
+
# Update kv_indices
|
652
850
|
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
653
851
|
kv_ptr = kv_indices + kv_offset
|
654
852
|
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
@@ -662,10 +860,26 @@ def generate_draft_decode_kv_indices(
|
|
662
860
|
kv_offset += BLOCK_SIZE
|
663
861
|
|
664
862
|
extend_offset = tl.arange(0, iter_upper)
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
863
|
+
if page_size == 1 or topk == 1:
|
864
|
+
extend_data = tl.load(
|
865
|
+
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
866
|
+
mask=extend_offset < iters,
|
867
|
+
)
|
868
|
+
else:
|
869
|
+
prefix_len = seq_len
|
870
|
+
last_page_len = prefix_len % page_size
|
871
|
+
num_new_pages_per_topk = (
|
872
|
+
last_page_len + num_steps + page_size - 1
|
873
|
+
) // page_size
|
874
|
+
prefix_base = seq_len // page_size * page_size
|
875
|
+
start = (
|
876
|
+
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
877
|
+
)
|
878
|
+
extend_data = tl.load(
|
879
|
+
token_pool_ptr + start + extend_offset,
|
880
|
+
mask=extend_offset < iters,
|
881
|
+
)
|
882
|
+
|
669
883
|
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
670
884
|
|
671
885
|
# Update kv_indptr
|
@@ -704,6 +918,116 @@ def align_evict_mask_to_page_size(
|
|
704
918
|
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
705
919
|
|
706
920
|
|
921
|
+
@triton.jit
|
922
|
+
def get_target_cache_loc(
|
923
|
+
tgt_cache_loc,
|
924
|
+
to_free_slots,
|
925
|
+
accept_length,
|
926
|
+
to_free_num_slots,
|
927
|
+
out_cache_loc,
|
928
|
+
num_verify_tokens: tl.constexpr,
|
929
|
+
num_verify_tokens_upper: tl.constexpr,
|
930
|
+
bs_upper: tl.constexpr,
|
931
|
+
):
|
932
|
+
bid = tl.program_id(axis=0)
|
933
|
+
offset = tl.arange(0, num_verify_tokens_upper)
|
934
|
+
bs_offset = tl.arange(0, bs_upper)
|
935
|
+
|
936
|
+
# write the first part to tgt_cache_loc
|
937
|
+
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
938
|
+
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
939
|
+
copy_len = tl.load(accept_length + bid) + 1
|
940
|
+
out_cache_loc_row = tl.load(
|
941
|
+
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
942
|
+
)
|
943
|
+
tl.store(
|
944
|
+
tgt_cache_loc + tgt_cache_loc_start + offset,
|
945
|
+
out_cache_loc_row,
|
946
|
+
mask=offset < copy_len,
|
947
|
+
)
|
948
|
+
|
949
|
+
# write the second part to to_free_num_pages
|
950
|
+
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
951
|
+
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
952
|
+
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
953
|
+
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
954
|
+
|
955
|
+
copy_len = to_free_num_slots_cur
|
956
|
+
out_cache_loc_row = tl.load(
|
957
|
+
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
958
|
+
mask=offset < copy_len,
|
959
|
+
)
|
960
|
+
tl.store(
|
961
|
+
to_free_slots + to_free_slots_start + offset,
|
962
|
+
out_cache_loc_row,
|
963
|
+
mask=offset < copy_len,
|
964
|
+
)
|
965
|
+
|
966
|
+
|
967
|
+
@torch.compile(dynamic=True)
|
968
|
+
def get_src_tgt_cache_loc(
|
969
|
+
seq_lens: torch.Tensor,
|
970
|
+
out_cache_loc: torch.Tensor,
|
971
|
+
accept_index: torch.Tensor,
|
972
|
+
accept_length: torch.Tensor,
|
973
|
+
draft_token_num: int,
|
974
|
+
page_size: int,
|
975
|
+
):
|
976
|
+
src_cache_loc = out_cache_loc[accept_index]
|
977
|
+
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
978
|
+
extended_len = seq_lens + draft_token_num
|
979
|
+
keep_len = torch.minimum(
|
980
|
+
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
981
|
+
extended_len,
|
982
|
+
)
|
983
|
+
to_free_num_slots = extended_len - keep_len
|
984
|
+
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
985
|
+
|
986
|
+
|
987
|
+
@triton.jit
|
988
|
+
def filter_finished_cache_loc_kernel(
|
989
|
+
out_cache_loc,
|
990
|
+
tgt_cache_loc,
|
991
|
+
accept_length,
|
992
|
+
accept_length_filter,
|
993
|
+
bs_upper: tl.constexpr,
|
994
|
+
num_verify_tokens_upper: tl.constexpr,
|
995
|
+
):
|
996
|
+
bid = tl.program_id(0)
|
997
|
+
bs_offset = tl.arange(0, bs_upper)
|
998
|
+
|
999
|
+
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
1000
|
+
old_start = tl.sum(accept_length_all) + bid
|
1001
|
+
|
1002
|
+
accept_length_filter_all = tl.load(
|
1003
|
+
accept_length_filter + bs_offset, mask=bs_offset < bid
|
1004
|
+
)
|
1005
|
+
new_start = tl.sum(accept_length_filter_all)
|
1006
|
+
|
1007
|
+
copy_len = tl.load(accept_length_filter + bid)
|
1008
|
+
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
1009
|
+
value = tl.load(
|
1010
|
+
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
1011
|
+
)
|
1012
|
+
tl.store(
|
1013
|
+
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
1014
|
+
)
|
1015
|
+
|
1016
|
+
|
1017
|
+
@torch.compile(dynamic=True)
|
1018
|
+
def create_accept_length_filter(
|
1019
|
+
accept_length: torch.Tensor,
|
1020
|
+
unfinished_index_device: torch.Tensor,
|
1021
|
+
seq_lens: torch.Tensor,
|
1022
|
+
):
|
1023
|
+
accept_length_filter = torch.zeros_like(accept_length)
|
1024
|
+
accept_length_filter[unfinished_index_device] = (
|
1025
|
+
accept_length[unfinished_index_device] + 1
|
1026
|
+
)
|
1027
|
+
seq_lens.add_(accept_length + 1)
|
1028
|
+
return accept_length_filter
|
1029
|
+
|
1030
|
+
|
707
1031
|
@torch.compile(dynamic=True)
|
708
1032
|
def select_top_k_tokens(
|
709
1033
|
i: int,
|
@@ -739,10 +1063,11 @@ def select_top_k_tokens(
|
|
739
1063
|
topk_index = topk_index.reshape(-1, topk**2)
|
740
1064
|
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
741
1065
|
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
1066
|
+
if hidden_states.shape[0] > 0:
|
1067
|
+
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
1068
|
+
0, hidden_states.shape[0], step=topk, device="cuda"
|
1069
|
+
).repeat_interleave(topk)
|
1070
|
+
hidden_states = hidden_states[selected_input_index, :]
|
746
1071
|
|
747
1072
|
tree_info = (
|
748
1073
|
expand_scores, # shape: (b, topk, topk)
|
@@ -762,15 +1087,35 @@ def _generate_simulated_accept_index(
|
|
762
1087
|
spec_steps,
|
763
1088
|
):
|
764
1089
|
simulate_acc_len_float = float(simulate_acc_len)
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
1090
|
+
if SIMULATE_ACC_METHOD == "multinomial":
|
1091
|
+
simulated_values = torch.normal(
|
1092
|
+
mean=simulate_acc_len_float,
|
1093
|
+
std=1.0,
|
1094
|
+
size=(1,),
|
1095
|
+
device="cpu",
|
1096
|
+
)
|
1097
|
+
# clamp simulated values to be between 1 and self.spec_steps
|
1098
|
+
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
1099
|
+
simulate_acc_len = int(simulated_values.round().item())
|
1100
|
+
elif SIMULATE_ACC_METHOD == "match-expected":
|
1101
|
+
# multinomial sampling does not match the expected length
|
1102
|
+
# we keep it for the sake of compatibility of existing tests
|
1103
|
+
# but it's better to use "match-expected" for the cases that need to
|
1104
|
+
# match the expected length, One caveat is that this will only sample
|
1105
|
+
# either round down or round up of the expected length
|
1106
|
+
simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
|
1107
|
+
lower = int(simulate_acc_len_float // 1)
|
1108
|
+
upper = lower + 1 if lower < spec_steps + 1 else lower
|
1109
|
+
if lower == upper:
|
1110
|
+
simulate_acc_len = lower
|
1111
|
+
else:
|
1112
|
+
weight_upper = simulate_acc_len_float - lower
|
1113
|
+
weight_lower = 1.0 - weight_upper
|
1114
|
+
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
1115
|
+
sampled_index = torch.multinomial(probs, num_samples=1)
|
1116
|
+
simulate_acc_len = lower if sampled_index == 0 else upper
|
1117
|
+
else:
|
1118
|
+
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
774
1119
|
|
775
1120
|
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
776
1121
|
sim_accept_index = torch.full(
|
@@ -861,9 +1206,9 @@ def generate_token_bitmask(
|
|
861
1206
|
"""
|
862
1207
|
Generate the logit mask for structured output.
|
863
1208
|
Draft model's token can be either valid or invalid with respect to the grammar.
|
864
|
-
We need to perform DFS to
|
865
|
-
1. which tokens are accepted by the grammar
|
866
|
-
2. what is the corresponding logit mask.
|
1209
|
+
We need to perform DFS to
|
1210
|
+
1. figure out which tokens are accepted by the grammar.
|
1211
|
+
2. if so, what is the corresponding logit mask.
|
867
1212
|
"""
|
868
1213
|
|
869
1214
|
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
@@ -880,6 +1225,7 @@ def generate_token_bitmask(
|
|
880
1225
|
device="cpu",
|
881
1226
|
)
|
882
1227
|
grammar = req.grammar
|
1228
|
+
s = time.perf_counter()
|
883
1229
|
traverse_tree(
|
884
1230
|
retrieve_next_token_cpu[i],
|
885
1231
|
retrieve_next_sibling_cpu[i],
|
@@ -889,6 +1235,12 @@ def generate_token_bitmask(
|
|
889
1235
|
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
890
1236
|
],
|
891
1237
|
)
|
1238
|
+
tree_traverse_time = time.perf_counter() - s
|
1239
|
+
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
1240
|
+
logger.warning(
|
1241
|
+
f"Bit mask generation took {tree_traverse_time} seconds with "
|
1242
|
+
f"grammar: {req.grammar}"
|
1243
|
+
)
|
892
1244
|
|
893
1245
|
verify_input.grammar = grammar
|
894
1246
|
return allocate_token_bitmask
|