sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +4 -1
- sglang/srt/layers/rotary_embedding.py +6 -1
- sglang/srt/layers/sampler.py +28 -8
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +61 -35
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -65
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import List
|
3
3
|
|
4
4
|
import torch
|
5
|
+
import torch.distributed as dist
|
5
6
|
from torch import nn
|
6
7
|
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_group
|
9
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
7
10
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
-
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
10
12
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
11
|
-
from sglang.srt.utils import crash_on_warnings,
|
13
|
+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
|
12
14
|
|
13
|
-
if
|
14
|
-
from
|
15
|
+
if is_cuda_available():
|
16
|
+
from sgl_kernel import (
|
15
17
|
min_p_sampling_from_probs,
|
16
18
|
top_k_renorm_prob,
|
17
19
|
top_k_top_p_sampling_from_probs,
|
@@ -21,11 +23,17 @@ if is_flashinfer_available():
|
|
21
23
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
26
|
+
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
27
|
+
|
24
28
|
|
25
29
|
class Sampler(nn.Module):
|
26
30
|
def __init__(self):
|
27
31
|
super().__init__()
|
28
32
|
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
33
|
+
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
34
|
+
|
35
|
+
if global_server_args_dict["enable_dp_attention"]:
|
36
|
+
self.tp_sync_group = get_attention_tp_group().device_group
|
29
37
|
|
30
38
|
def forward(
|
31
39
|
self,
|
@@ -109,8 +117,6 @@ class Sampler(nn.Module):
|
|
109
117
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
110
118
|
)
|
111
119
|
|
112
|
-
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
113
|
-
|
114
120
|
# Attach logprobs to logits_output (in-place modification)
|
115
121
|
if return_logprob:
|
116
122
|
if any(x > 0 for x in top_logprobs_nums):
|
@@ -124,7 +130,21 @@ class Sampler(nn.Module):
|
|
124
130
|
batch_next_token_ids,
|
125
131
|
]
|
126
132
|
|
127
|
-
|
133
|
+
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
|
134
|
+
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
|
135
|
+
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
|
136
|
+
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
|
137
|
+
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
|
138
|
+
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
|
139
|
+
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
|
140
|
+
|
141
|
+
torch.distributed.all_reduce(
|
142
|
+
batch_next_token_ids,
|
143
|
+
op=dist.ReduceOp.MIN,
|
144
|
+
group=self.tp_sync_group,
|
145
|
+
)
|
146
|
+
|
147
|
+
return batch_next_token_ids.to(torch.int32)
|
128
148
|
|
129
149
|
def _apply_custom_logit_processor(
|
130
150
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
@@ -5,6 +5,7 @@ Common utilities for torchao.
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
7
|
import pwd
|
8
|
+
from typing import Callable, Optional
|
8
9
|
|
9
10
|
import torch
|
10
11
|
|
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
|
|
27
28
|
return True
|
28
29
|
|
29
30
|
|
31
|
+
def proj_filter(
|
32
|
+
module: torch.nn.Module,
|
33
|
+
fqn: str,
|
34
|
+
):
|
35
|
+
"""Filter function for quantizing projection layers."""
|
36
|
+
return "proj" in fqn
|
37
|
+
|
38
|
+
|
30
39
|
def apply_torchao_config_to_model(
|
31
|
-
model: torch.nn.Module,
|
40
|
+
model: torch.nn.Module,
|
41
|
+
torchao_config: str,
|
42
|
+
filter_fn: Optional[Callable] = proj_filter,
|
32
43
|
):
|
33
44
|
"""Quantize a modelwith torchao quantization specified by torchao_config
|
34
45
|
|
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
|
|
49
60
|
)
|
50
61
|
from torchao.quantization.observer import PerRow, PerTensor
|
51
62
|
|
52
|
-
if filter_fn is None:
|
53
|
-
|
54
|
-
def filter_fn(module, fqn):
|
55
|
-
return "proj" in fqn
|
56
|
-
|
57
63
|
if torchao_config == "" or torchao_config is None:
|
58
64
|
return model
|
59
65
|
elif "int8wo" in torchao_config:
|
@@ -201,6 +201,7 @@ class DetokenizerManager:
|
|
201
201
|
prompt_tokens=recv_obj.prompt_tokens,
|
202
202
|
completion_tokens=recv_obj.completion_tokens,
|
203
203
|
cached_tokens=recv_obj.cached_tokens,
|
204
|
+
spec_verify_ct=recv_obj.spec_verify_ct,
|
204
205
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
205
206
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
206
207
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
17
17
|
"""
|
18
18
|
|
19
19
|
import uuid
|
20
|
-
from dataclasses import dataclass
|
20
|
+
from dataclasses import dataclass, field
|
21
21
|
from enum import Enum
|
22
22
|
from typing import Dict, List, Optional, Union
|
23
23
|
|
@@ -69,8 +69,10 @@ class GenerateReqInput:
|
|
69
69
|
|
70
70
|
# Session info for continual prompting
|
71
71
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
72
|
-
# Custom logit processor
|
73
|
-
|
72
|
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
73
|
+
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
74
|
+
# Use the processor's `to_str()` method to generate the serialized string.
|
75
|
+
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
74
76
|
|
75
77
|
def normalize_batch_and_arguments(self):
|
76
78
|
if (
|
@@ -248,8 +250,9 @@ class TokenizedGenerateReqInput:
|
|
248
250
|
# Session info for continual prompting
|
249
251
|
session_params: Optional[SessionParams] = None
|
250
252
|
|
251
|
-
# Custom logit processor
|
252
|
-
#
|
253
|
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
254
|
+
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
255
|
+
# Use the processor's `to_str()` method to generate the serialized string.
|
253
256
|
custom_logit_processor: Optional[str] = None
|
254
257
|
|
255
258
|
|
@@ -351,10 +354,13 @@ class BatchTokenIDOut:
|
|
351
354
|
skip_special_tokens: List[bool]
|
352
355
|
spaces_between_special_tokens: List[bool]
|
353
356
|
no_stop_trim: List[bool]
|
357
|
+
|
354
358
|
# Token counts
|
355
359
|
prompt_tokens: List[int]
|
356
360
|
completion_tokens: List[int]
|
357
361
|
cached_tokens: List[int]
|
362
|
+
spec_verify_ct: List[int]
|
363
|
+
|
358
364
|
# Logprobs
|
359
365
|
input_token_logprobs_val: List[float]
|
360
366
|
input_token_logprobs_idx: List[int]
|
@@ -379,6 +385,7 @@ class BatchStrOut:
|
|
379
385
|
prompt_tokens: List[int]
|
380
386
|
completion_tokens: List[int]
|
381
387
|
cached_tokens: List[int]
|
388
|
+
spec_verify_ct: List[int]
|
382
389
|
|
383
390
|
# Logprobs
|
384
391
|
input_token_logprobs_val: List[float]
|
@@ -533,3 +540,27 @@ class CloseSessionReqInput:
|
|
533
540
|
class OpenSessionReqOutput:
|
534
541
|
session_id: Optional[str]
|
535
542
|
success: bool
|
543
|
+
|
544
|
+
|
545
|
+
@dataclass
|
546
|
+
class Function:
|
547
|
+
description: Optional[str] = None
|
548
|
+
name: Optional[str] = None
|
549
|
+
parameters: Optional[object] = None
|
550
|
+
|
551
|
+
|
552
|
+
@dataclass
|
553
|
+
class Tool:
|
554
|
+
function: Function
|
555
|
+
type: Optional[str] = "function"
|
556
|
+
|
557
|
+
|
558
|
+
@dataclass
|
559
|
+
class FunctionCallReqInput:
|
560
|
+
text: str # The text to parse.
|
561
|
+
tools: List[Tool] = field(
|
562
|
+
default_factory=list
|
563
|
+
) # A list of available function tools (name, parameters, etc.).
|
564
|
+
tool_call_parser: Optional[str] = (
|
565
|
+
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
566
|
+
)
|
@@ -247,12 +247,12 @@ class Req:
|
|
247
247
|
# Each decode stage's output ids
|
248
248
|
self.output_ids = []
|
249
249
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
250
|
+
self.fill_ids = None
|
250
251
|
self.session_id = session_id
|
251
252
|
self.input_embeds = input_embeds
|
252
253
|
|
253
254
|
# Sampling info
|
254
255
|
self.sampling_params = sampling_params
|
255
|
-
self.lora_path = lora_path
|
256
256
|
self.custom_logit_processor = custom_logit_processor
|
257
257
|
|
258
258
|
# Memory pool info
|
@@ -300,7 +300,7 @@ class Req:
|
|
300
300
|
self.logprob_start_len = 0
|
301
301
|
self.top_logprobs_num = top_logprobs_num
|
302
302
|
|
303
|
-
# Logprobs (return
|
303
|
+
# Logprobs (return values)
|
304
304
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
305
305
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
306
306
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
@@ -329,8 +329,14 @@ class Req:
|
|
329
329
|
# Constrained decoding
|
330
330
|
self.grammar: Optional[BaseGrammarObject] = None
|
331
331
|
|
332
|
-
# The number of cached tokens
|
332
|
+
# The number of cached tokens that were already cached in the KV cache
|
333
333
|
self.cached_tokens = 0
|
334
|
+
self.already_computed = 0
|
335
|
+
|
336
|
+
# The number of verification forward passes in the speculative decoding.
|
337
|
+
# This is used to compute the average acceptance length per request.
|
338
|
+
self.spec_verify_ct = 0
|
339
|
+
self.lora_path = lora_path
|
334
340
|
|
335
341
|
def extend_image_inputs(self, image_inputs):
|
336
342
|
if self.image_inputs is None:
|
@@ -550,13 +556,13 @@ class ScheduleBatch:
|
|
550
556
|
next_batch_sampling_info: SamplingBatchInfo = None
|
551
557
|
|
552
558
|
# Batched arguments to model runner
|
553
|
-
input_ids: torch.Tensor = None
|
554
|
-
input_embeds: torch.Tensor = None
|
555
|
-
req_pool_indices: torch.Tensor = None
|
556
|
-
seq_lens: torch.Tensor = None
|
559
|
+
input_ids: torch.Tensor = None # shape: [b], int32
|
560
|
+
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
561
|
+
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
562
|
+
seq_lens: torch.Tensor = None # shape: [b], int64
|
557
563
|
# The output locations of the KV cache
|
558
|
-
out_cache_loc: torch.Tensor = None
|
559
|
-
output_ids: torch.Tensor = None
|
564
|
+
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
565
|
+
output_ids: torch.Tensor = None # shape: [b], int32
|
560
566
|
|
561
567
|
# The sum of all sequence lengths
|
562
568
|
seq_lens_sum: int = None
|
@@ -750,13 +756,6 @@ class ScheduleBatch:
|
|
750
756
|
|
751
757
|
pt = 0
|
752
758
|
for i, req in enumerate(reqs):
|
753
|
-
already_computed = (
|
754
|
-
req.extend_logprob_start_len + 1 + req.cached_tokens
|
755
|
-
if req.extend_logprob_start_len > 0
|
756
|
-
else 0
|
757
|
-
)
|
758
|
-
req.cached_tokens += len(req.prefix_indices) - already_computed
|
759
|
-
|
760
759
|
req.req_pool_idx = req_pool_indices[i]
|
761
760
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
762
761
|
seq_lens.append(seq_len)
|
@@ -772,15 +771,20 @@ class ScheduleBatch:
|
|
772
771
|
# If req.input_embeds is already a list, append its content directly
|
773
772
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
774
773
|
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
774
|
+
if req.return_logprob:
|
775
|
+
# Compute the relative logprob_start_len in an extend batch
|
776
|
+
if req.logprob_start_len >= pre_len:
|
777
|
+
extend_logprob_start_len = min(
|
778
|
+
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
raise RuntimeError(
|
782
|
+
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
783
|
+
)
|
784
|
+
req.extend_logprob_start_len = extend_logprob_start_len
|
782
785
|
|
783
|
-
req.
|
786
|
+
req.cached_tokens += pre_len - req.already_computed
|
787
|
+
req.already_computed = seq_len
|
784
788
|
req.is_retracted = False
|
785
789
|
pre_lens.append(pre_len)
|
786
790
|
|
@@ -1026,7 +1030,7 @@ class ScheduleBatch:
|
|
1026
1030
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
1027
1031
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1028
1032
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
1029
|
-
self.req_pool_indices = torch.empty(0, dtype=torch.
|
1033
|
+
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1030
1034
|
self.seq_lens_sum = 0
|
1031
1035
|
self.extend_num_tokens = 0
|
1032
1036
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
@@ -1112,6 +1116,8 @@ class ScheduleBatch:
|
|
1112
1116
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
1113
1117
|
|
1114
1118
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
1119
|
+
if self.spec_info:
|
1120
|
+
self.spec_info.filter_batch(new_indices)
|
1115
1121
|
|
1116
1122
|
def merge_batch(self, other: "ScheduleBatch"):
|
1117
1123
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -281,6 +281,7 @@ class Scheduler:
|
|
281
281
|
# Print debug info
|
282
282
|
logger.info(
|
283
283
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
284
|
+
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
284
285
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
285
286
|
f"max_running_requests={self.max_running_requests}, "
|
286
287
|
f"context_len={self.model_config.context_len}"
|
@@ -408,6 +409,11 @@ class Scheduler:
|
|
408
409
|
},
|
409
410
|
)
|
410
411
|
|
412
|
+
# The largest prefill length of a single request
|
413
|
+
self._largest_prefill_len: int = 0
|
414
|
+
# The largest context length (prefill + generation) of a single request
|
415
|
+
self._largest_prefill_decode_len: int = 0
|
416
|
+
|
411
417
|
# Init request dispatcher
|
412
418
|
self._request_dispatcher = TypeBasedDispatcher(
|
413
419
|
[
|
@@ -480,7 +486,7 @@ class Scheduler:
|
|
480
486
|
@torch.no_grad()
|
481
487
|
def event_loop_overlap(self):
|
482
488
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
483
|
-
result_queue = deque()
|
489
|
+
self.result_queue = deque()
|
484
490
|
|
485
491
|
while True:
|
486
492
|
recv_reqs = self.recv_requests()
|
@@ -491,7 +497,7 @@ class Scheduler:
|
|
491
497
|
|
492
498
|
if batch:
|
493
499
|
result = self.run_batch(batch)
|
494
|
-
result_queue.append((batch.copy(), result))
|
500
|
+
self.result_queue.append((batch.copy(), result))
|
495
501
|
|
496
502
|
if self.last_batch is None:
|
497
503
|
# Create a dummy first batch to start the pipeline for overlap schedule.
|
@@ -505,7 +511,7 @@ class Scheduler:
|
|
505
511
|
|
506
512
|
if self.last_batch:
|
507
513
|
# Process the results of the last batch
|
508
|
-
tmp_batch, tmp_result = result_queue.popleft()
|
514
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
509
515
|
tmp_batch.next_batch_sampling_info = (
|
510
516
|
self.tp_worker.cur_sampling_info if batch else None
|
511
517
|
)
|
@@ -636,7 +642,7 @@ class Scheduler:
|
|
636
642
|
self.waiting_queue.append(req)
|
637
643
|
return
|
638
644
|
|
639
|
-
# Handle
|
645
|
+
# Handle multimodal inputs
|
640
646
|
if recv_req.image_inputs is not None:
|
641
647
|
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
642
648
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
@@ -660,24 +666,23 @@ class Scheduler:
|
|
660
666
|
self.waiting_queue.append(req)
|
661
667
|
return
|
662
668
|
|
663
|
-
# Copy more attributes
|
664
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
665
|
-
|
666
|
-
if req.logprob_start_len == -1:
|
667
|
-
# By default, only return the logprobs for output tokens
|
668
|
-
req.logprob_start_len = len(req.origin_input_ids) - 1
|
669
|
-
|
670
669
|
# Validate prompts length
|
671
670
|
error_msg = validate_input_length(
|
672
671
|
req,
|
673
672
|
self.max_req_input_len,
|
674
673
|
self.server_args.allow_auto_truncate,
|
675
674
|
)
|
676
|
-
|
677
675
|
if error_msg:
|
678
676
|
self.waiting_queue.append(req)
|
679
677
|
return
|
680
678
|
|
679
|
+
# Copy more attributes
|
680
|
+
if recv_req.logprob_start_len == -1:
|
681
|
+
# By default, only return the logprobs for output tokens
|
682
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
683
|
+
else:
|
684
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
685
|
+
|
681
686
|
req.sampling_params.max_new_tokens = min(
|
682
687
|
(
|
683
688
|
req.sampling_params.max_new_tokens
|
@@ -725,15 +730,26 @@ class Scheduler:
|
|
725
730
|
req.tokenizer = self.tokenizer
|
726
731
|
|
727
732
|
# Validate prompts length
|
728
|
-
validate_input_length(
|
733
|
+
error_msg = validate_input_length(
|
729
734
|
req,
|
730
735
|
self.max_req_input_len,
|
731
736
|
self.server_args.allow_auto_truncate,
|
732
737
|
)
|
738
|
+
if error_msg:
|
739
|
+
self.waiting_queue.append(req)
|
740
|
+
return
|
733
741
|
|
742
|
+
# Copy more attributes
|
743
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
734
744
|
self.waiting_queue.append(req)
|
735
745
|
|
736
|
-
def log_prefill_stats(
|
746
|
+
def log_prefill_stats(
|
747
|
+
self,
|
748
|
+
adder: PrefillAdder,
|
749
|
+
can_run_list: List[Req],
|
750
|
+
running_bs: ScheduleBatch,
|
751
|
+
has_being_chunked: bool,
|
752
|
+
):
|
737
753
|
self.tree_cache_metrics["total"] += (
|
738
754
|
adder.log_input_tokens + adder.log_hit_tokens
|
739
755
|
) / 10**9
|
@@ -1023,7 +1039,7 @@ class Scheduler:
|
|
1023
1039
|
)
|
1024
1040
|
|
1025
1041
|
# Check for jump-forward
|
1026
|
-
if not self.disable_jump_forward:
|
1042
|
+
if not self.disable_jump_forward and batch.has_grammar:
|
1027
1043
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
1028
1044
|
self.waiting_queue.extend(jump_forward_reqs)
|
1029
1045
|
if batch.is_empty():
|
@@ -1044,26 +1060,23 @@ class Scheduler:
|
|
1044
1060
|
self.forward_ct += 1
|
1045
1061
|
|
1046
1062
|
if self.is_generation:
|
1047
|
-
if
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
)
|
1053
|
-
else:
|
1054
|
-
(
|
1055
|
-
logits_output,
|
1056
|
-
next_token_ids,
|
1057
|
-
model_worker_batch,
|
1058
|
-
num_accepted_tokens,
|
1059
|
-
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1060
|
-
self.spec_num_total_accepted_tokens += (
|
1061
|
-
num_accepted_tokens + batch.batch_size()
|
1062
|
-
)
|
1063
|
-
self.spec_num_total_forward_ct += batch.batch_size()
|
1064
|
-
self.num_generated_tokens += num_accepted_tokens
|
1063
|
+
if self.spec_algorithm.is_none():
|
1064
|
+
model_worker_batch = batch.get_model_worker_batch()
|
1065
|
+
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
1066
|
+
model_worker_batch
|
1067
|
+
)
|
1065
1068
|
else:
|
1066
|
-
|
1069
|
+
(
|
1070
|
+
logits_output,
|
1071
|
+
next_token_ids,
|
1072
|
+
model_worker_batch,
|
1073
|
+
num_accepted_tokens,
|
1074
|
+
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1075
|
+
self.spec_num_total_accepted_tokens += (
|
1076
|
+
num_accepted_tokens + batch.batch_size()
|
1077
|
+
)
|
1078
|
+
self.spec_num_total_forward_ct += batch.batch_size()
|
1079
|
+
self.num_generated_tokens += num_accepted_tokens
|
1067
1080
|
batch.output_ids = next_token_ids
|
1068
1081
|
|
1069
1082
|
ret = GenerationBatchResult(
|
@@ -1072,7 +1085,6 @@ class Scheduler:
|
|
1072
1085
|
bid=model_worker_batch.bid,
|
1073
1086
|
)
|
1074
1087
|
else: # embedding or reward model
|
1075
|
-
assert batch.extend_num_tokens != 0
|
1076
1088
|
model_worker_batch = batch.get_model_worker_batch()
|
1077
1089
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1078
1090
|
ret = EmbeddingBatchResult(
|
@@ -1371,6 +1383,7 @@ class Scheduler:
|
|
1371
1383
|
prompt_tokens = []
|
1372
1384
|
completion_tokens = []
|
1373
1385
|
cached_tokens = []
|
1386
|
+
spec_verify_ct = []
|
1374
1387
|
|
1375
1388
|
if return_logprob:
|
1376
1389
|
input_token_logprobs_val = []
|
@@ -1424,6 +1437,9 @@ class Scheduler:
|
|
1424
1437
|
completion_tokens.append(len(req.output_ids))
|
1425
1438
|
cached_tokens.append(req.cached_tokens)
|
1426
1439
|
|
1440
|
+
if not self.spec_algorithm.is_none():
|
1441
|
+
spec_verify_ct.append(req.spec_verify_ct)
|
1442
|
+
|
1427
1443
|
if return_logprob:
|
1428
1444
|
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1429
1445
|
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
@@ -1451,6 +1467,7 @@ class Scheduler:
|
|
1451
1467
|
prompt_tokens,
|
1452
1468
|
completion_tokens,
|
1453
1469
|
cached_tokens,
|
1470
|
+
spec_verify_ct,
|
1454
1471
|
input_token_logprobs_val,
|
1455
1472
|
input_token_logprobs_idx,
|
1456
1473
|
output_token_logprobs_val,
|
@@ -1564,6 +1581,15 @@ class Scheduler:
|
|
1564
1581
|
self.grammar_backend.reset()
|
1565
1582
|
self.req_to_token_pool.clear()
|
1566
1583
|
self.token_to_kv_pool.clear()
|
1584
|
+
|
1585
|
+
if not self.spec_algorithm.is_none():
|
1586
|
+
self.draft_worker.model_runner.req_to_token_pool.clear()
|
1587
|
+
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
1588
|
+
|
1589
|
+
self.num_generated_tokens = 0
|
1590
|
+
self.forward_ct_decode = 0
|
1591
|
+
self.spec_num_total_accepted_tokens = 0
|
1592
|
+
self.spec_num_total_forward_ct = 0
|
1567
1593
|
torch.cuda.empty_cache()
|
1568
1594
|
logger.info("Cache flushed successfully!")
|
1569
1595
|
if_success = True
|
@@ -785,6 +785,9 @@ class TokenizerManager:
|
|
785
785
|
i,
|
786
786
|
)
|
787
787
|
|
788
|
+
if self.server_args.speculative_algorithm:
|
789
|
+
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
790
|
+
|
788
791
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
789
792
|
meta_info.update(
|
790
793
|
{
|
@@ -809,6 +812,7 @@ class TokenizerManager:
|
|
809
812
|
"embedding": recv_obj.embeddings[i],
|
810
813
|
"meta_info": meta_info,
|
811
814
|
}
|
815
|
+
|
812
816
|
state.out_list.append(out_dict)
|
813
817
|
state.finished = recv_obj.finished_reasons[i] is not None
|
814
818
|
state.event.set()
|