sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +34 -16
  5. sglang/global_config.py +1 -0
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +24 -27
  15. sglang/srt/layers/token_attention.py +4 -1
  16. sglang/srt/managers/controller/infer_batch.py +2 -2
  17. sglang/srt/managers/controller/manager_single.py +1 -1
  18. sglang/srt/managers/controller/model_runner.py +27 -15
  19. sglang/srt/managers/controller/tp_worker.py +31 -14
  20. sglang/srt/managers/detokenizer_manager.py +4 -2
  21. sglang/srt/managers/io_struct.py +1 -1
  22. sglang/srt/managers/tokenizer_manager.py +14 -13
  23. sglang/srt/model_config.py +6 -0
  24. sglang/srt/models/gemma2.py +436 -0
  25. sglang/srt/models/llama2.py +3 -3
  26. sglang/srt/models/llama_classification.py +10 -7
  27. sglang/srt/models/minicpm.py +373 -0
  28. sglang/srt/models/qwen2_moe.py +454 -0
  29. sglang/srt/openai_api_adapter.py +2 -2
  30. sglang/srt/openai_protocol.py +1 -1
  31. sglang/srt/server.py +17 -8
  32. sglang/srt/server_args.py +14 -16
  33. sglang/srt/utils.py +68 -35
  34. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
  35. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  37. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
  38. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,454 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
4
+ """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
5
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers import PretrainedConfig
11
+
12
+ from vllm.config import CacheConfig
13
+ from vllm.distributed import (get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce)
15
+ from vllm.model_executor.layers.activation import SiluAndMul
16
+ from vllm.model_executor.layers.fused_moe import FusedMoE
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19
+ QKVParallelLinear,
20
+ ReplicatedLinear,
21
+ RowParallelLinear)
22
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
+ from vllm.model_executor.layers.quantization.base_config import (
24
+ QuantizationConfig)
25
+ from vllm.model_executor.layers.rotary_embedding import get_rope
26
+ from vllm.model_executor.layers.sampler import Sampler
27
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
28
+ ParallelLMHead, VocabParallelEmbedding)
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
31
+ from vllm.sequence import IntermediateTensors, SamplerOutput
32
+
33
+ from sglang.srt.layers.logits_processor import LogitsProcessor
34
+ from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.managers.controller.model_runner import InputMetadata
36
+
37
+ class Qwen2MoeMLP(nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ hidden_size: int,
42
+ intermediate_size: int,
43
+ hidden_act: str,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ reduce_results: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.gate_up_proj = MergedColumnParallelLinear(
49
+ hidden_size, [intermediate_size] * 2,
50
+ bias=False,
51
+ quant_config=quant_config)
52
+ self.down_proj = RowParallelLinear(intermediate_size,
53
+ hidden_size,
54
+ bias=False,
55
+ quant_config=quant_config,
56
+ reduce_results=reduce_results)
57
+ if hidden_act != "silu":
58
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
59
+ "Only silu is supported for now.")
60
+ self.act_fn = SiluAndMul()
61
+
62
+ def forward(self, x):
63
+ gate_up, _ = self.gate_up_proj(x)
64
+ x = self.act_fn(gate_up)
65
+ x, _ = self.down_proj(x)
66
+ return x
67
+
68
+
69
+ class Qwen2MoeSparseMoeBlock(nn.Module):
70
+
71
+ def __init__(
72
+ self,
73
+ config: PretrainedConfig,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ ):
76
+ super().__init__()
77
+ self.tp_size = get_tensor_model_parallel_world_size()
78
+
79
+ if self.tp_size > config.num_experts:
80
+ raise ValueError(
81
+ f"Tensor parallel size {self.tp_size} is greater than "
82
+ f"the number of experts {config.num_experts}.")
83
+
84
+ self.experts = FusedMoE(num_experts=config.num_experts,
85
+ top_k=config.num_experts_per_tok,
86
+ hidden_size=config.hidden_size,
87
+ intermediate_size=config.moe_intermediate_size,
88
+ reduce_results=False,
89
+ renormalize=config.norm_topk_prob,
90
+ quant_config=quant_config)
91
+
92
+ self.gate = ReplicatedLinear(config.hidden_size,
93
+ config.num_experts,
94
+ bias=False,
95
+ quant_config=None)
96
+ if config.shared_expert_intermediate_size > 0:
97
+ self.shared_expert = Qwen2MoeMLP(
98
+ hidden_size=config.hidden_size,
99
+ intermediate_size=config.shared_expert_intermediate_size,
100
+ hidden_act=config.hidden_act,
101
+ quant_config=quant_config,
102
+ reduce_results=False,
103
+ )
104
+ else:
105
+ self.shared_expert = None
106
+ self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
107
+ 1,
108
+ bias=False)
109
+
110
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ num_tokens, hidden_dim = hidden_states.shape
112
+ hidden_states = hidden_states.view(-1, hidden_dim)
113
+ shared_output = None
114
+ if self.shared_expert is not None:
115
+ shared_output = self.shared_expert(hidden_states)
116
+ if self.shared_expert_gate is not None:
117
+ shared_output = F.sigmoid(
118
+ self.shared_expert_gate(hidden_states)) * shared_output
119
+
120
+ # router_logits: (num_tokens, n_experts)
121
+ router_logits, _ = self.gate(hidden_states)
122
+ final_hidden_states = self.experts(hidden_states=hidden_states,
123
+ router_logits=router_logits)
124
+ if shared_output is not None:
125
+ final_hidden_states = final_hidden_states + shared_output
126
+ if self.tp_size > 1:
127
+ final_hidden_states = tensor_model_parallel_all_reduce(
128
+ final_hidden_states)
129
+
130
+ return final_hidden_states.view(num_tokens, hidden_dim)
131
+
132
+
133
+ class Qwen2MoeAttention(nn.Module):
134
+
135
+ def __init__(
136
+ self,
137
+ hidden_size: int,
138
+ num_heads: int,
139
+ num_kv_heads: int,
140
+ layer_id: int = 0,
141
+ rope_theta: float = 10000,
142
+ rope_scaling: Optional[Dict[str, Any]] = None,
143
+ max_position_embeddings: int = 8192,
144
+ cache_config: Optional[CacheConfig] = None,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ ) -> None:
147
+ super().__init__()
148
+ self.hidden_size = hidden_size
149
+ tp_size = get_tensor_model_parallel_world_size()
150
+ self.total_num_heads = num_heads
151
+ assert self.total_num_heads % tp_size == 0
152
+ self.num_heads = self.total_num_heads // tp_size
153
+ self.total_num_kv_heads = num_kv_heads
154
+ if self.total_num_kv_heads >= tp_size:
155
+ # Number of KV heads is greater than TP size, so we partition
156
+ # the KV heads across multiple tensor parallel GPUs.
157
+ assert self.total_num_kv_heads % tp_size == 0
158
+ else:
159
+ # Number of KV heads is less than TP size, so we replicate
160
+ # the KV heads across multiple tensor parallel GPUs.
161
+ assert tp_size % self.total_num_kv_heads == 0
162
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
163
+ self.head_dim = hidden_size // self.total_num_heads
164
+ self.q_size = self.num_heads * self.head_dim
165
+ self.kv_size = self.num_kv_heads * self.head_dim
166
+ self.scaling = self.head_dim**-0.5
167
+ self.rope_theta = rope_theta
168
+ self.max_position_embeddings = max_position_embeddings
169
+
170
+ self.qkv_proj = QKVParallelLinear(
171
+ hidden_size,
172
+ self.head_dim,
173
+ self.total_num_heads,
174
+ self.total_num_kv_heads,
175
+ bias=True,
176
+ quant_config=quant_config,
177
+ )
178
+
179
+ self.o_proj = RowParallelLinear(
180
+ self.total_num_heads * self.head_dim,
181
+ hidden_size,
182
+ bias=False,
183
+ quant_config=quant_config,
184
+ )
185
+
186
+ self.rotary_emb = get_rope(
187
+ self.head_dim,
188
+ rotary_dim=self.head_dim,
189
+ max_position=max_position_embeddings,
190
+ base=rope_theta,
191
+ rope_scaling=rope_scaling,
192
+ )
193
+ self.attn = RadixAttention(self.num_heads,
194
+ self.head_dim,
195
+ self.scaling,
196
+ num_kv_heads=self.num_kv_heads,
197
+ layer_id=layer_id)
198
+
199
+ def forward(
200
+ self,
201
+ positions: torch.Tensor,
202
+ hidden_states: torch.Tensor,
203
+ input_metadata: InputMetadata
204
+ ) -> torch.Tensor:
205
+ qkv, _ = self.qkv_proj(hidden_states)
206
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
207
+ q, k = self.rotary_emb(positions, q, k)
208
+ attn_output = self.attn(q, k, v, input_metadata)
209
+ output, _ = self.o_proj(attn_output)
210
+ return output
211
+
212
+
213
+ class Qwen2MoeDecoderLayer(nn.Module):
214
+
215
+ def __init__(
216
+ self,
217
+ config: PretrainedConfig,
218
+ layer_id: int,
219
+ cache_config: Optional[CacheConfig] = None,
220
+ quant_config: Optional[QuantizationConfig] = None,
221
+ ) -> None:
222
+ super().__init__()
223
+ self.hidden_size = config.hidden_size
224
+ rope_theta = getattr(config, "rope_theta", 10000)
225
+ rope_scaling = getattr(config, "rope_scaling", None)
226
+ max_position_embeddings = getattr(config, "max_position_embeddings",
227
+ 8192)
228
+ self.self_attn = Qwen2MoeAttention(
229
+ hidden_size=self.hidden_size,
230
+ num_heads=config.num_attention_heads,
231
+ num_kv_heads=config.num_key_value_heads,
232
+ layer_id=layer_id,
233
+ rope_theta=rope_theta,
234
+ rope_scaling=rope_scaling,
235
+ max_position_embeddings=max_position_embeddings,
236
+ cache_config=cache_config,
237
+ quant_config=quant_config,
238
+ )
239
+
240
+ # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
241
+ # `mlp_only_layers` in the config.
242
+ mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
243
+ config.mlp_only_layers)
244
+ if (layer_id not in mlp_only_layers) and (
245
+ config.num_experts > 0 and
246
+ (layer_id + 1) % config.decoder_sparse_step == 0):
247
+ self.mlp = Qwen2MoeSparseMoeBlock(config=config,
248
+ quant_config=quant_config)
249
+ else:
250
+ self.mlp = Qwen2MoeMLP(
251
+ hidden_size=config.hidden_size,
252
+ intermediate_size=config.intermediate_size,
253
+ hidden_act=config.hidden_act,
254
+ quant_config=quant_config,
255
+ )
256
+ self.input_layernorm = RMSNorm(config.hidden_size,
257
+ eps=config.rms_norm_eps)
258
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
259
+ eps=config.rms_norm_eps)
260
+
261
+ def forward(
262
+ self,
263
+ positions: torch.Tensor,
264
+ hidden_states: torch.Tensor,
265
+ input_metadata: InputMetadata,
266
+ residual: Optional[torch.Tensor],
267
+ ) -> torch.Tensor:
268
+ # Self Attention
269
+ if residual is None:
270
+ residual = hidden_states
271
+ hidden_states = self.input_layernorm(hidden_states)
272
+ else:
273
+ hidden_states, residual = self.input_layernorm(
274
+ hidden_states, residual)
275
+ hidden_states = self.self_attn(
276
+ positions=positions,
277
+ hidden_states=hidden_states,
278
+ input_metadata=input_metadata
279
+ )
280
+
281
+ # Fully Connected
282
+ hidden_states, residual = self.post_attention_layernorm(
283
+ hidden_states, residual)
284
+ hidden_states = self.mlp(hidden_states)
285
+ return hidden_states, residual
286
+
287
+
288
+ class Qwen2MoeModel(nn.Module):
289
+
290
+ def __init__(
291
+ self,
292
+ config: PretrainedConfig,
293
+ cache_config: Optional[CacheConfig] = None,
294
+ quant_config: Optional[QuantizationConfig] = None,
295
+ ) -> None:
296
+ super().__init__()
297
+ self.padding_idx = config.pad_token_id
298
+ self.vocab_size = config.vocab_size
299
+
300
+ self.embed_tokens = VocabParallelEmbedding(
301
+ config.vocab_size,
302
+ config.hidden_size,
303
+ )
304
+ self.layers = nn.ModuleList([
305
+ Qwen2MoeDecoderLayer(config,
306
+ layer_id,
307
+ cache_config,
308
+ quant_config=quant_config)
309
+ for layer_id in range(config.num_hidden_layers)
310
+ ])
311
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
312
+
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None
319
+ ) -> torch.Tensor:
320
+ if input_embeds is None:
321
+ hidden_states = self.embed_tokens(input_ids)
322
+ else:
323
+ hidden_states = input_embeds
324
+ residual = None
325
+ for i in range(len(self.layers)):
326
+ layer = self.layers[i]
327
+ hidden_states, residual = layer(positions,
328
+ hidden_states,
329
+ input_metadata,
330
+ residual)
331
+ hidden_states, _ = self.norm(hidden_states, residual)
332
+ return hidden_states
333
+
334
+
335
+ class Qwen2MoeForCausalLM(nn.Module):
336
+
337
+ fall_back_to_pt_during_load = False
338
+
339
+ def __init__(
340
+ self,
341
+ config: PretrainedConfig,
342
+ cache_config: Optional[CacheConfig] = None,
343
+ quant_config: Optional[QuantizationConfig] = None,
344
+ ) -> None:
345
+ super().__init__()
346
+ self.config = config
347
+ self.quant_config = quant_config
348
+ self.model = Qwen2MoeModel(config, cache_config, quant_config)
349
+ self.lm_head = ParallelLMHead(config.vocab_size,
350
+ config.hidden_size,
351
+ quant_config=quant_config)
352
+ self.logits_processor = LogitsProcessor(config)
353
+ self.sampler = Sampler()
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ positions: torch.Tensor,
359
+ input_metadata: InputMetadata,
360
+ input_embeds: torch.Tensor = None
361
+ ) -> torch.Tensor:
362
+ hidden_states = self.model(input_ids, positions, input_metadata,
363
+ input_embeds)
364
+ return self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
365
+ input_metadata)
366
+
367
+ def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor,
368
+ input_metadata: InputMetadata) -> torch.Tensor:
369
+ logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
370
+ input_metadata)
371
+ return logits
372
+
373
+ def sample(
374
+ self,
375
+ logits: Optional[torch.Tensor],
376
+ sampling_metadata: SamplingMetadata,
377
+ ) -> Optional[SamplerOutput]:
378
+ next_tokens = self.sampler(logits, sampling_metadata)
379
+ return next_tokens
380
+
381
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
382
+ stacked_params_mapping = [
383
+ # (param_name, shard_name, shard_id)
384
+ ("qkv_proj", "q_proj", "q"),
385
+ ("qkv_proj", "k_proj", "k"),
386
+ ("qkv_proj", "v_proj", "v"),
387
+ ("gate_up_proj", "gate_proj", 0),
388
+ ("gate_up_proj", "up_proj", 1),
389
+ ]
390
+
391
+ expert_params_mapping = [
392
+ # These are the weights for the experts
393
+ # (param_name, weight_name, expert_id, shard_id)
394
+ ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
395
+ else "experts.w2_weight",
396
+ f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
397
+ for expert_id in range(self.config.num_experts) for shard_id,
398
+ weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
399
+ ]
400
+
401
+ params_dict = dict(self.named_parameters())
402
+ for name, loaded_weight in weights:
403
+ if "rotary_emb.inv_freq" in name:
404
+ continue
405
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
406
+ # Skip non-stacked layers and experts (experts handled below).
407
+ if weight_name not in name:
408
+ continue
409
+ # We have mlp.experts[0].gate_proj in the checkpoint.
410
+ # Since we handle the experts below in expert_params_mapping,
411
+ # we need to skip here BEFORE we update the name, otherwise
412
+ # name will be updated to mlp.experts[0].gate_up_proj, which
413
+ # will then be updated below in expert_params_mapping
414
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
415
+ if "mlp.experts" in name:
416
+ continue
417
+ name = name.replace(weight_name, param_name)
418
+ # Skip loading extra bias for GPTQ models.
419
+ if name.endswith(".bias") and name not in params_dict:
420
+ continue
421
+ if name not in params_dict:
422
+ continue
423
+
424
+ param = params_dict[name]
425
+ weight_loader = param.weight_loader
426
+ weight_loader(param, loaded_weight, shard_id)
427
+ break
428
+ else:
429
+ for mapping in expert_params_mapping:
430
+ param_name, weight_name, expert_id, shard_id = mapping
431
+ if weight_name not in name:
432
+ continue
433
+ name = name.replace(weight_name, param_name)
434
+ param = params_dict[name]
435
+ weight_loader = param.weight_loader
436
+ weight_loader(param,
437
+ loaded_weight,
438
+ weight_name,
439
+ shard_id=shard_id,
440
+ expert_id=expert_id)
441
+ break
442
+ else:
443
+ # Skip loading extra bias for GPTQ models.
444
+ if name.endswith(".bias") and name not in params_dict:
445
+ continue
446
+ if name not in params_dict:
447
+ continue
448
+
449
+ param = params_dict[name]
450
+ weight_loader = getattr(param, "weight_loader",
451
+ default_weight_loader)
452
+ weight_loader(param, loaded_weight)
453
+
454
+ EntryClass = Qwen2MoeForCausalLM
@@ -164,7 +164,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
164
164
  logprobs = None
165
165
 
166
166
  delta = text[len(stream_buffer) :]
167
- stream_buffer = content["text"]
167
+ stream_buffer = stream_buffer + delta
168
168
  choice_data = CompletionResponseStreamChoice(
169
169
  index=0,
170
170
  text=delta,
@@ -323,7 +323,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
323
323
 
324
324
  text = content["text"]
325
325
  delta = text[len(stream_buffer) :]
326
- stream_buffer = text
326
+ stream_buffer = stream_buffer + delta
327
327
  choice_data = ChatCompletionResponseStreamChoice(
328
328
  index=0,
329
329
  delta=DeltaMessage(content=delta),
@@ -134,7 +134,7 @@ class ChatCompletionRequest(BaseModel):
134
134
  logit_bias: Optional[Dict[str, float]] = None
135
135
  logprobs: Optional[bool] = False
136
136
  top_logprobs: Optional[int] = None
137
- max_tokens: Optional[int] = None
137
+ max_tokens: Optional[int] = 16
138
138
  n: Optional[int] = 1
139
139
  presence_penalty: Optional[float] = 0.0
140
140
  response_format: Optional[ResponseFormat] = None
sglang/srt/server.py CHANGED
@@ -51,13 +51,12 @@ from sglang.srt.utils import (
51
51
  allocate_init_ports,
52
52
  assert_pkg_version,
53
53
  enable_show_time_cost,
54
- send_addrs_to_rank_0,
55
54
  receive_addrs,
55
+ send_addrs_to_rank_0,
56
56
  start_rpyc_service_process,
57
57
  )
58
58
  from sglang.utils import get_exception_traceback
59
59
 
60
-
61
60
  logger = logging.getLogger(__name__)
62
61
 
63
62
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
152
151
  if server_args.disable_disk_cache:
153
152
  disable_cache()
154
153
  if not server_args.disable_flashinfer:
155
- assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
156
- "reinstall the latest version by following the instructions "
157
- "at https://docs.flashinfer.ai/installation.html.")
154
+ assert_pkg_version(
155
+ "flashinfer",
156
+ "0.0.8",
157
+ "Please uninstall the old version and "
158
+ "reinstall the latest version by following the instructions "
159
+ "at https://docs.flashinfer.ai/installation.html.",
160
+ )
158
161
  if server_args.chat_template:
159
162
  # TODO: replace this with huggingface transformers template
160
163
  load_chat_template_for_openai_api(server_args.chat_template)
@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
176
179
  ModelPortArgs(
177
180
  nccl_port=ports[3 + i * (tp_size_local + 1)],
178
181
  model_tp_ips=[None] * tp_size_local,
179
- model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
182
+ model_tp_ports=ports[
183
+ 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
184
+ ],
180
185
  )
181
186
  )
182
187
  port_args = PortArgs(
@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
194
199
  else:
195
200
  receive_addrs(model_port_args[0], server_args)
196
201
  for i in range(tp_size_local):
197
- start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
202
+ start_rpyc_service_process(
203
+ ModelTpService, model_port_args[0].model_tp_ports[i]
204
+ )
198
205
  if server_args.node_rank != 0:
199
- logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
206
+ logger.info(
207
+ f"[node_rank={server_args.node_rank}]: Listen for connections..."
208
+ )
200
209
  while True:
201
210
  pass
202
211
 
sglang/srt/server_args.py CHANGED
@@ -55,6 +55,7 @@ class ServerArgs:
55
55
  disable_regex_jump_forward: bool = False
56
56
  disable_disk_cache: bool = False
57
57
  attention_reduce_in_fp32: bool = False
58
+ enable_p2p_check: bool = False
58
59
 
59
60
  # Distributed args
60
61
  nccl_init_addr: Optional[str] = None
@@ -137,17 +138,16 @@ class ServerArgs:
137
138
  "--dtype",
138
139
  type=str,
139
140
  default=ServerArgs.dtype,
140
- choices=[
141
- "auto", "half", "float16", "bfloat16", "float", "float32"
142
- ],
143
- help='Data type for model weights and activations.\n\n'
141
+ choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
142
+ help="Data type for model weights and activations.\n\n"
144
143
  '* "auto" will use FP16 precision for FP32 and FP16 models, and '
145
- 'BF16 precision for BF16 models.\n'
144
+ "BF16 precision for BF16 models.\n"
146
145
  '* "half" for FP16. Recommended for AWQ quantization.\n'
147
146
  '* "float16" is the same as "half".\n'
148
147
  '* "bfloat16" for a balance between precision and range.\n'
149
148
  '* "float" is shorthand for FP32 precision.\n'
150
- '* "float32" for FP32 precision.')
149
+ '* "float32" for FP32 precision.',
150
+ )
151
151
  parser.add_argument(
152
152
  "--trust-remote-code",
153
153
  action="store_true",
@@ -271,19 +271,12 @@ class ServerArgs:
271
271
  parser.add_argument(
272
272
  "--nccl-init-addr",
273
273
  type=str,
274
- help="The nccl init address of multi-node server."
275
- )
276
- parser.add_argument(
277
- "--nnodes",
278
- type=int,
279
- default=1,
280
- help="The number of nodes."
274
+ help="The nccl init address of multi-node server.",
281
275
  )
282
276
  parser.add_argument(
283
- "--node-rank",
284
- type=int,
285
- help="The node rank."
277
+ "--nnodes", type=int, default=1, help="The number of nodes."
286
278
  )
279
+ parser.add_argument("--node-rank", type=int, help="The node rank.")
287
280
 
288
281
  # Optimization/debug options
289
282
  parser.add_argument(
@@ -312,6 +305,11 @@ class ServerArgs:
312
305
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
313
306
  "This only affects Triton attention kernels",
314
307
  )
308
+ parser.add_argument(
309
+ "--enable-p2p-check",
310
+ action="store_true",
311
+ help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
312
+ )
315
313
 
316
314
  @classmethod
317
315
  def from_cli_args(cls, args: argparse.Namespace):