sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -100,6 +100,7 @@ class LlamaAttention(nn.Module):
|
|
100
100
|
max_position_embeddings: int = 8192,
|
101
101
|
quant_config: Optional[QuantizationConfig] = None,
|
102
102
|
prefix: str = "",
|
103
|
+
bias: bool = False,
|
103
104
|
) -> None:
|
104
105
|
super().__init__()
|
105
106
|
self.hidden_size = hidden_size
|
@@ -132,14 +133,14 @@ class LlamaAttention(nn.Module):
|
|
132
133
|
self.head_dim,
|
133
134
|
self.total_num_heads,
|
134
135
|
self.total_num_kv_heads,
|
135
|
-
bias=
|
136
|
+
bias=bias,
|
136
137
|
quant_config=quant_config,
|
137
138
|
prefix=f"{prefix}.qkv_proj",
|
138
139
|
)
|
139
140
|
self.o_proj = RowParallelLinear(
|
140
141
|
self.total_num_heads * self.head_dim,
|
141
142
|
hidden_size,
|
142
|
-
bias=
|
143
|
+
bias=bias,
|
143
144
|
quant_config=quant_config,
|
144
145
|
prefix=f"{prefix}.o_proj",
|
145
146
|
)
|
@@ -194,6 +195,11 @@ class LlamaDecoderLayer(nn.Module):
|
|
194
195
|
)
|
195
196
|
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
196
197
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
198
|
+
# Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
|
199
|
+
# Support internlm/internlm-7b with bias
|
200
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
201
|
+
config, "bias", False
|
202
|
+
)
|
197
203
|
self.self_attn = LlamaAttention(
|
198
204
|
config=config,
|
199
205
|
hidden_size=self.hidden_size,
|
@@ -206,6 +212,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
206
212
|
max_position_embeddings=max_position_embeddings,
|
207
213
|
quant_config=quant_config,
|
208
214
|
prefix=f"{prefix}.self_attn",
|
215
|
+
bias=attention_bias,
|
209
216
|
)
|
210
217
|
self.mlp = LlamaMLP(
|
211
218
|
hidden_size=self.hidden_size,
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -696,14 +696,6 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
696
696
|
|
697
697
|
async def v1_completions(tokenizer_manager, raw_request: Request):
|
698
698
|
request_json = await raw_request.json()
|
699
|
-
if "extra_body" in request_json:
|
700
|
-
extra = request_json["extra_body"]
|
701
|
-
if "ebnf" in extra:
|
702
|
-
request_json["ebnf"] = extra["ebnf"]
|
703
|
-
if "regex" in extra:
|
704
|
-
request_json["regex"] = extra["regex"]
|
705
|
-
# remove extra_body to avoid pydantic conflict
|
706
|
-
del request_json["extra_body"]
|
707
699
|
all_requests = [CompletionRequest(**request_json)]
|
708
700
|
adapted_request, request = v1_generate_request(all_requests)
|
709
701
|
|
@@ -1176,15 +1168,6 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1176
1168
|
|
1177
1169
|
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
1178
1170
|
request_json = await raw_request.json()
|
1179
|
-
if "extra_body" in request_json:
|
1180
|
-
extra = request_json["extra_body"]
|
1181
|
-
# For example, if 'ebnf' is given:
|
1182
|
-
if "ebnf" in extra:
|
1183
|
-
request_json["ebnf"] = extra["ebnf"]
|
1184
|
-
if "regex" in extra:
|
1185
|
-
request_json["regex"] = extra["regex"]
|
1186
|
-
# remove extra_body to avoid pydantic conflict
|
1187
|
-
del request_json["extra_body"]
|
1188
1171
|
all_requests = [ChatCompletionRequest(**request_json)]
|
1189
1172
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
1190
1173
|
|
@@ -171,15 +171,15 @@ class CompletionRequest(BaseModel):
|
|
171
171
|
top_k: int = -1
|
172
172
|
min_p: float = 0.0
|
173
173
|
min_tokens: int = 0
|
174
|
-
regex: Optional[str] = None
|
175
174
|
json_schema: Optional[str] = None
|
175
|
+
regex: Optional[str] = None
|
176
|
+
ebnf: Optional[str] = None
|
176
177
|
repetition_penalty: float = 1.0
|
177
178
|
stop_token_ids: Optional[List[int]] = None
|
178
179
|
no_stop_trim: bool = False
|
179
180
|
ignore_eos: bool = False
|
180
181
|
skip_special_tokens: bool = True
|
181
182
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
182
|
-
ebnf: Optional[str] = None
|
183
183
|
|
184
184
|
|
185
185
|
class CompletionResponseChoice(BaseModel):
|
@@ -315,13 +315,13 @@ class ChatCompletionRequest(BaseModel):
|
|
315
315
|
min_p: float = 0.0
|
316
316
|
min_tokens: int = 0
|
317
317
|
regex: Optional[str] = None
|
318
|
+
ebnf: Optional[str] = None
|
318
319
|
repetition_penalty: float = 1.0
|
319
320
|
stop_token_ids: Optional[List[int]] = None
|
320
321
|
no_stop_trim: bool = False
|
321
322
|
ignore_eos: bool = False
|
322
323
|
skip_special_tokens: bool = True
|
323
324
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
324
|
-
ebnf: Optional[str] = None
|
325
325
|
|
326
326
|
|
327
327
|
class FunctionResponse(BaseModel):
|
@@ -232,3 +232,25 @@ class SamplingBatchInfo:
|
|
232
232
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
233
233
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
234
234
|
)
|
235
|
+
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
236
|
+
|
237
|
+
def apply_logits_bias(self, logits: torch.Tensor):
|
238
|
+
# Apply logit_bias
|
239
|
+
if self.logit_bias is not None:
|
240
|
+
logits.add_(self.logit_bias)
|
241
|
+
|
242
|
+
# min-token, presence, frequency
|
243
|
+
if self.linear_penalties is not None:
|
244
|
+
logits.add_(self.linear_penalties)
|
245
|
+
|
246
|
+
# repetition
|
247
|
+
if self.scaling_penalties is not None:
|
248
|
+
logits[:] = torch.where(
|
249
|
+
logits > 0,
|
250
|
+
logits / self.scaling_penalties,
|
251
|
+
logits * self.scaling_penalties,
|
252
|
+
)
|
253
|
+
|
254
|
+
# Apply regex vocab_mask
|
255
|
+
if self.vocab_mask is not None:
|
256
|
+
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
@@ -19,6 +19,14 @@ _SAMPLING_EPS = 1e-6
|
|
19
19
|
|
20
20
|
|
21
21
|
class SamplingParams:
|
22
|
+
"""
|
23
|
+
The sampling parameters.
|
24
|
+
|
25
|
+
See docs/references/sampling_params.md or
|
26
|
+
https://sgl-project.github.io/references/sampling_params.html
|
27
|
+
for the documentation.
|
28
|
+
"""
|
29
|
+
|
22
30
|
def __init__(
|
23
31
|
self,
|
24
32
|
max_new_tokens: int = 128,
|
@@ -33,9 +41,9 @@ class SamplingParams:
|
|
33
41
|
repetition_penalty: float = 1.0,
|
34
42
|
min_new_tokens: int = 0,
|
35
43
|
spaces_between_special_tokens: bool = True,
|
36
|
-
regex: Optional[str] = None,
|
37
44
|
n: int = 1,
|
38
45
|
json_schema: Optional[str] = None,
|
46
|
+
regex: Optional[str] = None,
|
39
47
|
ebnf: Optional[str] = None,
|
40
48
|
no_stop_trim: bool = False,
|
41
49
|
ignore_eos: bool = False,
|
sglang/srt/server.py
CHANGED
@@ -27,7 +27,9 @@ import signal
|
|
27
27
|
import threading
|
28
28
|
import time
|
29
29
|
from http import HTTPStatus
|
30
|
-
from typing import AsyncIterator, Dict, List, Optional, Union
|
30
|
+
from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
31
|
+
|
32
|
+
import torch
|
31
33
|
|
32
34
|
# Fix a bug of Python threading
|
33
35
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
|
|
78
80
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
79
81
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
80
82
|
from sglang.srt.utils import (
|
83
|
+
MultiprocessingSerializer,
|
81
84
|
add_api_key_middleware,
|
82
85
|
add_prometheus_middleware,
|
83
86
|
assert_pkg_version,
|
@@ -124,14 +127,12 @@ async def health() -> Response:
|
|
124
127
|
async def health_generate(request: Request) -> Response:
|
125
128
|
"""Check the health of the inference server by generating one token."""
|
126
129
|
|
130
|
+
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
131
|
+
|
127
132
|
if tokenizer_manager.is_generation:
|
128
|
-
gri = GenerateReqInput(
|
129
|
-
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
130
|
-
)
|
133
|
+
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
|
131
134
|
else:
|
132
|
-
gri = EmbeddingReqInput(
|
133
|
-
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
134
|
-
)
|
135
|
+
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
|
135
136
|
|
136
137
|
try:
|
137
138
|
async for _ in tokenizer_manager.generate_request(gri, request):
|
@@ -543,7 +544,12 @@ def launch_server(
|
|
543
544
|
|
544
545
|
# Send a warmup request
|
545
546
|
t = threading.Thread(
|
546
|
-
target=_wait_and_warmup,
|
547
|
+
target=_wait_and_warmup,
|
548
|
+
args=(
|
549
|
+
server_args,
|
550
|
+
pipe_finish_writer,
|
551
|
+
tokenizer_manager.image_token_id,
|
552
|
+
),
|
547
553
|
)
|
548
554
|
t.start()
|
549
555
|
|
@@ -613,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
613
619
|
mp.set_start_method("spawn", force=True)
|
614
620
|
|
615
621
|
|
616
|
-
def _wait_and_warmup(server_args, pipe_finish_writer):
|
622
|
+
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
617
623
|
headers = {}
|
618
624
|
url = server_args.url()
|
619
625
|
if server_args.api_key:
|
@@ -872,9 +878,11 @@ class Engine:
|
|
872
878
|
tokenizer_manager.update_weights_from_distributed(obj, None)
|
873
879
|
)
|
874
880
|
|
875
|
-
def update_weights_from_tensor(self,
|
881
|
+
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
876
882
|
"""Update weights from distributed source."""
|
877
|
-
obj = UpdateWeightsFromTensorReqInput(
|
883
|
+
obj = UpdateWeightsFromTensorReqInput(
|
884
|
+
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
885
|
+
)
|
878
886
|
loop = asyncio.get_event_loop()
|
879
887
|
return loop.run_until_complete(
|
880
888
|
tokenizer_manager.update_weights_from_tensor(obj, None)
|
@@ -910,10 +918,9 @@ class Runtime:
|
|
910
918
|
atexit.register(self.shutdown)
|
911
919
|
|
912
920
|
# Pre-allocate ports
|
913
|
-
for port in range(
|
921
|
+
for port in range(self.server_args.port, 40000):
|
914
922
|
if is_port_available(port):
|
915
923
|
break
|
916
|
-
port += 1
|
917
924
|
self.server_args.port = port
|
918
925
|
|
919
926
|
self.url = self.server_args.url()
|
sglang/srt/server_args.py
CHANGED
@@ -23,6 +23,7 @@ from typing import List, Optional
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
26
27
|
from sglang.srt.utils import (
|
27
28
|
get_amdgpu_memory_capacity,
|
28
29
|
get_hpu_memory_capacity,
|
@@ -42,7 +43,6 @@ class ServerArgs:
|
|
42
43
|
model_path: str
|
43
44
|
tokenizer_path: Optional[str] = None
|
44
45
|
tokenizer_mode: str = "auto"
|
45
|
-
skip_tokenizer_init: bool = False
|
46
46
|
load_format: str = "auto"
|
47
47
|
trust_remote_code: bool = True
|
48
48
|
dtype: str = "auto"
|
@@ -54,6 +54,7 @@ class ServerArgs:
|
|
54
54
|
chat_template: Optional[str] = None
|
55
55
|
is_embedding: bool = False
|
56
56
|
revision: Optional[str] = None
|
57
|
+
skip_tokenizer_init: bool = False
|
57
58
|
return_token_ids: bool = False
|
58
59
|
|
59
60
|
# Port for the HTTP server
|
@@ -108,14 +109,6 @@ class ServerArgs:
|
|
108
109
|
# Model override args in JSON
|
109
110
|
json_model_override_args: str = "{}"
|
110
111
|
|
111
|
-
# Double Sparsity
|
112
|
-
enable_double_sparsity: bool = False
|
113
|
-
ds_channel_config_path: str = None
|
114
|
-
ds_heavy_channel_num: int = 32
|
115
|
-
ds_heavy_token_num: int = 256
|
116
|
-
ds_heavy_channel_type: str = "qk"
|
117
|
-
ds_sparse_decode_threshold: int = 4096
|
118
|
-
|
119
112
|
# LoRA
|
120
113
|
lora_paths: Optional[List[str]] = None
|
121
114
|
max_loras_per_batch: int = 8
|
@@ -125,6 +118,21 @@ class ServerArgs:
|
|
125
118
|
sampling_backend: Optional[str] = None
|
126
119
|
grammar_backend: Optional[str] = "outlines"
|
127
120
|
|
121
|
+
# Speculative decoding
|
122
|
+
speculative_draft_model_path: Optional[str] = None
|
123
|
+
speculative_algorithm: Optional[str] = None
|
124
|
+
speculative_num_steps: int = 5
|
125
|
+
speculative_num_draft_tokens: int = 64
|
126
|
+
speculative_eagle_topk: int = 8
|
127
|
+
|
128
|
+
# Double Sparsity
|
129
|
+
enable_double_sparsity: bool = False
|
130
|
+
ds_channel_config_path: str = None
|
131
|
+
ds_heavy_channel_num: int = 32
|
132
|
+
ds_heavy_token_num: int = 256
|
133
|
+
ds_heavy_channel_type: str = "qk"
|
134
|
+
ds_sparse_decode_threshold: int = 4096
|
135
|
+
|
128
136
|
# Optimization/debug options
|
129
137
|
disable_radix_cache: bool = False
|
130
138
|
disable_jump_forward: bool = False
|
@@ -140,6 +148,7 @@ class ServerArgs:
|
|
140
148
|
enable_torch_compile: bool = False
|
141
149
|
torch_compile_max_bs: int = 32
|
142
150
|
cuda_graph_max_bs: Optional[int] = None
|
151
|
+
cuda_graph_bs: Optional[List[int]] = None
|
143
152
|
torchao_config: str = ""
|
144
153
|
enable_nan_detection: bool = False
|
145
154
|
enable_p2p_check: bool = False
|
@@ -240,6 +249,17 @@ class ServerArgs:
|
|
240
249
|
"Overlap scheduler is disabled."
|
241
250
|
)
|
242
251
|
|
252
|
+
# Speculative Decoding
|
253
|
+
if self.speculative_algorithm == "EAGLE":
|
254
|
+
self.prefill_only_one_req = True
|
255
|
+
self.disable_cuda_graph_padding = True
|
256
|
+
self.disable_radix_cache = True
|
257
|
+
self.disable_overlap_schedule = True
|
258
|
+
self.chunked_prefill_size = -1
|
259
|
+
logger.info(
|
260
|
+
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
|
261
|
+
)
|
262
|
+
|
243
263
|
# GGUF
|
244
264
|
if (
|
245
265
|
self.load_format == "auto" or self.load_format == "gguf"
|
@@ -276,17 +296,6 @@ class ServerArgs:
|
|
276
296
|
"tokenizer if available, and 'slow' will "
|
277
297
|
"always use the slow tokenizer.",
|
278
298
|
)
|
279
|
-
parser.add_argument(
|
280
|
-
"--skip-tokenizer-init",
|
281
|
-
action="store_true",
|
282
|
-
help="If set, skip init tokenizer and pass input_ids in generate request",
|
283
|
-
)
|
284
|
-
parser.add_argument(
|
285
|
-
"--return-token-ids",
|
286
|
-
action="store_true",
|
287
|
-
default=ServerArgs.return_token_ids,
|
288
|
-
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
289
|
-
)
|
290
299
|
parser.add_argument(
|
291
300
|
"--load-format",
|
292
301
|
type=str,
|
@@ -353,6 +362,7 @@ class ServerArgs:
|
|
353
362
|
"awq_marlin",
|
354
363
|
"bitsandbytes",
|
355
364
|
"gguf",
|
365
|
+
"modelopt",
|
356
366
|
],
|
357
367
|
help="The quantization method.",
|
358
368
|
)
|
@@ -394,6 +404,17 @@ class ServerArgs:
|
|
394
404
|
"name, a tag name, or a commit id. If unspecified, will use "
|
395
405
|
"the default version.",
|
396
406
|
)
|
407
|
+
parser.add_argument(
|
408
|
+
"--skip-tokenizer-init",
|
409
|
+
action="store_true",
|
410
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
411
|
+
)
|
412
|
+
parser.add_argument(
|
413
|
+
"--return-token-ids",
|
414
|
+
action="store_true",
|
415
|
+
default=ServerArgs.return_token_ids,
|
416
|
+
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
417
|
+
)
|
397
418
|
|
398
419
|
# Memory and scheduling
|
399
420
|
parser.add_argument(
|
@@ -602,43 +623,6 @@ class ServerArgs:
|
|
602
623
|
default=ServerArgs.json_model_override_args,
|
603
624
|
)
|
604
625
|
|
605
|
-
# Double Sparsity
|
606
|
-
parser.add_argument(
|
607
|
-
"--enable-double-sparsity",
|
608
|
-
action="store_true",
|
609
|
-
help="Enable double sparsity attention",
|
610
|
-
)
|
611
|
-
parser.add_argument(
|
612
|
-
"--ds-channel-config-path",
|
613
|
-
type=str,
|
614
|
-
default=ServerArgs.ds_channel_config_path,
|
615
|
-
help="The path of the double sparsity channel config",
|
616
|
-
)
|
617
|
-
parser.add_argument(
|
618
|
-
"--ds-heavy-channel-num",
|
619
|
-
type=int,
|
620
|
-
default=ServerArgs.ds_heavy_channel_num,
|
621
|
-
help="The number of heavy channels in double sparsity attention",
|
622
|
-
)
|
623
|
-
parser.add_argument(
|
624
|
-
"--ds-heavy-token-num",
|
625
|
-
type=int,
|
626
|
-
default=ServerArgs.ds_heavy_token_num,
|
627
|
-
help="The number of heavy tokens in double sparsity attention",
|
628
|
-
)
|
629
|
-
parser.add_argument(
|
630
|
-
"--ds-heavy-channel-type",
|
631
|
-
type=str,
|
632
|
-
default=ServerArgs.ds_heavy_channel_type,
|
633
|
-
help="The type of heavy channels in double sparsity attention",
|
634
|
-
)
|
635
|
-
parser.add_argument(
|
636
|
-
"--ds-sparse-decode-threshold",
|
637
|
-
type=int,
|
638
|
-
default=ServerArgs.ds_sparse_decode_threshold,
|
639
|
-
help="The type of heavy channels in double sparsity attention",
|
640
|
-
)
|
641
|
-
|
642
626
|
# LoRA
|
643
627
|
parser.add_argument(
|
644
628
|
"--lora-paths",
|
@@ -678,6 +662,75 @@ class ServerArgs:
|
|
678
662
|
help="Choose the backend for grammar-guided decoding.",
|
679
663
|
)
|
680
664
|
|
665
|
+
# Speculative decoding
|
666
|
+
parser.add_argument(
|
667
|
+
"--speculative-algorithm",
|
668
|
+
type=str,
|
669
|
+
choices=["EAGLE"],
|
670
|
+
help="Speculative algorithm.",
|
671
|
+
)
|
672
|
+
parser.add_argument(
|
673
|
+
"--speculative-draft-model-path",
|
674
|
+
type=str,
|
675
|
+
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
676
|
+
)
|
677
|
+
parser.add_argument(
|
678
|
+
"--speculative-num-steps",
|
679
|
+
type=int,
|
680
|
+
help="The number of steps sampled from draft model in Speculative Decoding.",
|
681
|
+
default=ServerArgs.speculative_num_steps,
|
682
|
+
)
|
683
|
+
parser.add_argument(
|
684
|
+
"--speculative-num-draft-tokens",
|
685
|
+
type=int,
|
686
|
+
help="The number of token sampled from draft model in Speculative Decoding.",
|
687
|
+
default=ServerArgs.speculative_num_draft_tokens,
|
688
|
+
)
|
689
|
+
parser.add_argument(
|
690
|
+
"--speculative-eagle-topk",
|
691
|
+
type=int,
|
692
|
+
help="The number of token sampled from draft model in eagle2 each step.",
|
693
|
+
choices=[1, 2, 4, 8],
|
694
|
+
default=ServerArgs.speculative_eagle_topk,
|
695
|
+
)
|
696
|
+
|
697
|
+
# Double Sparsity
|
698
|
+
parser.add_argument(
|
699
|
+
"--enable-double-sparsity",
|
700
|
+
action="store_true",
|
701
|
+
help="Enable double sparsity attention",
|
702
|
+
)
|
703
|
+
parser.add_argument(
|
704
|
+
"--ds-channel-config-path",
|
705
|
+
type=str,
|
706
|
+
default=ServerArgs.ds_channel_config_path,
|
707
|
+
help="The path of the double sparsity channel config",
|
708
|
+
)
|
709
|
+
parser.add_argument(
|
710
|
+
"--ds-heavy-channel-num",
|
711
|
+
type=int,
|
712
|
+
default=ServerArgs.ds_heavy_channel_num,
|
713
|
+
help="The number of heavy channels in double sparsity attention",
|
714
|
+
)
|
715
|
+
parser.add_argument(
|
716
|
+
"--ds-heavy-token-num",
|
717
|
+
type=int,
|
718
|
+
default=ServerArgs.ds_heavy_token_num,
|
719
|
+
help="The number of heavy tokens in double sparsity attention",
|
720
|
+
)
|
721
|
+
parser.add_argument(
|
722
|
+
"--ds-heavy-channel-type",
|
723
|
+
type=str,
|
724
|
+
default=ServerArgs.ds_heavy_channel_type,
|
725
|
+
help="The type of heavy channels in double sparsity attention",
|
726
|
+
)
|
727
|
+
parser.add_argument(
|
728
|
+
"--ds-sparse-decode-threshold",
|
729
|
+
type=int,
|
730
|
+
default=ServerArgs.ds_sparse_decode_threshold,
|
731
|
+
help="The type of heavy channels in double sparsity attention",
|
732
|
+
)
|
733
|
+
|
681
734
|
# Optimization/debug options
|
682
735
|
parser.add_argument(
|
683
736
|
"--disable-radix-cache",
|
@@ -751,6 +804,12 @@ class ServerArgs:
|
|
751
804
|
default=ServerArgs.cuda_graph_max_bs,
|
752
805
|
help="Set the maximum batch size for cuda graph.",
|
753
806
|
)
|
807
|
+
parser.add_argument(
|
808
|
+
"--cuda-graph-bs",
|
809
|
+
type=int,
|
810
|
+
nargs="+",
|
811
|
+
help="Set the list of batch sizes for cuda graph.",
|
812
|
+
)
|
754
813
|
parser.add_argument(
|
755
814
|
"--torchao-config",
|
756
815
|
type=str,
|
@@ -869,7 +928,10 @@ class PortArgs:
|
|
869
928
|
while True:
|
870
929
|
if is_port_available(port):
|
871
930
|
break
|
872
|
-
port
|
931
|
+
if port < 60000:
|
932
|
+
port += 42
|
933
|
+
else:
|
934
|
+
port -= 43
|
873
935
|
|
874
936
|
return PortArgs(
|
875
937
|
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|