sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +758 -0
  7. sglang/check_env.py +171 -0
  8. sglang/lang/backend/__init__.py +0 -0
  9. sglang/lang/backend/anthropic.py +77 -0
  10. sglang/lang/backend/base_backend.py +80 -0
  11. sglang/lang/backend/litellm.py +90 -0
  12. sglang/lang/backend/openai.py +438 -0
  13. sglang/lang/backend/runtime_endpoint.py +283 -0
  14. sglang/lang/backend/vertexai.py +149 -0
  15. sglang/lang/tracer.py +1 -1
  16. sglang/launch_server.py +1 -1
  17. sglang/launch_server_llavavid.py +1 -4
  18. sglang/srt/conversation.py +1 -1
  19. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  20. sglang/srt/layers/extend_attention.py +0 -39
  21. sglang/srt/layers/linear.py +869 -0
  22. sglang/srt/layers/quantization/__init__.py +49 -0
  23. sglang/srt/layers/quantization/fp8.py +662 -0
  24. sglang/srt/layers/radix_attention.py +31 -5
  25. sglang/srt/layers/token_attention.py +1 -51
  26. sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
  27. sglang/srt/managers/controller/infer_batch.py +47 -49
  28. sglang/srt/managers/controller/manager_multi.py +107 -100
  29. sglang/srt/managers/controller/manager_single.py +76 -96
  30. sglang/srt/managers/controller/model_runner.py +35 -23
  31. sglang/srt/managers/controller/tp_worker.py +127 -138
  32. sglang/srt/managers/detokenizer_manager.py +49 -5
  33. sglang/srt/managers/io_struct.py +36 -17
  34. sglang/srt/managers/tokenizer_manager.py +228 -125
  35. sglang/srt/memory_pool.py +19 -6
  36. sglang/srt/model_loader/model_loader.py +277 -0
  37. sglang/srt/model_loader/utils.py +260 -0
  38. sglang/srt/models/chatglm.py +1 -0
  39. sglang/srt/models/dbrx.py +1 -0
  40. sglang/srt/models/grok.py +1 -0
  41. sglang/srt/models/internlm2.py +317 -0
  42. sglang/srt/models/llama2.py +65 -16
  43. sglang/srt/models/llama_classification.py +1 -0
  44. sglang/srt/models/llava.py +1 -0
  45. sglang/srt/models/llavavid.py +1 -0
  46. sglang/srt/models/minicpm.py +1 -0
  47. sglang/srt/models/mixtral.py +1 -0
  48. sglang/srt/models/mixtral_quant.py +1 -0
  49. sglang/srt/models/qwen.py +1 -0
  50. sglang/srt/models/qwen2.py +6 -0
  51. sglang/srt/models/qwen2_moe.py +7 -4
  52. sglang/srt/models/stablelm.py +1 -0
  53. sglang/srt/openai_api/adapter.py +432 -0
  54. sglang/srt/openai_api/api_adapter.py +432 -0
  55. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  56. sglang/srt/openai_api/openai_protocol.py +207 -0
  57. sglang/srt/openai_api/protocol.py +208 -0
  58. sglang/srt/openai_protocol.py +17 -0
  59. sglang/srt/sampling_params.py +2 -0
  60. sglang/srt/server.py +113 -84
  61. sglang/srt/server_args.py +23 -15
  62. sglang/srt/utils.py +16 -117
  63. sglang/test/test_conversation.py +1 -1
  64. sglang/test/test_openai_protocol.py +1 -1
  65. sglang/test/test_programs.py +1 -1
  66. sglang/test/test_utils.py +2 -2
  67. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
  68. sglang-0.1.22.dist-info/RECORD +103 -0
  69. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  70. sglang-0.1.21.dist-info/RECORD +0 -82
  71. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  72. {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ """Pydantic models for OpenAI API protocol"""
2
+
3
+ import time
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ from pydantic import BaseModel, Field
7
+ from typing_extensions import Literal
8
+
9
+
10
+ class ModelCard(BaseModel):
11
+ """Model cards."""
12
+
13
+ id: str
14
+ object: str = "model"
15
+ created: int = Field(default_factory=lambda: int(time.time()))
16
+ owned_by: str = "sglang"
17
+ root: Optional[str] = None
18
+
19
+
20
+ class ModelList(BaseModel):
21
+ """Model list consists of model cards."""
22
+
23
+ object: str = "list"
24
+ data: List[ModelCard] = []
25
+
26
+
27
+ class ErrorResponse(BaseModel):
28
+ object: str = "error"
29
+ message: str
30
+ type: str
31
+ param: Optional[str] = None
32
+ code: int
33
+
34
+
35
+ class LogProbs(BaseModel):
36
+ text_offset: List[int] = Field(default_factory=list)
37
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
38
+ tokens: List[str] = Field(default_factory=list)
39
+ top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
40
+
41
+
42
+ class UsageInfo(BaseModel):
43
+ prompt_tokens: int = 0
44
+ total_tokens: int = 0
45
+ completion_tokens: Optional[int] = 0
46
+
47
+
48
+ class CompletionRequest(BaseModel):
49
+ # Ordered by official OpenAI API documentation
50
+ # https://platform.openai.com/docs/api-reference/completions/create
51
+ model: str
52
+ prompt: Union[List[int], List[List[int]], str, List[str]]
53
+ best_of: Optional[int] = None
54
+ echo: Optional[bool] = False
55
+ frequency_penalty: Optional[float] = 0.0
56
+ logit_bias: Optional[Dict[str, float]] = None
57
+ logprobs: Optional[int] = None
58
+ max_tokens: Optional[int] = 16
59
+ n: int = 1
60
+ presence_penalty: Optional[float] = 0.0
61
+ seed: Optional[int] = None
62
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
63
+ stream: Optional[bool] = False
64
+ suffix: Optional[str] = None
65
+ temperature: Optional[float] = 1.0
66
+ top_p: Optional[float] = 1.0
67
+ user: Optional[str] = None
68
+
69
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
70
+ regex: Optional[str] = None
71
+ ignore_eos: Optional[bool] = False
72
+
73
+
74
+ class CompletionResponseChoice(BaseModel):
75
+ index: int
76
+ text: str
77
+ logprobs: Optional[LogProbs] = None
78
+ finish_reason: Optional[str] = None
79
+
80
+
81
+ class CompletionResponse(BaseModel):
82
+ id: str
83
+ object: str = "text_completion"
84
+ created: int = Field(default_factory=lambda: int(time.time()))
85
+ model: str
86
+ choices: List[CompletionResponseChoice]
87
+ usage: UsageInfo
88
+
89
+
90
+ class CompletionResponseStreamChoice(BaseModel):
91
+ index: int
92
+ text: str
93
+ logprobs: Optional[LogProbs] = None
94
+ finish_reason: Optional[str] = None
95
+
96
+
97
+ class CompletionStreamResponse(BaseModel):
98
+ id: str
99
+ object: str = "text_completion"
100
+ created: int = Field(default_factory=lambda: int(time.time()))
101
+ model: str
102
+ choices: List[CompletionResponseStreamChoice]
103
+ usage: UsageInfo
104
+
105
+
106
+ class ChatCompletionMessageGenericParam(BaseModel):
107
+ role: Literal["system", "assistant"]
108
+ content: str
109
+
110
+
111
+ class ChatCompletionMessageContentTextPart(BaseModel):
112
+ type: Literal["text"]
113
+ text: str
114
+
115
+
116
+ class ChatCompletionMessageContentImageURL(BaseModel):
117
+ url: str
118
+ detail: Optional[Literal["auto", "low", "high"]] = "auto"
119
+
120
+
121
+ class ChatCompletionMessageContentImagePart(BaseModel):
122
+ type: Literal["image_url"]
123
+ image_url: ChatCompletionMessageContentImageURL
124
+
125
+
126
+ ChatCompletionMessageContentPart = Union[
127
+ ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
128
+ ]
129
+
130
+
131
+ class ChatCompletionMessageUserParam(BaseModel):
132
+ role: Literal["user"]
133
+ content: Union[str, List[ChatCompletionMessageContentPart]]
134
+
135
+
136
+ ChatCompletionMessageParam = Union[
137
+ ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
138
+ ]
139
+
140
+
141
+ class ResponseFormat(BaseModel):
142
+ # type must be "json_object" or "text"
143
+ type: Literal["text", "json_object"]
144
+
145
+
146
+ class ChatCompletionRequest(BaseModel):
147
+ # Ordered by official OpenAI API documentation
148
+ # https://platform.openai.com/docs/api-reference/chat/create
149
+ messages: List[ChatCompletionMessageParam]
150
+ model: str
151
+ frequency_penalty: Optional[float] = 0.0
152
+ logit_bias: Optional[Dict[str, float]] = None
153
+ logprobs: Optional[bool] = False
154
+ top_logprobs: Optional[int] = None
155
+ max_tokens: Optional[int] = 16
156
+ n: Optional[int] = 1
157
+ presence_penalty: Optional[float] = 0.0
158
+ response_format: Optional[ResponseFormat] = None
159
+ seed: Optional[int] = None
160
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
161
+ stream: Optional[bool] = False
162
+ temperature: Optional[float] = 0.7
163
+ top_p: Optional[float] = 1.0
164
+ user: Optional[str] = None
165
+
166
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
167
+ regex: Optional[str] = None
168
+
169
+
170
+ class ChatMessage(BaseModel):
171
+ role: Optional[str] = None
172
+ content: Optional[str] = None
173
+
174
+
175
+ class ChatCompletionResponseChoice(BaseModel):
176
+ index: int
177
+ message: ChatMessage
178
+ logprobs: Optional[LogProbs] = None
179
+ finish_reason: Optional[str] = None
180
+
181
+
182
+ class ChatCompletionResponse(BaseModel):
183
+ id: str
184
+ object: str = "chat.completion"
185
+ created: int = Field(default_factory=lambda: int(time.time()))
186
+ model: str
187
+ choices: List[ChatCompletionResponseChoice]
188
+ usage: UsageInfo
189
+
190
+
191
+ class DeltaMessage(BaseModel):
192
+ role: Optional[str] = None
193
+ content: Optional[str] = None
194
+
195
+
196
+ class ChatCompletionResponseStreamChoice(BaseModel):
197
+ index: int
198
+ delta: DeltaMessage
199
+ logprobs: Optional[LogProbs] = None
200
+ finish_reason: Optional[str] = None
201
+
202
+
203
+ class ChatCompletionStreamResponse(BaseModel):
204
+ id: str
205
+ object: str = "chat.completion.chunk"
206
+ created: int = Field(default_factory=lambda: int(time.time()))
207
+ model: str
208
+ choices: List[ChatCompletionResponseStreamChoice]
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
7
7
  from typing_extensions import Literal
8
8
 
9
9
 
10
+ class ModelCard(BaseModel):
11
+ """Model cards."""
12
+
13
+ id: str
14
+ object: str = "model"
15
+ created: int = Field(default_factory=lambda: int(time.time()))
16
+ owned_by: str = "sglang"
17
+ root: Optional[str] = None
18
+
19
+
20
+ class ModelList(BaseModel):
21
+ """Model list consists of model cards."""
22
+
23
+ object: str = "list"
24
+ data: List[ModelCard] = []
25
+
26
+
10
27
  class ErrorResponse(BaseModel):
11
28
  object: str = "error"
12
29
  message: str
@@ -20,6 +20,7 @@ class SamplingParams:
20
20
  spaces_between_special_tokens: bool = True,
21
21
  dtype: Optional[str] = None,
22
22
  regex: Optional[str] = None,
23
+ n: int = 1,
23
24
  ) -> None:
24
25
  self.temperature = temperature
25
26
  self.top_p = top_p
@@ -33,6 +34,7 @@ class SamplingParams:
33
34
  self.spaces_between_special_tokens = spaces_between_special_tokens
34
35
  self.dtype = dtype
35
36
  self.regex = regex
37
+ self.n = n
36
38
 
37
39
  # Process some special cases
38
40
  if self.temperature < _SAMPLING_EPS:
sglang/srt/server.py CHANGED
@@ -26,33 +26,33 @@ import uvloop
26
26
  from fastapi import FastAPI, Request
27
27
  from fastapi.responses import JSONResponse, Response, StreamingResponse
28
28
 
29
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
30
30
  from sglang.srt.constrained import disable_cache
31
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
32
32
  from sglang.srt.managers.controller.manager_multi import (
33
33
  start_controller_process as start_controller_process_multi,
34
34
  )
35
+ from sglang.srt.managers.controller.manager_single import launch_tp_servers
35
36
  from sglang.srt.managers.controller.manager_single import (
36
- launch_tp_servers,
37
37
  start_controller_process as start_controller_process_single,
38
38
  )
39
39
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
40
40
  from sglang.srt.managers.io_struct import GenerateReqInput
41
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
42
- from sglang.srt.openai_api_adapter import (
42
+ from sglang.srt.openai_api.adapter import (
43
43
  load_chat_template_for_openai_api,
44
44
  v1_chat_completions,
45
45
  v1_completions,
46
46
  )
47
- from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
47
+ from sglang.srt.openai_api.protocol import ModelCard, ModelList
48
+ from sglang.srt.server_args import PortArgs, ServerArgs
48
49
  from sglang.srt.utils import (
49
50
  API_KEY_HEADER_NAME,
50
51
  APIKeyValidatorMiddleware,
51
52
  allocate_init_ports,
52
53
  assert_pkg_version,
53
54
  enable_show_time_cost,
54
- receive_addrs,
55
- send_addrs_to_rank_0,
55
+ set_ulimit,
56
56
  )
57
57
  from sglang.utils import get_exception_traceback
58
58
 
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
64
64
  app = FastAPI()
65
65
  tokenizer_manager = None
66
66
 
67
+ # Put some args for easily access
68
+ global_server_args_dict = {}
69
+
67
70
 
68
71
  @app.get("/health")
69
72
  async def health() -> Response:
@@ -95,6 +98,7 @@ async def flush_cache():
95
98
 
96
99
 
97
100
  async def generate_request(obj: GenerateReqInput, request: Request):
101
+ """Handle a generate request."""
98
102
  if obj.stream:
99
103
 
100
104
  async def stream_results():
@@ -135,7 +139,30 @@ async def openai_v1_chat_completions(raw_request: Request):
135
139
  return await v1_chat_completions(tokenizer_manager, raw_request)
136
140
 
137
141
 
138
- def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
142
+ @app.get("/v1/models")
143
+ def available_models():
144
+ """Show available models."""
145
+ model_names = [tokenizer_manager.model_path]
146
+ model_cards = []
147
+ for model_name in model_names:
148
+ model_cards.append(ModelCard(id=model_name, root=model_name))
149
+ return ModelList(data=model_cards)
150
+
151
+
152
+ def _set_global_server_args(server_args: ServerArgs):
153
+ global global_server_args_dict
154
+ global_server_args_dict = {
155
+ "disable_flashinfer": server_args.disable_flashinfer,
156
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
157
+ }
158
+
159
+
160
+ def launch_server(
161
+ server_args: ServerArgs,
162
+ model_overide_args: Optional[dict] = None,
163
+ pipe_finish_writer: Optional[mp.connection.Connection] = None,
164
+ ):
165
+ """Launch an HTTP server."""
139
166
  global tokenizer_manager
140
167
 
141
168
  logging.basicConfig(
@@ -146,6 +173,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
146
173
  # Set global environments
147
174
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
148
175
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
176
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
177
+ set_ulimit()
149
178
  if server_args.show_time_cost:
150
179
  enable_show_time_cost()
151
180
  if server_args.disable_disk_cache:
@@ -153,7 +182,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
153
182
  if not server_args.disable_flashinfer:
154
183
  assert_pkg_version(
155
184
  "flashinfer",
156
- "0.0.8",
185
+ "0.1.0",
157
186
  "Please uninstall the old version and "
158
187
  "reinstall the latest version by following the instructions "
159
188
  "at https://docs.flashinfer.ai/installation.html.",
@@ -161,64 +190,61 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
161
190
  if server_args.chat_template:
162
191
  # TODO: replace this with huggingface transformers template
163
192
  load_chat_template_for_openai_api(server_args.chat_template)
193
+ _set_global_server_args(server_args)
164
194
 
165
195
  # Allocate ports
166
- assert server_args.tp_size % server_args.nnodes == 0
167
- tp_size_local = server_args.tp_size // server_args.nnodes
168
196
  server_args.port, server_args.additional_ports = allocate_init_ports(
169
197
  server_args.port,
170
198
  server_args.additional_ports,
171
- tp_size_local,
172
199
  server_args.dp_size,
173
200
  )
174
-
175
201
  ports = server_args.additional_ports
176
- model_port_args = []
177
- for i in range(server_args.dp_size):
178
- model_port_args.append(
179
- ModelPortArgs(
180
- nccl_port=ports[3 + i * (tp_size_local + 1)],
181
- model_tp_ips=[None] * tp_size_local,
182
- model_tp_ports=ports[
183
- 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
184
- ],
185
- )
186
- )
187
202
  port_args = PortArgs(
188
203
  tokenizer_port=ports[0],
189
- router_port=ports[1],
204
+ controller_port=ports[1],
190
205
  detokenizer_port=ports[2],
191
- model_port_args=model_port_args,
206
+ nccl_ports=ports[3:],
192
207
  )
193
208
 
194
- # Handle multi-node tp
209
+ # Handle multi-node tensor parallelism
195
210
  if server_args.nnodes > 1:
196
211
  assert server_args.dp_size == 1, "Multi-node dp is not supported."
197
212
 
198
213
  if server_args.node_rank != 0:
199
214
  tp_size_local = server_args.tp_size // server_args.nnodes
200
- gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
201
- tp_rank_range = list(range(server_args.node_rank * tp_size_local,
202
- (server_args.node_rank + 1) * tp_size_local))
203
- procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
204
- port_args.model_port_args[0], model_overide_args)
215
+ gpu_ids = [
216
+ i for _ in range(server_args.nnodes) for i in range(tp_size_local)
217
+ ]
218
+ tp_rank_range = list(
219
+ range(
220
+ server_args.node_rank * tp_size_local,
221
+ (server_args.node_rank + 1) * tp_size_local,
222
+ )
223
+ )
224
+ procs = launch_tp_servers(
225
+ gpu_ids,
226
+ tp_rank_range,
227
+ server_args,
228
+ ports[3],
229
+ model_overide_args,
230
+ )
205
231
  while True:
206
232
  pass
207
233
 
208
234
  # Launch processes
209
235
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
210
- pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
236
+ pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
211
237
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
212
238
 
213
239
  if server_args.dp_size == 1:
214
240
  start_process = start_controller_process_single
215
241
  else:
216
242
  start_process = start_controller_process_multi
217
- proc_router = mp.Process(
243
+ proc_controller = mp.Process(
218
244
  target=start_process,
219
- args=(server_args, port_args, pipe_router_writer, model_overide_args),
245
+ args=(server_args, port_args, pipe_controller_writer, model_overide_args),
220
246
  )
221
- proc_router.start()
247
+ proc_controller.start()
222
248
  proc_detoken = mp.Process(
223
249
  target=start_detokenizer_process,
224
250
  args=(
@@ -230,68 +256,30 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
230
256
  proc_detoken.start()
231
257
 
232
258
  # Wait for the model to finish loading
233
- router_init_state = pipe_router_reader.recv()
259
+ controller_init_state = pipe_controller_reader.recv()
234
260
  detoken_init_state = pipe_detoken_reader.recv()
235
261
 
236
- if router_init_state != "init ok" or detoken_init_state != "init ok":
237
- proc_router.kill()
262
+ if controller_init_state != "init ok" or detoken_init_state != "init ok":
263
+ proc_controller.kill()
238
264
  proc_detoken.kill()
239
265
  print(
240
- f"Initialization failed. router_init_state: {router_init_state}", flush=True
266
+ f"Initialization failed. controller_init_state: {controller_init_state}",
267
+ flush=True,
241
268
  )
242
269
  print(
243
270
  f"Initialization failed. detoken_init_state: {detoken_init_state}",
244
271
  flush=True,
245
272
  )
246
273
  sys.exit(1)
247
- assert proc_router.is_alive() and proc_detoken.is_alive()
274
+ assert proc_controller.is_alive() and proc_detoken.is_alive()
248
275
 
249
276
  if server_args.api_key and server_args.api_key != "":
250
277
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
251
278
 
252
279
  # Send a warmup request
253
- def _wait_and_warmup():
254
- headers = {}
255
- url = server_args.url()
256
- if server_args.api_key:
257
- headers[API_KEY_HEADER_NAME] = server_args.api_key
258
-
259
- # Wait until the server is launched
260
- for _ in range(120):
261
- time.sleep(0.5)
262
- try:
263
- requests.get(url + "/get_model_info", timeout=5, headers=headers)
264
- break
265
- except requests.exceptions.RequestException:
266
- pass
267
-
268
- # Send a warmup request
269
- try:
270
- for _ in range(server_args.dp_size):
271
- res = requests.post(
272
- url + "/generate",
273
- json={
274
- "text": "The capital city of France is",
275
- "sampling_params": {
276
- "temperature": 0,
277
- "max_new_tokens": 8,
278
- },
279
- },
280
- headers=headers,
281
- timeout=600,
282
- )
283
- assert res.status_code == 200
284
- except Exception as e:
285
- if pipe_finish_writer is not None:
286
- pipe_finish_writer.send(get_exception_traceback())
287
- print(f"Initialization failed. warmup error: {e}", flush=True)
288
- raise e
289
-
290
- logger.info("The server is fired up and ready to roll!")
291
- if pipe_finish_writer is not None:
292
- pipe_finish_writer.send("init ok")
293
-
294
- t = threading.Thread(target=_wait_and_warmup)
280
+ t = threading.Thread(
281
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
282
+ )
295
283
  t.start()
296
284
 
297
285
  # Listen for requests
@@ -308,6 +296,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
308
296
  t.join()
309
297
 
310
298
 
299
+ def _wait_and_warmup(server_args, pipe_finish_writer):
300
+ headers = {}
301
+ url = server_args.url()
302
+ if server_args.api_key:
303
+ headers[API_KEY_HEADER_NAME] = server_args.api_key
304
+
305
+ # Wait until the server is launched
306
+ for _ in range(120):
307
+ time.sleep(0.5)
308
+ try:
309
+ requests.get(url + "/get_model_info", timeout=5, headers=headers)
310
+ break
311
+ except requests.exceptions.RequestException:
312
+ pass
313
+
314
+ # Send a warmup request
315
+ try:
316
+ for _ in range(server_args.dp_size):
317
+ res = requests.post(
318
+ url + "/generate",
319
+ json={
320
+ "text": "The capital city of France is",
321
+ "sampling_params": {
322
+ "temperature": 0,
323
+ "max_new_tokens": 8,
324
+ },
325
+ },
326
+ headers=headers,
327
+ timeout=600,
328
+ )
329
+ assert res.status_code == 200
330
+ except Exception as e:
331
+ if pipe_finish_writer is not None:
332
+ pipe_finish_writer.send(get_exception_traceback())
333
+ print(f"Initialization failed. warmup error: {e}", flush=True)
334
+ raise e
335
+
336
+ logger.info("The server is fired up and ready to roll!")
337
+ if pipe_finish_writer is not None:
338
+ pipe_finish_writer.send("init ok")
339
+
340
+
311
341
  class Runtime:
312
342
  """
313
343
  A wrapper for the server.
@@ -329,7 +359,6 @@ class Runtime:
329
359
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
330
360
  self.server_args.port,
331
361
  self.server_args.additional_ports,
332
- self.server_args.tp_size,
333
362
  self.server_args.dp_size,
334
363
  )
335
364
 
@@ -342,7 +371,7 @@ class Runtime:
342
371
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
343
372
  proc = mp.Process(
344
373
  target=launch_server,
345
- args=(self.server_args, pipe_writer, model_overide_args),
374
+ args=(self.server_args, model_overide_args, pipe_writer),
346
375
  )
347
376
  proc.start()
348
377
  pipe_writer.close()
sglang/srt/server_args.py CHANGED
@@ -33,7 +33,7 @@ class ServerArgs:
33
33
 
34
34
  # Other runtime options
35
35
  tp_size: int = 1
36
- stream_interval: int = 8
36
+ stream_interval: int = 1
37
37
  random_seed: Optional[int] = None
38
38
 
39
39
  # Logging
@@ -57,6 +57,7 @@ class ServerArgs:
57
57
  disable_disk_cache: bool = False
58
58
  attention_reduce_in_fp32: bool = False
59
59
  enable_p2p_check: bool = False
60
+ efficient_weight_load: bool = False
60
61
 
61
62
  # Distributed args
62
63
  nccl_init_addr: Optional[str] = None
@@ -166,6 +167,15 @@ class ServerArgs:
166
167
  "--quantization",
167
168
  type=str,
168
169
  default=ServerArgs.quantization,
170
+ choices=[
171
+ "awq",
172
+ "fp8",
173
+ "gptq",
174
+ "marlin",
175
+ "gptq_marlin",
176
+ "squeezellm",
177
+ "bitsandbytes",
178
+ ],
169
179
  help="The quantization method.",
170
180
  )
171
181
  parser.add_argument(
@@ -243,13 +253,13 @@ class ServerArgs:
243
253
  parser.add_argument(
244
254
  "--show-time-cost",
245
255
  action="store_true",
246
- help="Show time cost of custom marks",
256
+ help="Show time cost of custom marks.",
247
257
  )
248
258
  parser.add_argument(
249
259
  "--api-key",
250
260
  type=str,
251
261
  default=ServerArgs.api_key,
252
- help="Set API key of the server",
262
+ help="Set API key of the server.",
253
263
  )
254
264
 
255
265
  # Data parallelism
@@ -285,17 +295,17 @@ class ServerArgs:
285
295
  parser.add_argument(
286
296
  "--disable-flashinfer",
287
297
  action="store_true",
288
- help="Disable flashinfer inference kernels",
298
+ help="Disable flashinfer inference kernels.",
289
299
  )
290
300
  parser.add_argument(
291
301
  "--disable-radix-cache",
292
302
  action="store_true",
293
- help="Disable RadixAttention",
303
+ help="Disable RadixAttention for prefix caching.",
294
304
  )
295
305
  parser.add_argument(
296
306
  "--disable-regex-jump-forward",
297
307
  action="store_true",
298
- help="Disable regex jump-forward",
308
+ help="Disable regex jump-forward.",
299
309
  )
300
310
  parser.add_argument(
301
311
  "--disable-cuda-graph",
@@ -318,6 +328,11 @@ class ServerArgs:
318
328
  action="store_true",
319
329
  help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
320
330
  )
331
+ parser.add_argument(
332
+ "--efficient-weight-load",
333
+ action="store_true",
334
+ help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
335
+ )
321
336
 
322
337
  @classmethod
323
338
  def from_cli_args(cls, args: argparse.Namespace):
@@ -337,16 +352,9 @@ class ServerArgs:
337
352
  )
338
353
 
339
354
 
340
- @dataclasses.dataclass
341
- class ModelPortArgs:
342
- nccl_port: int
343
- model_tp_ips: List[str]
344
- model_tp_ports: List[int]
345
-
346
-
347
355
  @dataclasses.dataclass
348
356
  class PortArgs:
349
357
  tokenizer_port: int
350
- router_port: int
358
+ controller_port: int
351
359
  detokenizer_port: int
352
- model_port_args: List[ModelPortArgs]
360
+ nccl_ports: List[int]