sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -100,6 +100,7 @@ class LlamaAttention(nn.Module):
100
100
  max_position_embeddings: int = 8192,
101
101
  quant_config: Optional[QuantizationConfig] = None,
102
102
  prefix: str = "",
103
+ bias: bool = False,
103
104
  ) -> None:
104
105
  super().__init__()
105
106
  self.hidden_size = hidden_size
@@ -132,14 +133,14 @@ class LlamaAttention(nn.Module):
132
133
  self.head_dim,
133
134
  self.total_num_heads,
134
135
  self.total_num_kv_heads,
135
- bias=False,
136
+ bias=bias,
136
137
  quant_config=quant_config,
137
138
  prefix=f"{prefix}.qkv_proj",
138
139
  )
139
140
  self.o_proj = RowParallelLinear(
140
141
  self.total_num_heads * self.head_dim,
141
142
  hidden_size,
142
- bias=False,
143
+ bias=bias,
143
144
  quant_config=quant_config,
144
145
  prefix=f"{prefix}.o_proj",
145
146
  )
@@ -194,6 +195,11 @@ class LlamaDecoderLayer(nn.Module):
194
195
  )
195
196
  rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
196
197
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
198
+ # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
199
+ # Support internlm/internlm-7b with bias
200
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
201
+ config, "bias", False
202
+ )
197
203
  self.self_attn = LlamaAttention(
198
204
  config=config,
199
205
  hidden_size=self.hidden_size,
@@ -206,6 +212,7 @@ class LlamaDecoderLayer(nn.Module):
206
212
  max_position_embeddings=max_position_embeddings,
207
213
  quant_config=quant_config,
208
214
  prefix=f"{prefix}.self_attn",
215
+ bias=attention_bias,
209
216
  )
210
217
  self.mlp = LlamaMLP(
211
218
  hidden_size=self.hidden_size,
@@ -696,14 +696,6 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
696
696
 
697
697
  async def v1_completions(tokenizer_manager, raw_request: Request):
698
698
  request_json = await raw_request.json()
699
- if "extra_body" in request_json:
700
- extra = request_json["extra_body"]
701
- if "ebnf" in extra:
702
- request_json["ebnf"] = extra["ebnf"]
703
- if "regex" in extra:
704
- request_json["regex"] = extra["regex"]
705
- # remove extra_body to avoid pydantic conflict
706
- del request_json["extra_body"]
707
699
  all_requests = [CompletionRequest(**request_json)]
708
700
  adapted_request, request = v1_generate_request(all_requests)
709
701
 
@@ -1176,15 +1168,6 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1176
1168
 
1177
1169
  async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1178
1170
  request_json = await raw_request.json()
1179
- if "extra_body" in request_json:
1180
- extra = request_json["extra_body"]
1181
- # For example, if 'ebnf' is given:
1182
- if "ebnf" in extra:
1183
- request_json["ebnf"] = extra["ebnf"]
1184
- if "regex" in extra:
1185
- request_json["regex"] = extra["regex"]
1186
- # remove extra_body to avoid pydantic conflict
1187
- del request_json["extra_body"]
1188
1171
  all_requests = [ChatCompletionRequest(**request_json)]
1189
1172
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1190
1173
 
@@ -171,15 +171,15 @@ class CompletionRequest(BaseModel):
171
171
  top_k: int = -1
172
172
  min_p: float = 0.0
173
173
  min_tokens: int = 0
174
- regex: Optional[str] = None
175
174
  json_schema: Optional[str] = None
175
+ regex: Optional[str] = None
176
+ ebnf: Optional[str] = None
176
177
  repetition_penalty: float = 1.0
177
178
  stop_token_ids: Optional[List[int]] = None
178
179
  no_stop_trim: bool = False
179
180
  ignore_eos: bool = False
180
181
  skip_special_tokens: bool = True
181
182
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
182
- ebnf: Optional[str] = None
183
183
 
184
184
 
185
185
  class CompletionResponseChoice(BaseModel):
@@ -315,13 +315,13 @@ class ChatCompletionRequest(BaseModel):
315
315
  min_p: float = 0.0
316
316
  min_tokens: int = 0
317
317
  regex: Optional[str] = None
318
+ ebnf: Optional[str] = None
318
319
  repetition_penalty: float = 1.0
319
320
  stop_token_ids: Optional[List[int]] = None
320
321
  no_stop_trim: bool = False
321
322
  ignore_eos: bool = False
322
323
  skip_special_tokens: bool = True
323
324
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
324
- ebnf: Optional[str] = None
325
325
 
326
326
 
327
327
  class FunctionResponse(BaseModel):
@@ -232,3 +232,25 @@ class SamplingBatchInfo:
232
232
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
233
233
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
234
234
  )
235
+ self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
236
+
237
+ def apply_logits_bias(self, logits: torch.Tensor):
238
+ # Apply logit_bias
239
+ if self.logit_bias is not None:
240
+ logits.add_(self.logit_bias)
241
+
242
+ # min-token, presence, frequency
243
+ if self.linear_penalties is not None:
244
+ logits.add_(self.linear_penalties)
245
+
246
+ # repetition
247
+ if self.scaling_penalties is not None:
248
+ logits[:] = torch.where(
249
+ logits > 0,
250
+ logits / self.scaling_penalties,
251
+ logits * self.scaling_penalties,
252
+ )
253
+
254
+ # Apply regex vocab_mask
255
+ if self.vocab_mask is not None:
256
+ self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
@@ -19,6 +19,14 @@ _SAMPLING_EPS = 1e-6
19
19
 
20
20
 
21
21
  class SamplingParams:
22
+ """
23
+ The sampling parameters.
24
+
25
+ See docs/references/sampling_params.md or
26
+ https://sgl-project.github.io/references/sampling_params.html
27
+ for the documentation.
28
+ """
29
+
22
30
  def __init__(
23
31
  self,
24
32
  max_new_tokens: int = 128,
@@ -33,9 +41,9 @@ class SamplingParams:
33
41
  repetition_penalty: float = 1.0,
34
42
  min_new_tokens: int = 0,
35
43
  spaces_between_special_tokens: bool = True,
36
- regex: Optional[str] = None,
37
44
  n: int = 1,
38
45
  json_schema: Optional[str] = None,
46
+ regex: Optional[str] = None,
39
47
  ebnf: Optional[str] = None,
40
48
  no_stop_trim: bool = False,
41
49
  ignore_eos: bool = False,
sglang/srt/server.py CHANGED
@@ -27,7 +27,9 @@ import signal
27
27
  import threading
28
28
  import time
29
29
  from http import HTTPStatus
30
- from typing import AsyncIterator, Dict, List, Optional, Union
30
+ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
31
+
32
+ import torch
31
33
 
32
34
  # Fix a bug of Python threading
33
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
78
80
  from sglang.srt.openai_api.protocol import ModelCard, ModelList
79
81
  from sglang.srt.server_args import PortArgs, ServerArgs
80
82
  from sglang.srt.utils import (
83
+ MultiprocessingSerializer,
81
84
  add_api_key_middleware,
82
85
  add_prometheus_middleware,
83
86
  assert_pkg_version,
@@ -124,14 +127,12 @@ async def health() -> Response:
124
127
  async def health_generate(request: Request) -> Response:
125
128
  """Check the health of the inference server by generating one token."""
126
129
 
130
+ sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
131
+
127
132
  if tokenizer_manager.is_generation:
128
- gri = GenerateReqInput(
129
- input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
130
- )
133
+ gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
131
134
  else:
132
- gri = EmbeddingReqInput(
133
- input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
134
- )
135
+ gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
135
136
 
136
137
  try:
137
138
  async for _ in tokenizer_manager.generate_request(gri, request):
@@ -543,7 +544,12 @@ def launch_server(
543
544
 
544
545
  # Send a warmup request
545
546
  t = threading.Thread(
546
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
547
+ target=_wait_and_warmup,
548
+ args=(
549
+ server_args,
550
+ pipe_finish_writer,
551
+ tokenizer_manager.image_token_id,
552
+ ),
547
553
  )
548
554
  t.start()
549
555
 
@@ -613,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
613
619
  mp.set_start_method("spawn", force=True)
614
620
 
615
621
 
616
- def _wait_and_warmup(server_args, pipe_finish_writer):
622
+ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
617
623
  headers = {}
618
624
  url = server_args.url()
619
625
  if server_args.api_key:
@@ -872,9 +878,11 @@ class Engine:
872
878
  tokenizer_manager.update_weights_from_distributed(obj, None)
873
879
  )
874
880
 
875
- def update_weights_from_tensor(self, name, tensor):
881
+ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
876
882
  """Update weights from distributed source."""
877
- obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
883
+ obj = UpdateWeightsFromTensorReqInput(
884
+ serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
885
+ )
878
886
  loop = asyncio.get_event_loop()
879
887
  return loop.run_until_complete(
880
888
  tokenizer_manager.update_weights_from_tensor(obj, None)
@@ -910,10 +918,9 @@ class Runtime:
910
918
  atexit.register(self.shutdown)
911
919
 
912
920
  # Pre-allocate ports
913
- for port in range(10000, 40000):
921
+ for port in range(self.server_args.port, 40000):
914
922
  if is_port_available(port):
915
923
  break
916
- port += 1
917
924
  self.server_args.port = port
918
925
 
919
926
  self.url = self.server_args.url()
sglang/srt/server_args.py CHANGED
@@ -23,6 +23,7 @@ from typing import List, Optional
23
23
  import torch
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
26
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
26
27
  from sglang.srt.utils import (
27
28
  get_amdgpu_memory_capacity,
28
29
  get_hpu_memory_capacity,
@@ -42,7 +43,6 @@ class ServerArgs:
42
43
  model_path: str
43
44
  tokenizer_path: Optional[str] = None
44
45
  tokenizer_mode: str = "auto"
45
- skip_tokenizer_init: bool = False
46
46
  load_format: str = "auto"
47
47
  trust_remote_code: bool = True
48
48
  dtype: str = "auto"
@@ -54,6 +54,7 @@ class ServerArgs:
54
54
  chat_template: Optional[str] = None
55
55
  is_embedding: bool = False
56
56
  revision: Optional[str] = None
57
+ skip_tokenizer_init: bool = False
57
58
  return_token_ids: bool = False
58
59
 
59
60
  # Port for the HTTP server
@@ -108,14 +109,6 @@ class ServerArgs:
108
109
  # Model override args in JSON
109
110
  json_model_override_args: str = "{}"
110
111
 
111
- # Double Sparsity
112
- enable_double_sparsity: bool = False
113
- ds_channel_config_path: str = None
114
- ds_heavy_channel_num: int = 32
115
- ds_heavy_token_num: int = 256
116
- ds_heavy_channel_type: str = "qk"
117
- ds_sparse_decode_threshold: int = 4096
118
-
119
112
  # LoRA
120
113
  lora_paths: Optional[List[str]] = None
121
114
  max_loras_per_batch: int = 8
@@ -125,6 +118,21 @@ class ServerArgs:
125
118
  sampling_backend: Optional[str] = None
126
119
  grammar_backend: Optional[str] = "outlines"
127
120
 
121
+ # Speculative decoding
122
+ speculative_draft_model_path: Optional[str] = None
123
+ speculative_algorithm: Optional[str] = None
124
+ speculative_num_steps: int = 5
125
+ speculative_num_draft_tokens: int = 64
126
+ speculative_eagle_topk: int = 8
127
+
128
+ # Double Sparsity
129
+ enable_double_sparsity: bool = False
130
+ ds_channel_config_path: str = None
131
+ ds_heavy_channel_num: int = 32
132
+ ds_heavy_token_num: int = 256
133
+ ds_heavy_channel_type: str = "qk"
134
+ ds_sparse_decode_threshold: int = 4096
135
+
128
136
  # Optimization/debug options
129
137
  disable_radix_cache: bool = False
130
138
  disable_jump_forward: bool = False
@@ -140,6 +148,7 @@ class ServerArgs:
140
148
  enable_torch_compile: bool = False
141
149
  torch_compile_max_bs: int = 32
142
150
  cuda_graph_max_bs: Optional[int] = None
151
+ cuda_graph_bs: Optional[List[int]] = None
143
152
  torchao_config: str = ""
144
153
  enable_nan_detection: bool = False
145
154
  enable_p2p_check: bool = False
@@ -240,6 +249,17 @@ class ServerArgs:
240
249
  "Overlap scheduler is disabled."
241
250
  )
242
251
 
252
+ # Speculative Decoding
253
+ if self.speculative_algorithm == "EAGLE":
254
+ self.prefill_only_one_req = True
255
+ self.disable_cuda_graph_padding = True
256
+ self.disable_radix_cache = True
257
+ self.disable_overlap_schedule = True
258
+ self.chunked_prefill_size = -1
259
+ logger.info(
260
+ "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
261
+ )
262
+
243
263
  # GGUF
244
264
  if (
245
265
  self.load_format == "auto" or self.load_format == "gguf"
@@ -276,17 +296,6 @@ class ServerArgs:
276
296
  "tokenizer if available, and 'slow' will "
277
297
  "always use the slow tokenizer.",
278
298
  )
279
- parser.add_argument(
280
- "--skip-tokenizer-init",
281
- action="store_true",
282
- help="If set, skip init tokenizer and pass input_ids in generate request",
283
- )
284
- parser.add_argument(
285
- "--return-token-ids",
286
- action="store_true",
287
- default=ServerArgs.return_token_ids,
288
- help="Whether to return token IDs in the output, this may introduce additional overhead.",
289
- )
290
299
  parser.add_argument(
291
300
  "--load-format",
292
301
  type=str,
@@ -353,6 +362,7 @@ class ServerArgs:
353
362
  "awq_marlin",
354
363
  "bitsandbytes",
355
364
  "gguf",
365
+ "modelopt",
356
366
  ],
357
367
  help="The quantization method.",
358
368
  )
@@ -394,6 +404,17 @@ class ServerArgs:
394
404
  "name, a tag name, or a commit id. If unspecified, will use "
395
405
  "the default version.",
396
406
  )
407
+ parser.add_argument(
408
+ "--skip-tokenizer-init",
409
+ action="store_true",
410
+ help="If set, skip init tokenizer and pass input_ids in generate request",
411
+ )
412
+ parser.add_argument(
413
+ "--return-token-ids",
414
+ action="store_true",
415
+ default=ServerArgs.return_token_ids,
416
+ help="Whether to return token IDs in the output, this may introduce additional overhead.",
417
+ )
397
418
 
398
419
  # Memory and scheduling
399
420
  parser.add_argument(
@@ -602,43 +623,6 @@ class ServerArgs:
602
623
  default=ServerArgs.json_model_override_args,
603
624
  )
604
625
 
605
- # Double Sparsity
606
- parser.add_argument(
607
- "--enable-double-sparsity",
608
- action="store_true",
609
- help="Enable double sparsity attention",
610
- )
611
- parser.add_argument(
612
- "--ds-channel-config-path",
613
- type=str,
614
- default=ServerArgs.ds_channel_config_path,
615
- help="The path of the double sparsity channel config",
616
- )
617
- parser.add_argument(
618
- "--ds-heavy-channel-num",
619
- type=int,
620
- default=ServerArgs.ds_heavy_channel_num,
621
- help="The number of heavy channels in double sparsity attention",
622
- )
623
- parser.add_argument(
624
- "--ds-heavy-token-num",
625
- type=int,
626
- default=ServerArgs.ds_heavy_token_num,
627
- help="The number of heavy tokens in double sparsity attention",
628
- )
629
- parser.add_argument(
630
- "--ds-heavy-channel-type",
631
- type=str,
632
- default=ServerArgs.ds_heavy_channel_type,
633
- help="The type of heavy channels in double sparsity attention",
634
- )
635
- parser.add_argument(
636
- "--ds-sparse-decode-threshold",
637
- type=int,
638
- default=ServerArgs.ds_sparse_decode_threshold,
639
- help="The type of heavy channels in double sparsity attention",
640
- )
641
-
642
626
  # LoRA
643
627
  parser.add_argument(
644
628
  "--lora-paths",
@@ -678,6 +662,75 @@ class ServerArgs:
678
662
  help="Choose the backend for grammar-guided decoding.",
679
663
  )
680
664
 
665
+ # Speculative decoding
666
+ parser.add_argument(
667
+ "--speculative-algorithm",
668
+ type=str,
669
+ choices=["EAGLE"],
670
+ help="Speculative algorithm.",
671
+ )
672
+ parser.add_argument(
673
+ "--speculative-draft-model-path",
674
+ type=str,
675
+ help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
676
+ )
677
+ parser.add_argument(
678
+ "--speculative-num-steps",
679
+ type=int,
680
+ help="The number of steps sampled from draft model in Speculative Decoding.",
681
+ default=ServerArgs.speculative_num_steps,
682
+ )
683
+ parser.add_argument(
684
+ "--speculative-num-draft-tokens",
685
+ type=int,
686
+ help="The number of token sampled from draft model in Speculative Decoding.",
687
+ default=ServerArgs.speculative_num_draft_tokens,
688
+ )
689
+ parser.add_argument(
690
+ "--speculative-eagle-topk",
691
+ type=int,
692
+ help="The number of token sampled from draft model in eagle2 each step.",
693
+ choices=[1, 2, 4, 8],
694
+ default=ServerArgs.speculative_eagle_topk,
695
+ )
696
+
697
+ # Double Sparsity
698
+ parser.add_argument(
699
+ "--enable-double-sparsity",
700
+ action="store_true",
701
+ help="Enable double sparsity attention",
702
+ )
703
+ parser.add_argument(
704
+ "--ds-channel-config-path",
705
+ type=str,
706
+ default=ServerArgs.ds_channel_config_path,
707
+ help="The path of the double sparsity channel config",
708
+ )
709
+ parser.add_argument(
710
+ "--ds-heavy-channel-num",
711
+ type=int,
712
+ default=ServerArgs.ds_heavy_channel_num,
713
+ help="The number of heavy channels in double sparsity attention",
714
+ )
715
+ parser.add_argument(
716
+ "--ds-heavy-token-num",
717
+ type=int,
718
+ default=ServerArgs.ds_heavy_token_num,
719
+ help="The number of heavy tokens in double sparsity attention",
720
+ )
721
+ parser.add_argument(
722
+ "--ds-heavy-channel-type",
723
+ type=str,
724
+ default=ServerArgs.ds_heavy_channel_type,
725
+ help="The type of heavy channels in double sparsity attention",
726
+ )
727
+ parser.add_argument(
728
+ "--ds-sparse-decode-threshold",
729
+ type=int,
730
+ default=ServerArgs.ds_sparse_decode_threshold,
731
+ help="The type of heavy channels in double sparsity attention",
732
+ )
733
+
681
734
  # Optimization/debug options
682
735
  parser.add_argument(
683
736
  "--disable-radix-cache",
@@ -751,6 +804,12 @@ class ServerArgs:
751
804
  default=ServerArgs.cuda_graph_max_bs,
752
805
  help="Set the maximum batch size for cuda graph.",
753
806
  )
807
+ parser.add_argument(
808
+ "--cuda-graph-bs",
809
+ type=int,
810
+ nargs="+",
811
+ help="Set the list of batch sizes for cuda graph.",
812
+ )
754
813
  parser.add_argument(
755
814
  "--torchao-config",
756
815
  type=str,
@@ -869,7 +928,10 @@ class PortArgs:
869
928
  while True:
870
929
  if is_port_available(port):
871
930
  break
872
- port += 42
931
+ if port < 60000:
932
+ port += 42
933
+ else:
934
+ port -= 43
873
935
 
874
936
  return PortArgs(
875
937
  tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,