sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/runtime_endpoint.py +14 -4
  4. sglang/backend/vertexai.py +5 -4
  5. sglang/bench.py +627 -0
  6. sglang/bench_latency.py +22 -20
  7. sglang/bench_serving.py +758 -0
  8. sglang/check_env.py +171 -0
  9. sglang/global_config.py +3 -1
  10. sglang/lang/backend/__init__.py +0 -0
  11. sglang/lang/backend/anthropic.py +77 -0
  12. sglang/lang/backend/base_backend.py +80 -0
  13. sglang/lang/backend/litellm.py +90 -0
  14. sglang/lang/backend/openai.py +438 -0
  15. sglang/lang/backend/runtime_endpoint.py +283 -0
  16. sglang/lang/backend/vertexai.py +149 -0
  17. sglang/lang/chat_template.py +2 -2
  18. sglang/lang/ir.py +3 -3
  19. sglang/lang/tracer.py +1 -1
  20. sglang/launch_server.py +1 -1
  21. sglang/launch_server_llavavid.py +1 -4
  22. sglang/srt/conversation.py +1 -1
  23. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  24. sglang/srt/layers/extend_attention.py +0 -39
  25. sglang/srt/layers/linear.py +869 -0
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +31 -5
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
  31. sglang/srt/managers/controller/infer_batch.py +76 -72
  32. sglang/srt/managers/controller/manager_multi.py +109 -98
  33. sglang/srt/managers/controller/manager_single.py +105 -50
  34. sglang/srt/managers/controller/model_runner.py +42 -18
  35. sglang/srt/managers/controller/radix_cache.py +4 -3
  36. sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  37. sglang/srt/managers/controller/tp_worker.py +143 -156
  38. sglang/srt/managers/detokenizer_manager.py +49 -5
  39. sglang/srt/managers/io_struct.py +36 -17
  40. sglang/srt/managers/tokenizer_manager.py +228 -125
  41. sglang/srt/memory_pool.py +46 -58
  42. sglang/srt/model_loader/model_loader.py +277 -0
  43. sglang/srt/model_loader/utils.py +260 -0
  44. sglang/srt/models/chatglm.py +1 -0
  45. sglang/srt/models/dbrx.py +1 -0
  46. sglang/srt/models/grok.py +1 -0
  47. sglang/srt/models/internlm2.py +317 -0
  48. sglang/srt/models/llama2.py +65 -16
  49. sglang/srt/models/llama_classification.py +1 -0
  50. sglang/srt/models/llava.py +1 -0
  51. sglang/srt/models/llavavid.py +1 -0
  52. sglang/srt/models/minicpm.py +2 -8
  53. sglang/srt/models/mixtral.py +1 -0
  54. sglang/srt/models/mixtral_quant.py +1 -0
  55. sglang/srt/models/qwen.py +1 -0
  56. sglang/srt/models/qwen2.py +6 -0
  57. sglang/srt/models/qwen2_moe.py +130 -108
  58. sglang/srt/models/stablelm.py +1 -0
  59. sglang/srt/openai_api/adapter.py +432 -0
  60. sglang/srt/openai_api/api_adapter.py +432 -0
  61. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  62. sglang/srt/openai_api/openai_protocol.py +207 -0
  63. sglang/srt/openai_api/protocol.py +208 -0
  64. sglang/srt/openai_protocol.py +17 -0
  65. sglang/srt/sampling_params.py +2 -0
  66. sglang/srt/server.py +114 -90
  67. sglang/srt/server_args.py +27 -17
  68. sglang/srt/utils.py +17 -118
  69. sglang/test/test_conversation.py +1 -1
  70. sglang/test/test_openai_protocol.py +1 -1
  71. sglang/test/test_programs.py +1 -1
  72. sglang/test/test_utils.py +2 -2
  73. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
  74. sglang-0.1.22.dist-info/RECORD +103 -0
  75. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  76. sglang-0.1.20.dist-info/RECORD +0 -82
  77. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  78. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ """Pydantic models for OpenAI API protocol"""
2
+
3
+ import time
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ from pydantic import BaseModel, Field
7
+ from typing_extensions import Literal
8
+
9
+
10
+ class ModelCard(BaseModel):
11
+ """Model cards."""
12
+
13
+ id: str
14
+ object: str = "model"
15
+ created: int = Field(default_factory=lambda: int(time.time()))
16
+ owned_by: str = "sglang"
17
+ root: Optional[str] = None
18
+
19
+
20
+ class ModelList(BaseModel):
21
+ """Model list consists of model cards."""
22
+
23
+ object: str = "list"
24
+ data: List[ModelCard] = []
25
+
26
+
27
+ class ErrorResponse(BaseModel):
28
+ object: str = "error"
29
+ message: str
30
+ type: str
31
+ param: Optional[str] = None
32
+ code: int
33
+
34
+
35
+ class LogProbs(BaseModel):
36
+ text_offset: List[int] = Field(default_factory=list)
37
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
38
+ tokens: List[str] = Field(default_factory=list)
39
+ top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
40
+
41
+
42
+ class UsageInfo(BaseModel):
43
+ prompt_tokens: int = 0
44
+ total_tokens: int = 0
45
+ completion_tokens: Optional[int] = 0
46
+
47
+
48
+ class CompletionRequest(BaseModel):
49
+ # Ordered by official OpenAI API documentation
50
+ # https://platform.openai.com/docs/api-reference/completions/create
51
+ model: str
52
+ prompt: Union[List[int], List[List[int]], str, List[str]]
53
+ best_of: Optional[int] = None
54
+ echo: Optional[bool] = False
55
+ frequency_penalty: Optional[float] = 0.0
56
+ logit_bias: Optional[Dict[str, float]] = None
57
+ logprobs: Optional[int] = None
58
+ max_tokens: Optional[int] = 16
59
+ n: int = 1
60
+ presence_penalty: Optional[float] = 0.0
61
+ seed: Optional[int] = None
62
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
63
+ stream: Optional[bool] = False
64
+ suffix: Optional[str] = None
65
+ temperature: Optional[float] = 1.0
66
+ top_p: Optional[float] = 1.0
67
+ user: Optional[str] = None
68
+
69
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
70
+ regex: Optional[str] = None
71
+ ignore_eos: Optional[bool] = False
72
+
73
+
74
+ class CompletionResponseChoice(BaseModel):
75
+ index: int
76
+ text: str
77
+ logprobs: Optional[LogProbs] = None
78
+ finish_reason: Optional[str] = None
79
+
80
+
81
+ class CompletionResponse(BaseModel):
82
+ id: str
83
+ object: str = "text_completion"
84
+ created: int = Field(default_factory=lambda: int(time.time()))
85
+ model: str
86
+ choices: List[CompletionResponseChoice]
87
+ usage: UsageInfo
88
+
89
+
90
+ class CompletionResponseStreamChoice(BaseModel):
91
+ index: int
92
+ text: str
93
+ logprobs: Optional[LogProbs] = None
94
+ finish_reason: Optional[str] = None
95
+
96
+
97
+ class CompletionStreamResponse(BaseModel):
98
+ id: str
99
+ object: str = "text_completion"
100
+ created: int = Field(default_factory=lambda: int(time.time()))
101
+ model: str
102
+ choices: List[CompletionResponseStreamChoice]
103
+ usage: UsageInfo
104
+
105
+
106
+ class ChatCompletionMessageGenericParam(BaseModel):
107
+ role: Literal["system", "assistant"]
108
+ content: str
109
+
110
+
111
+ class ChatCompletionMessageContentTextPart(BaseModel):
112
+ type: Literal["text"]
113
+ text: str
114
+
115
+
116
+ class ChatCompletionMessageContentImageURL(BaseModel):
117
+ url: str
118
+ detail: Optional[Literal["auto", "low", "high"]] = "auto"
119
+
120
+
121
+ class ChatCompletionMessageContentImagePart(BaseModel):
122
+ type: Literal["image_url"]
123
+ image_url: ChatCompletionMessageContentImageURL
124
+
125
+
126
+ ChatCompletionMessageContentPart = Union[
127
+ ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
128
+ ]
129
+
130
+
131
+ class ChatCompletionMessageUserParam(BaseModel):
132
+ role: Literal["user"]
133
+ content: Union[str, List[ChatCompletionMessageContentPart]]
134
+
135
+
136
+ ChatCompletionMessageParam = Union[
137
+ ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
138
+ ]
139
+
140
+
141
+ class ResponseFormat(BaseModel):
142
+ # type must be "json_object" or "text"
143
+ type: Literal["text", "json_object"]
144
+
145
+
146
+ class ChatCompletionRequest(BaseModel):
147
+ # Ordered by official OpenAI API documentation
148
+ # https://platform.openai.com/docs/api-reference/chat/create
149
+ messages: List[ChatCompletionMessageParam]
150
+ model: str
151
+ frequency_penalty: Optional[float] = 0.0
152
+ logit_bias: Optional[Dict[str, float]] = None
153
+ logprobs: Optional[bool] = False
154
+ top_logprobs: Optional[int] = None
155
+ max_tokens: Optional[int] = 16
156
+ n: Optional[int] = 1
157
+ presence_penalty: Optional[float] = 0.0
158
+ response_format: Optional[ResponseFormat] = None
159
+ seed: Optional[int] = None
160
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
161
+ stream: Optional[bool] = False
162
+ temperature: Optional[float] = 0.7
163
+ top_p: Optional[float] = 1.0
164
+ user: Optional[str] = None
165
+
166
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
167
+ regex: Optional[str] = None
168
+
169
+
170
+ class ChatMessage(BaseModel):
171
+ role: Optional[str] = None
172
+ content: Optional[str] = None
173
+
174
+
175
+ class ChatCompletionResponseChoice(BaseModel):
176
+ index: int
177
+ message: ChatMessage
178
+ logprobs: Optional[LogProbs] = None
179
+ finish_reason: Optional[str] = None
180
+
181
+
182
+ class ChatCompletionResponse(BaseModel):
183
+ id: str
184
+ object: str = "chat.completion"
185
+ created: int = Field(default_factory=lambda: int(time.time()))
186
+ model: str
187
+ choices: List[ChatCompletionResponseChoice]
188
+ usage: UsageInfo
189
+
190
+
191
+ class DeltaMessage(BaseModel):
192
+ role: Optional[str] = None
193
+ content: Optional[str] = None
194
+
195
+
196
+ class ChatCompletionResponseStreamChoice(BaseModel):
197
+ index: int
198
+ delta: DeltaMessage
199
+ logprobs: Optional[LogProbs] = None
200
+ finish_reason: Optional[str] = None
201
+
202
+
203
+ class ChatCompletionStreamResponse(BaseModel):
204
+ id: str
205
+ object: str = "chat.completion.chunk"
206
+ created: int = Field(default_factory=lambda: int(time.time()))
207
+ model: str
208
+ choices: List[ChatCompletionResponseStreamChoice]
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
7
7
  from typing_extensions import Literal
8
8
 
9
9
 
10
+ class ModelCard(BaseModel):
11
+ """Model cards."""
12
+
13
+ id: str
14
+ object: str = "model"
15
+ created: int = Field(default_factory=lambda: int(time.time()))
16
+ owned_by: str = "sglang"
17
+ root: Optional[str] = None
18
+
19
+
20
+ class ModelList(BaseModel):
21
+ """Model list consists of model cards."""
22
+
23
+ object: str = "list"
24
+ data: List[ModelCard] = []
25
+
26
+
10
27
  class ErrorResponse(BaseModel):
11
28
  object: str = "error"
12
29
  message: str
@@ -20,6 +20,7 @@ class SamplingParams:
20
20
  spaces_between_special_tokens: bool = True,
21
21
  dtype: Optional[str] = None,
22
22
  regex: Optional[str] = None,
23
+ n: int = 1,
23
24
  ) -> None:
24
25
  self.temperature = temperature
25
26
  self.top_p = top_p
@@ -33,6 +34,7 @@ class SamplingParams:
33
34
  self.spaces_between_special_tokens = spaces_between_special_tokens
34
35
  self.dtype = dtype
35
36
  self.regex = regex
37
+ self.n = n
36
38
 
37
39
  # Process some special cases
38
40
  if self.temperature < _SAMPLING_EPS:
sglang/srt/server.py CHANGED
@@ -26,34 +26,33 @@ import uvloop
26
26
  from fastapi import FastAPI, Request
27
27
  from fastapi.responses import JSONResponse, Response, StreamingResponse
28
28
 
29
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
30
30
  from sglang.srt.constrained import disable_cache
31
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
32
32
  from sglang.srt.managers.controller.manager_multi import (
33
33
  start_controller_process as start_controller_process_multi,
34
34
  )
35
+ from sglang.srt.managers.controller.manager_single import launch_tp_servers
35
36
  from sglang.srt.managers.controller.manager_single import (
36
37
  start_controller_process as start_controller_process_single,
37
38
  )
38
- from sglang.srt.managers.controller.tp_worker import ModelTpService
39
39
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
40
40
  from sglang.srt.managers.io_struct import GenerateReqInput
41
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
42
- from sglang.srt.openai_api_adapter import (
42
+ from sglang.srt.openai_api.adapter import (
43
43
  load_chat_template_for_openai_api,
44
44
  v1_chat_completions,
45
45
  v1_completions,
46
46
  )
47
- from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
47
+ from sglang.srt.openai_api.protocol import ModelCard, ModelList
48
+ from sglang.srt.server_args import PortArgs, ServerArgs
48
49
  from sglang.srt.utils import (
49
50
  API_KEY_HEADER_NAME,
50
51
  APIKeyValidatorMiddleware,
51
52
  allocate_init_ports,
52
53
  assert_pkg_version,
53
54
  enable_show_time_cost,
54
- receive_addrs,
55
- send_addrs_to_rank_0,
56
- start_rpyc_service_process,
55
+ set_ulimit,
57
56
  )
58
57
  from sglang.utils import get_exception_traceback
59
58
 
@@ -65,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
65
64
  app = FastAPI()
66
65
  tokenizer_manager = None
67
66
 
67
+ # Put some args for easily access
68
+ global_server_args_dict = {}
69
+
68
70
 
69
71
  @app.get("/health")
70
72
  async def health() -> Response:
@@ -96,6 +98,7 @@ async def flush_cache():
96
98
 
97
99
 
98
100
  async def generate_request(obj: GenerateReqInput, request: Request):
101
+ """Handle a generate request."""
99
102
  if obj.stream:
100
103
 
101
104
  async def stream_results():
@@ -136,7 +139,30 @@ async def openai_v1_chat_completions(raw_request: Request):
136
139
  return await v1_chat_completions(tokenizer_manager, raw_request)
137
140
 
138
141
 
139
- def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
142
+ @app.get("/v1/models")
143
+ def available_models():
144
+ """Show available models."""
145
+ model_names = [tokenizer_manager.model_path]
146
+ model_cards = []
147
+ for model_name in model_names:
148
+ model_cards.append(ModelCard(id=model_name, root=model_name))
149
+ return ModelList(data=model_cards)
150
+
151
+
152
+ def _set_global_server_args(server_args: ServerArgs):
153
+ global global_server_args_dict
154
+ global_server_args_dict = {
155
+ "disable_flashinfer": server_args.disable_flashinfer,
156
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
157
+ }
158
+
159
+
160
+ def launch_server(
161
+ server_args: ServerArgs,
162
+ model_overide_args: Optional[dict] = None,
163
+ pipe_finish_writer: Optional[mp.connection.Connection] = None,
164
+ ):
165
+ """Launch an HTTP server."""
140
166
  global tokenizer_manager
141
167
 
142
168
  logging.basicConfig(
@@ -147,6 +173,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
147
173
  # Set global environments
148
174
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
149
175
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
176
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
177
+ set_ulimit()
150
178
  if server_args.show_time_cost:
151
179
  enable_show_time_cost()
152
180
  if server_args.disable_disk_cache:
@@ -154,7 +182,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
154
182
  if not server_args.disable_flashinfer:
155
183
  assert_pkg_version(
156
184
  "flashinfer",
157
- "0.0.8",
185
+ "0.1.0",
158
186
  "Please uninstall the old version and "
159
187
  "reinstall the latest version by following the instructions "
160
188
  "at https://docs.flashinfer.ai/installation.html.",
@@ -162,68 +190,61 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
162
190
  if server_args.chat_template:
163
191
  # TODO: replace this with huggingface transformers template
164
192
  load_chat_template_for_openai_api(server_args.chat_template)
193
+ _set_global_server_args(server_args)
165
194
 
166
195
  # Allocate ports
167
- assert server_args.tp_size % server_args.nnodes == 0
168
- tp_size_local = server_args.tp_size // server_args.nnodes
169
196
  server_args.port, server_args.additional_ports = allocate_init_ports(
170
197
  server_args.port,
171
198
  server_args.additional_ports,
172
- tp_size_local,
173
199
  server_args.dp_size,
174
200
  )
175
-
176
201
  ports = server_args.additional_ports
177
- model_port_args = []
178
- for i in range(server_args.dp_size):
179
- model_port_args.append(
180
- ModelPortArgs(
181
- nccl_port=ports[3 + i * (tp_size_local + 1)],
182
- model_tp_ips=[None] * tp_size_local,
183
- model_tp_ports=ports[
184
- 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
185
- ],
186
- )
187
- )
188
202
  port_args = PortArgs(
189
203
  tokenizer_port=ports[0],
190
- router_port=ports[1],
204
+ controller_port=ports[1],
191
205
  detokenizer_port=ports[2],
192
- model_port_args=model_port_args,
206
+ nccl_ports=ports[3:],
193
207
  )
194
208
 
195
- # TODO multi-node dp is not supported
196
- assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
209
+ # Handle multi-node tensor parallelism
197
210
  if server_args.nnodes > 1:
211
+ assert server_args.dp_size == 1, "Multi-node dp is not supported."
212
+
198
213
  if server_args.node_rank != 0:
199
- send_addrs_to_rank_0(model_port_args[0], server_args)
200
- else:
201
- receive_addrs(model_port_args[0], server_args)
202
- for i in range(tp_size_local):
203
- start_rpyc_service_process(
204
- ModelTpService, model_port_args[0].model_tp_ports[i]
214
+ tp_size_local = server_args.tp_size // server_args.nnodes
215
+ gpu_ids = [
216
+ i for _ in range(server_args.nnodes) for i in range(tp_size_local)
217
+ ]
218
+ tp_rank_range = list(
219
+ range(
220
+ server_args.node_rank * tp_size_local,
221
+ (server_args.node_rank + 1) * tp_size_local,
222
+ )
205
223
  )
206
- if server_args.node_rank != 0:
207
- logger.info(
208
- f"[node_rank={server_args.node_rank}]: Listen for connections..."
224
+ procs = launch_tp_servers(
225
+ gpu_ids,
226
+ tp_rank_range,
227
+ server_args,
228
+ ports[3],
229
+ model_overide_args,
209
230
  )
210
231
  while True:
211
232
  pass
212
233
 
213
234
  # Launch processes
214
235
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
215
- pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
236
+ pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
216
237
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
217
238
 
218
239
  if server_args.dp_size == 1:
219
240
  start_process = start_controller_process_single
220
241
  else:
221
242
  start_process = start_controller_process_multi
222
- proc_router = mp.Process(
243
+ proc_controller = mp.Process(
223
244
  target=start_process,
224
- args=(server_args, port_args, pipe_router_writer, model_overide_args),
245
+ args=(server_args, port_args, pipe_controller_writer, model_overide_args),
225
246
  )
226
- proc_router.start()
247
+ proc_controller.start()
227
248
  proc_detoken = mp.Process(
228
249
  target=start_detokenizer_process,
229
250
  args=(
@@ -235,68 +256,30 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
235
256
  proc_detoken.start()
236
257
 
237
258
  # Wait for the model to finish loading
238
- router_init_state = pipe_router_reader.recv()
259
+ controller_init_state = pipe_controller_reader.recv()
239
260
  detoken_init_state = pipe_detoken_reader.recv()
240
261
 
241
- if router_init_state != "init ok" or detoken_init_state != "init ok":
242
- proc_router.kill()
262
+ if controller_init_state != "init ok" or detoken_init_state != "init ok":
263
+ proc_controller.kill()
243
264
  proc_detoken.kill()
244
265
  print(
245
- f"Initialization failed. router_init_state: {router_init_state}", flush=True
266
+ f"Initialization failed. controller_init_state: {controller_init_state}",
267
+ flush=True,
246
268
  )
247
269
  print(
248
270
  f"Initialization failed. detoken_init_state: {detoken_init_state}",
249
271
  flush=True,
250
272
  )
251
273
  sys.exit(1)
252
- assert proc_router.is_alive() and proc_detoken.is_alive()
274
+ assert proc_controller.is_alive() and proc_detoken.is_alive()
253
275
 
254
276
  if server_args.api_key and server_args.api_key != "":
255
277
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
256
278
 
257
279
  # Send a warmup request
258
- def _wait_and_warmup():
259
- headers = {}
260
- url = server_args.url()
261
- if server_args.api_key:
262
- headers[API_KEY_HEADER_NAME] = server_args.api_key
263
-
264
- # Wait until the server is launched
265
- for _ in range(120):
266
- time.sleep(0.5)
267
- try:
268
- requests.get(url + "/get_model_info", timeout=5, headers=headers)
269
- break
270
- except requests.exceptions.RequestException:
271
- pass
272
-
273
- # Send a warmup request
274
- try:
275
- for _ in range(server_args.dp_size):
276
- res = requests.post(
277
- url + "/generate",
278
- json={
279
- "text": "The capital city of France is",
280
- "sampling_params": {
281
- "temperature": 0,
282
- "max_new_tokens": 8,
283
- },
284
- },
285
- headers=headers,
286
- timeout=600,
287
- )
288
- assert res.status_code == 200
289
- except Exception as e:
290
- if pipe_finish_writer is not None:
291
- pipe_finish_writer.send(get_exception_traceback())
292
- print(f"Initialization failed. warmup error: {e}", flush=True)
293
- raise e
294
-
295
- logger.info("The server is fired up and ready to roll!")
296
- if pipe_finish_writer is not None:
297
- pipe_finish_writer.send("init ok")
298
-
299
- t = threading.Thread(target=_wait_and_warmup)
280
+ t = threading.Thread(
281
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
282
+ )
300
283
  t.start()
301
284
 
302
285
  # Listen for requests
@@ -313,6 +296,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
313
296
  t.join()
314
297
 
315
298
 
299
+ def _wait_and_warmup(server_args, pipe_finish_writer):
300
+ headers = {}
301
+ url = server_args.url()
302
+ if server_args.api_key:
303
+ headers[API_KEY_HEADER_NAME] = server_args.api_key
304
+
305
+ # Wait until the server is launched
306
+ for _ in range(120):
307
+ time.sleep(0.5)
308
+ try:
309
+ requests.get(url + "/get_model_info", timeout=5, headers=headers)
310
+ break
311
+ except requests.exceptions.RequestException:
312
+ pass
313
+
314
+ # Send a warmup request
315
+ try:
316
+ for _ in range(server_args.dp_size):
317
+ res = requests.post(
318
+ url + "/generate",
319
+ json={
320
+ "text": "The capital city of France is",
321
+ "sampling_params": {
322
+ "temperature": 0,
323
+ "max_new_tokens": 8,
324
+ },
325
+ },
326
+ headers=headers,
327
+ timeout=600,
328
+ )
329
+ assert res.status_code == 200
330
+ except Exception as e:
331
+ if pipe_finish_writer is not None:
332
+ pipe_finish_writer.send(get_exception_traceback())
333
+ print(f"Initialization failed. warmup error: {e}", flush=True)
334
+ raise e
335
+
336
+ logger.info("The server is fired up and ready to roll!")
337
+ if pipe_finish_writer is not None:
338
+ pipe_finish_writer.send("init ok")
339
+
340
+
316
341
  class Runtime:
317
342
  """
318
343
  A wrapper for the server.
@@ -334,7 +359,6 @@ class Runtime:
334
359
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
335
360
  self.server_args.port,
336
361
  self.server_args.additional_ports,
337
- self.server_args.tp_size,
338
362
  self.server_args.dp_size,
339
363
  )
340
364
 
@@ -347,7 +371,7 @@ class Runtime:
347
371
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
348
372
  proc = mp.Process(
349
373
  target=launch_server,
350
- args=(self.server_args, pipe_writer, model_overide_args),
374
+ args=(self.server_args, model_overide_args, pipe_writer),
351
375
  )
352
376
  proc.start()
353
377
  pipe_writer.close()