sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +40 -18
  5. sglang/global_config.py +21 -16
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +33 -59
  15. sglang/srt/layers/token_attention.py +4 -8
  16. sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
  17. sglang/srt/managers/controller/infer_batch.py +244 -36
  18. sglang/srt/managers/controller/manager_single.py +1 -1
  19. sglang/srt/managers/controller/model_runner.py +69 -284
  20. sglang/srt/managers/controller/tp_worker.py +39 -20
  21. sglang/srt/managers/detokenizer_manager.py +4 -2
  22. sglang/srt/managers/io_struct.py +1 -1
  23. sglang/srt/managers/tokenizer_manager.py +14 -13
  24. sglang/srt/memory_pool.py +33 -6
  25. sglang/srt/model_config.py +6 -0
  26. sglang/srt/models/gemma2.py +436 -0
  27. sglang/srt/models/llama2.py +3 -3
  28. sglang/srt/models/llama_classification.py +10 -7
  29. sglang/srt/models/minicpm.py +373 -0
  30. sglang/srt/models/qwen2_moe.py +454 -0
  31. sglang/srt/openai_api_adapter.py +2 -2
  32. sglang/srt/openai_protocol.py +1 -1
  33. sglang/srt/server.py +18 -8
  34. sglang/srt/server_args.py +24 -20
  35. sglang/srt/utils.py +68 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
  37. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
  38. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
  39. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,454 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
4
+ """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
5
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers import PretrainedConfig
11
+
12
+ from vllm.config import CacheConfig
13
+ from vllm.distributed import (get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce)
15
+ from vllm.model_executor.layers.activation import SiluAndMul
16
+ from vllm.model_executor.layers.fused_moe import FusedMoE
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19
+ QKVParallelLinear,
20
+ ReplicatedLinear,
21
+ RowParallelLinear)
22
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
+ from vllm.model_executor.layers.quantization.base_config import (
24
+ QuantizationConfig)
25
+ from vllm.model_executor.layers.rotary_embedding import get_rope
26
+ from vllm.model_executor.layers.sampler import Sampler
27
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
28
+ ParallelLMHead, VocabParallelEmbedding)
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
31
+ from vllm.sequence import IntermediateTensors, SamplerOutput
32
+
33
+ from sglang.srt.layers.logits_processor import LogitsProcessor
34
+ from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.managers.controller.model_runner import InputMetadata
36
+
37
+ class Qwen2MoeMLP(nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ hidden_size: int,
42
+ intermediate_size: int,
43
+ hidden_act: str,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ reduce_results: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.gate_up_proj = MergedColumnParallelLinear(
49
+ hidden_size, [intermediate_size] * 2,
50
+ bias=False,
51
+ quant_config=quant_config)
52
+ self.down_proj = RowParallelLinear(intermediate_size,
53
+ hidden_size,
54
+ bias=False,
55
+ quant_config=quant_config,
56
+ reduce_results=reduce_results)
57
+ if hidden_act != "silu":
58
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
59
+ "Only silu is supported for now.")
60
+ self.act_fn = SiluAndMul()
61
+
62
+ def forward(self, x):
63
+ gate_up, _ = self.gate_up_proj(x)
64
+ x = self.act_fn(gate_up)
65
+ x, _ = self.down_proj(x)
66
+ return x
67
+
68
+
69
+ class Qwen2MoeSparseMoeBlock(nn.Module):
70
+
71
+ def __init__(
72
+ self,
73
+ config: PretrainedConfig,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ ):
76
+ super().__init__()
77
+ self.tp_size = get_tensor_model_parallel_world_size()
78
+
79
+ if self.tp_size > config.num_experts:
80
+ raise ValueError(
81
+ f"Tensor parallel size {self.tp_size} is greater than "
82
+ f"the number of experts {config.num_experts}.")
83
+
84
+ self.experts = FusedMoE(num_experts=config.num_experts,
85
+ top_k=config.num_experts_per_tok,
86
+ hidden_size=config.hidden_size,
87
+ intermediate_size=config.moe_intermediate_size,
88
+ reduce_results=False,
89
+ renormalize=config.norm_topk_prob,
90
+ quant_config=quant_config)
91
+
92
+ self.gate = ReplicatedLinear(config.hidden_size,
93
+ config.num_experts,
94
+ bias=False,
95
+ quant_config=None)
96
+ if config.shared_expert_intermediate_size > 0:
97
+ self.shared_expert = Qwen2MoeMLP(
98
+ hidden_size=config.hidden_size,
99
+ intermediate_size=config.shared_expert_intermediate_size,
100
+ hidden_act=config.hidden_act,
101
+ quant_config=quant_config,
102
+ reduce_results=False,
103
+ )
104
+ else:
105
+ self.shared_expert = None
106
+ self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
107
+ 1,
108
+ bias=False)
109
+
110
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ num_tokens, hidden_dim = hidden_states.shape
112
+ hidden_states = hidden_states.view(-1, hidden_dim)
113
+ shared_output = None
114
+ if self.shared_expert is not None:
115
+ shared_output = self.shared_expert(hidden_states)
116
+ if self.shared_expert_gate is not None:
117
+ shared_output = F.sigmoid(
118
+ self.shared_expert_gate(hidden_states)) * shared_output
119
+
120
+ # router_logits: (num_tokens, n_experts)
121
+ router_logits, _ = self.gate(hidden_states)
122
+ final_hidden_states = self.experts(hidden_states=hidden_states,
123
+ router_logits=router_logits)
124
+ if shared_output is not None:
125
+ final_hidden_states = final_hidden_states + shared_output
126
+ if self.tp_size > 1:
127
+ final_hidden_states = tensor_model_parallel_all_reduce(
128
+ final_hidden_states)
129
+
130
+ return final_hidden_states.view(num_tokens, hidden_dim)
131
+
132
+
133
+ class Qwen2MoeAttention(nn.Module):
134
+
135
+ def __init__(
136
+ self,
137
+ hidden_size: int,
138
+ num_heads: int,
139
+ num_kv_heads: int,
140
+ layer_id: int = 0,
141
+ rope_theta: float = 10000,
142
+ rope_scaling: Optional[Dict[str, Any]] = None,
143
+ max_position_embeddings: int = 8192,
144
+ cache_config: Optional[CacheConfig] = None,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ ) -> None:
147
+ super().__init__()
148
+ self.hidden_size = hidden_size
149
+ tp_size = get_tensor_model_parallel_world_size()
150
+ self.total_num_heads = num_heads
151
+ assert self.total_num_heads % tp_size == 0
152
+ self.num_heads = self.total_num_heads // tp_size
153
+ self.total_num_kv_heads = num_kv_heads
154
+ if self.total_num_kv_heads >= tp_size:
155
+ # Number of KV heads is greater than TP size, so we partition
156
+ # the KV heads across multiple tensor parallel GPUs.
157
+ assert self.total_num_kv_heads % tp_size == 0
158
+ else:
159
+ # Number of KV heads is less than TP size, so we replicate
160
+ # the KV heads across multiple tensor parallel GPUs.
161
+ assert tp_size % self.total_num_kv_heads == 0
162
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
163
+ self.head_dim = hidden_size // self.total_num_heads
164
+ self.q_size = self.num_heads * self.head_dim
165
+ self.kv_size = self.num_kv_heads * self.head_dim
166
+ self.scaling = self.head_dim**-0.5
167
+ self.rope_theta = rope_theta
168
+ self.max_position_embeddings = max_position_embeddings
169
+
170
+ self.qkv_proj = QKVParallelLinear(
171
+ hidden_size,
172
+ self.head_dim,
173
+ self.total_num_heads,
174
+ self.total_num_kv_heads,
175
+ bias=True,
176
+ quant_config=quant_config,
177
+ )
178
+
179
+ self.o_proj = RowParallelLinear(
180
+ self.total_num_heads * self.head_dim,
181
+ hidden_size,
182
+ bias=False,
183
+ quant_config=quant_config,
184
+ )
185
+
186
+ self.rotary_emb = get_rope(
187
+ self.head_dim,
188
+ rotary_dim=self.head_dim,
189
+ max_position=max_position_embeddings,
190
+ base=rope_theta,
191
+ rope_scaling=rope_scaling,
192
+ )
193
+ self.attn = RadixAttention(self.num_heads,
194
+ self.head_dim,
195
+ self.scaling,
196
+ num_kv_heads=self.num_kv_heads,
197
+ layer_id=layer_id)
198
+
199
+ def forward(
200
+ self,
201
+ positions: torch.Tensor,
202
+ hidden_states: torch.Tensor,
203
+ input_metadata: InputMetadata
204
+ ) -> torch.Tensor:
205
+ qkv, _ = self.qkv_proj(hidden_states)
206
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
207
+ q, k = self.rotary_emb(positions, q, k)
208
+ attn_output = self.attn(q, k, v, input_metadata)
209
+ output, _ = self.o_proj(attn_output)
210
+ return output
211
+
212
+
213
+ class Qwen2MoeDecoderLayer(nn.Module):
214
+
215
+ def __init__(
216
+ self,
217
+ config: PretrainedConfig,
218
+ layer_id: int,
219
+ cache_config: Optional[CacheConfig] = None,
220
+ quant_config: Optional[QuantizationConfig] = None,
221
+ ) -> None:
222
+ super().__init__()
223
+ self.hidden_size = config.hidden_size
224
+ rope_theta = getattr(config, "rope_theta", 10000)
225
+ rope_scaling = getattr(config, "rope_scaling", None)
226
+ max_position_embeddings = getattr(config, "max_position_embeddings",
227
+ 8192)
228
+ self.self_attn = Qwen2MoeAttention(
229
+ hidden_size=self.hidden_size,
230
+ num_heads=config.num_attention_heads,
231
+ num_kv_heads=config.num_key_value_heads,
232
+ layer_id=layer_id,
233
+ rope_theta=rope_theta,
234
+ rope_scaling=rope_scaling,
235
+ max_position_embeddings=max_position_embeddings,
236
+ cache_config=cache_config,
237
+ quant_config=quant_config,
238
+ )
239
+
240
+ # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
241
+ # `mlp_only_layers` in the config.
242
+ mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
243
+ config.mlp_only_layers)
244
+ if (layer_id not in mlp_only_layers) and (
245
+ config.num_experts > 0 and
246
+ (layer_id + 1) % config.decoder_sparse_step == 0):
247
+ self.mlp = Qwen2MoeSparseMoeBlock(config=config,
248
+ quant_config=quant_config)
249
+ else:
250
+ self.mlp = Qwen2MoeMLP(
251
+ hidden_size=config.hidden_size,
252
+ intermediate_size=config.intermediate_size,
253
+ hidden_act=config.hidden_act,
254
+ quant_config=quant_config,
255
+ )
256
+ self.input_layernorm = RMSNorm(config.hidden_size,
257
+ eps=config.rms_norm_eps)
258
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
259
+ eps=config.rms_norm_eps)
260
+
261
+ def forward(
262
+ self,
263
+ positions: torch.Tensor,
264
+ hidden_states: torch.Tensor,
265
+ input_metadata: InputMetadata,
266
+ residual: Optional[torch.Tensor],
267
+ ) -> torch.Tensor:
268
+ # Self Attention
269
+ if residual is None:
270
+ residual = hidden_states
271
+ hidden_states = self.input_layernorm(hidden_states)
272
+ else:
273
+ hidden_states, residual = self.input_layernorm(
274
+ hidden_states, residual)
275
+ hidden_states = self.self_attn(
276
+ positions=positions,
277
+ hidden_states=hidden_states,
278
+ input_metadata=input_metadata
279
+ )
280
+
281
+ # Fully Connected
282
+ hidden_states, residual = self.post_attention_layernorm(
283
+ hidden_states, residual)
284
+ hidden_states = self.mlp(hidden_states)
285
+ return hidden_states, residual
286
+
287
+
288
+ class Qwen2MoeModel(nn.Module):
289
+
290
+ def __init__(
291
+ self,
292
+ config: PretrainedConfig,
293
+ cache_config: Optional[CacheConfig] = None,
294
+ quant_config: Optional[QuantizationConfig] = None,
295
+ ) -> None:
296
+ super().__init__()
297
+ self.padding_idx = config.pad_token_id
298
+ self.vocab_size = config.vocab_size
299
+
300
+ self.embed_tokens = VocabParallelEmbedding(
301
+ config.vocab_size,
302
+ config.hidden_size,
303
+ )
304
+ self.layers = nn.ModuleList([
305
+ Qwen2MoeDecoderLayer(config,
306
+ layer_id,
307
+ cache_config,
308
+ quant_config=quant_config)
309
+ for layer_id in range(config.num_hidden_layers)
310
+ ])
311
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
312
+
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None
319
+ ) -> torch.Tensor:
320
+ if input_embeds is None:
321
+ hidden_states = self.embed_tokens(input_ids)
322
+ else:
323
+ hidden_states = input_embeds
324
+ residual = None
325
+ for i in range(len(self.layers)):
326
+ layer = self.layers[i]
327
+ hidden_states, residual = layer(positions,
328
+ hidden_states,
329
+ input_metadata,
330
+ residual)
331
+ hidden_states, _ = self.norm(hidden_states, residual)
332
+ return hidden_states
333
+
334
+
335
+ class Qwen2MoeForCausalLM(nn.Module):
336
+
337
+ fall_back_to_pt_during_load = False
338
+
339
+ def __init__(
340
+ self,
341
+ config: PretrainedConfig,
342
+ cache_config: Optional[CacheConfig] = None,
343
+ quant_config: Optional[QuantizationConfig] = None,
344
+ ) -> None:
345
+ super().__init__()
346
+ self.config = config
347
+ self.quant_config = quant_config
348
+ self.model = Qwen2MoeModel(config, cache_config, quant_config)
349
+ self.lm_head = ParallelLMHead(config.vocab_size,
350
+ config.hidden_size,
351
+ quant_config=quant_config)
352
+ self.logits_processor = LogitsProcessor(config)
353
+ self.sampler = Sampler()
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ positions: torch.Tensor,
359
+ input_metadata: InputMetadata,
360
+ input_embeds: torch.Tensor = None
361
+ ) -> torch.Tensor:
362
+ hidden_states = self.model(input_ids, positions, input_metadata,
363
+ input_embeds)
364
+ return self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
365
+ input_metadata)
366
+
367
+ def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor,
368
+ input_metadata: InputMetadata) -> torch.Tensor:
369
+ logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
370
+ input_metadata)
371
+ return logits
372
+
373
+ def sample(
374
+ self,
375
+ logits: Optional[torch.Tensor],
376
+ sampling_metadata: SamplingMetadata,
377
+ ) -> Optional[SamplerOutput]:
378
+ next_tokens = self.sampler(logits, sampling_metadata)
379
+ return next_tokens
380
+
381
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
382
+ stacked_params_mapping = [
383
+ # (param_name, shard_name, shard_id)
384
+ ("qkv_proj", "q_proj", "q"),
385
+ ("qkv_proj", "k_proj", "k"),
386
+ ("qkv_proj", "v_proj", "v"),
387
+ ("gate_up_proj", "gate_proj", 0),
388
+ ("gate_up_proj", "up_proj", 1),
389
+ ]
390
+
391
+ expert_params_mapping = [
392
+ # These are the weights for the experts
393
+ # (param_name, weight_name, expert_id, shard_id)
394
+ ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
395
+ else "experts.w2_weight",
396
+ f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
397
+ for expert_id in range(self.config.num_experts) for shard_id,
398
+ weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
399
+ ]
400
+
401
+ params_dict = dict(self.named_parameters())
402
+ for name, loaded_weight in weights:
403
+ if "rotary_emb.inv_freq" in name:
404
+ continue
405
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
406
+ # Skip non-stacked layers and experts (experts handled below).
407
+ if weight_name not in name:
408
+ continue
409
+ # We have mlp.experts[0].gate_proj in the checkpoint.
410
+ # Since we handle the experts below in expert_params_mapping,
411
+ # we need to skip here BEFORE we update the name, otherwise
412
+ # name will be updated to mlp.experts[0].gate_up_proj, which
413
+ # will then be updated below in expert_params_mapping
414
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
415
+ if "mlp.experts" in name:
416
+ continue
417
+ name = name.replace(weight_name, param_name)
418
+ # Skip loading extra bias for GPTQ models.
419
+ if name.endswith(".bias") and name not in params_dict:
420
+ continue
421
+ if name not in params_dict:
422
+ continue
423
+
424
+ param = params_dict[name]
425
+ weight_loader = param.weight_loader
426
+ weight_loader(param, loaded_weight, shard_id)
427
+ break
428
+ else:
429
+ for mapping in expert_params_mapping:
430
+ param_name, weight_name, expert_id, shard_id = mapping
431
+ if weight_name not in name:
432
+ continue
433
+ name = name.replace(weight_name, param_name)
434
+ param = params_dict[name]
435
+ weight_loader = param.weight_loader
436
+ weight_loader(param,
437
+ loaded_weight,
438
+ weight_name,
439
+ shard_id=shard_id,
440
+ expert_id=expert_id)
441
+ break
442
+ else:
443
+ # Skip loading extra bias for GPTQ models.
444
+ if name.endswith(".bias") and name not in params_dict:
445
+ continue
446
+ if name not in params_dict:
447
+ continue
448
+
449
+ param = params_dict[name]
450
+ weight_loader = getattr(param, "weight_loader",
451
+ default_weight_loader)
452
+ weight_loader(param, loaded_weight)
453
+
454
+ EntryClass = Qwen2MoeForCausalLM
@@ -164,7 +164,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
164
164
  logprobs = None
165
165
 
166
166
  delta = text[len(stream_buffer) :]
167
- stream_buffer = content["text"]
167
+ stream_buffer = stream_buffer + delta
168
168
  choice_data = CompletionResponseStreamChoice(
169
169
  index=0,
170
170
  text=delta,
@@ -323,7 +323,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
323
323
 
324
324
  text = content["text"]
325
325
  delta = text[len(stream_buffer) :]
326
- stream_buffer = text
326
+ stream_buffer = stream_buffer + delta
327
327
  choice_data = ChatCompletionResponseStreamChoice(
328
328
  index=0,
329
329
  delta=DeltaMessage(content=delta),
@@ -134,7 +134,7 @@ class ChatCompletionRequest(BaseModel):
134
134
  logit_bias: Optional[Dict[str, float]] = None
135
135
  logprobs: Optional[bool] = False
136
136
  top_logprobs: Optional[int] = None
137
- max_tokens: Optional[int] = None
137
+ max_tokens: Optional[int] = 16
138
138
  n: Optional[int] = 1
139
139
  presence_penalty: Optional[float] = 0.0
140
140
  response_format: Optional[ResponseFormat] = None
sglang/srt/server.py CHANGED
@@ -51,13 +51,12 @@ from sglang.srt.utils import (
51
51
  allocate_init_ports,
52
52
  assert_pkg_version,
53
53
  enable_show_time_cost,
54
- send_addrs_to_rank_0,
55
54
  receive_addrs,
55
+ send_addrs_to_rank_0,
56
56
  start_rpyc_service_process,
57
57
  )
58
58
  from sglang.utils import get_exception_traceback
59
59
 
60
-
61
60
  logger = logging.getLogger(__name__)
62
61
 
63
62
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -147,14 +146,19 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
147
146
 
148
147
  # Set global environments
149
148
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
149
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
150
150
  if server_args.show_time_cost:
151
151
  enable_show_time_cost()
152
152
  if server_args.disable_disk_cache:
153
153
  disable_cache()
154
154
  if not server_args.disable_flashinfer:
155
- assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
156
- "reinstall the latest version by following the instructions "
157
- "at https://docs.flashinfer.ai/installation.html.")
155
+ assert_pkg_version(
156
+ "flashinfer",
157
+ "0.0.8",
158
+ "Please uninstall the old version and "
159
+ "reinstall the latest version by following the instructions "
160
+ "at https://docs.flashinfer.ai/installation.html.",
161
+ )
158
162
  if server_args.chat_template:
159
163
  # TODO: replace this with huggingface transformers template
160
164
  load_chat_template_for_openai_api(server_args.chat_template)
@@ -176,7 +180,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
176
180
  ModelPortArgs(
177
181
  nccl_port=ports[3 + i * (tp_size_local + 1)],
178
182
  model_tp_ips=[None] * tp_size_local,
179
- model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
183
+ model_tp_ports=ports[
184
+ 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
185
+ ],
180
186
  )
181
187
  )
182
188
  port_args = PortArgs(
@@ -194,9 +200,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
194
200
  else:
195
201
  receive_addrs(model_port_args[0], server_args)
196
202
  for i in range(tp_size_local):
197
- start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
203
+ start_rpyc_service_process(
204
+ ModelTpService, model_port_args[0].model_tp_ports[i]
205
+ )
198
206
  if server_args.node_rank != 0:
199
- logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
207
+ logger.info(
208
+ f"[node_rank={server_args.node_rank}]: Listen for connections..."
209
+ )
200
210
  while True:
201
211
  pass
202
212
 
sglang/srt/server_args.py CHANGED
@@ -29,7 +29,7 @@ class ServerArgs:
29
29
  max_prefill_tokens: Optional[int] = None
30
30
  max_running_requests: Optional[int] = None
31
31
  schedule_heuristic: str = "lpm"
32
- schedule_conservativeness: float = 1.0
32
+ schedule_conservativeness: float = 0.8
33
33
 
34
34
  # Other runtime options
35
35
  tp_size: int = 1
@@ -53,8 +53,10 @@ class ServerArgs:
53
53
  disable_flashinfer: bool = False
54
54
  disable_radix_cache: bool = False
55
55
  disable_regex_jump_forward: bool = False
56
+ disable_cuda_graph: bool = False
56
57
  disable_disk_cache: bool = False
57
58
  attention_reduce_in_fp32: bool = False
59
+ enable_p2p_check: bool = False
58
60
 
59
61
  # Distributed args
60
62
  nccl_init_addr: Optional[str] = None
@@ -66,13 +68,13 @@ class ServerArgs:
66
68
  self.tokenizer_path = self.model_path
67
69
  if self.mem_fraction_static is None:
68
70
  if self.tp_size >= 8:
69
- self.mem_fraction_static = 0.80
71
+ self.mem_fraction_static = 0.78
70
72
  elif self.tp_size >= 4:
71
- self.mem_fraction_static = 0.82
73
+ self.mem_fraction_static = 0.80
72
74
  elif self.tp_size >= 2:
73
75
  self.mem_fraction_static = 0.85
74
76
  else:
75
- self.mem_fraction_static = 0.90
77
+ self.mem_fraction_static = 0.88
76
78
  if isinstance(self.additional_ports, int):
77
79
  self.additional_ports = [self.additional_ports]
78
80
  elif self.additional_ports is None:
@@ -137,17 +139,16 @@ class ServerArgs:
137
139
  "--dtype",
138
140
  type=str,
139
141
  default=ServerArgs.dtype,
140
- choices=[
141
- "auto", "half", "float16", "bfloat16", "float", "float32"
142
- ],
143
- help='Data type for model weights and activations.\n\n'
142
+ choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
143
+ help="Data type for model weights and activations.\n\n"
144
144
  '* "auto" will use FP16 precision for FP32 and FP16 models, and '
145
- 'BF16 precision for BF16 models.\n'
145
+ "BF16 precision for BF16 models.\n"
146
146
  '* "half" for FP16. Recommended for AWQ quantization.\n'
147
147
  '* "float16" is the same as "half".\n'
148
148
  '* "bfloat16" for a balance between precision and range.\n'
149
149
  '* "float" is shorthand for FP32 precision.\n'
150
- '* "float32" for FP32 precision.')
150
+ '* "float32" for FP32 precision.',
151
+ )
151
152
  parser.add_argument(
152
153
  "--trust-remote-code",
153
154
  action="store_true",
@@ -271,19 +272,12 @@ class ServerArgs:
271
272
  parser.add_argument(
272
273
  "--nccl-init-addr",
273
274
  type=str,
274
- help="The nccl init address of multi-node server."
275
+ help="The nccl init address of multi-node server.",
275
276
  )
276
277
  parser.add_argument(
277
- "--nnodes",
278
- type=int,
279
- default=1,
280
- help="The number of nodes."
281
- )
282
- parser.add_argument(
283
- "--node-rank",
284
- type=int,
285
- help="The node rank."
278
+ "--nnodes", type=int, default=1, help="The number of nodes."
286
279
  )
280
+ parser.add_argument("--node-rank", type=int, help="The node rank.")
287
281
 
288
282
  # Optimization/debug options
289
283
  parser.add_argument(
@@ -301,6 +295,11 @@ class ServerArgs:
301
295
  action="store_true",
302
296
  help="Disable regex jump-forward",
303
297
  )
298
+ parser.add_argument(
299
+ "--disable-cuda-graph",
300
+ action="store_true",
301
+ help="Disable cuda graph.",
302
+ )
304
303
  parser.add_argument(
305
304
  "--disable-disk-cache",
306
305
  action="store_true",
@@ -312,6 +311,11 @@ class ServerArgs:
312
311
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
313
312
  "This only affects Triton attention kernels",
314
313
  )
314
+ parser.add_argument(
315
+ "--enable-p2p-check",
316
+ action="store_true",
317
+ help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
318
+ )
315
319
 
316
320
  @classmethod
317
321
  def from_cli_args(cls, args: argparse.Namespace):