sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ """Pydantic models for OpenAI API protocol"""
2
+
3
+ import time
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ from pydantic import BaseModel, Field
7
+ from typing_extensions import Literal
8
+
9
+
10
+ class ModelCard(BaseModel):
11
+ """Model cards."""
12
+
13
+ id: str
14
+ object: str = "model"
15
+ created: int = Field(default_factory=lambda: int(time.time()))
16
+ owned_by: str = "sglang"
17
+ root: Optional[str] = None
18
+
19
+
20
+ class ModelList(BaseModel):
21
+ """Model list consists of model cards."""
22
+
23
+ object: str = "list"
24
+ data: List[ModelCard] = []
25
+
26
+
27
+ class ErrorResponse(BaseModel):
28
+ object: str = "error"
29
+ message: str
30
+ type: str
31
+ param: Optional[str] = None
32
+ code: int
33
+
34
+
35
+ class LogProbs(BaseModel):
36
+ text_offset: List[int] = Field(default_factory=list)
37
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
38
+ tokens: List[str] = Field(default_factory=list)
39
+ top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
40
+
41
+
42
+ class UsageInfo(BaseModel):
43
+ prompt_tokens: int = 0
44
+ total_tokens: int = 0
45
+ completion_tokens: Optional[int] = 0
46
+
47
+
48
+ class CompletionRequest(BaseModel):
49
+ # Ordered by official OpenAI API documentation
50
+ # https://platform.openai.com/docs/api-reference/completions/create
51
+ model: str
52
+ prompt: Union[List[int], List[List[int]], str, List[str]]
53
+ best_of: Optional[int] = None
54
+ echo: Optional[bool] = False
55
+ frequency_penalty: Optional[float] = 0.0
56
+ logit_bias: Optional[Dict[str, float]] = None
57
+ logprobs: Optional[int] = None
58
+ max_tokens: Optional[int] = 16
59
+ n: int = 1
60
+ presence_penalty: Optional[float] = 0.0
61
+ seed: Optional[int] = None
62
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
63
+ stream: Optional[bool] = False
64
+ suffix: Optional[str] = None
65
+ temperature: Optional[float] = 1.0
66
+ top_p: Optional[float] = 1.0
67
+ user: Optional[str] = None
68
+
69
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
70
+ regex: Optional[str] = None
71
+ ignore_eos: Optional[bool] = False
72
+
73
+
74
+ class CompletionResponseChoice(BaseModel):
75
+ index: int
76
+ text: str
77
+ logprobs: Optional[LogProbs] = None
78
+ finish_reason: Optional[str] = None
79
+
80
+
81
+ class CompletionResponse(BaseModel):
82
+ id: str
83
+ object: str = "text_completion"
84
+ created: int = Field(default_factory=lambda: int(time.time()))
85
+ model: str
86
+ choices: List[CompletionResponseChoice]
87
+ usage: UsageInfo
88
+
89
+
90
+ class CompletionResponseStreamChoice(BaseModel):
91
+ index: int
92
+ text: str
93
+ logprobs: Optional[LogProbs] = None
94
+ finish_reason: Optional[str] = None
95
+
96
+
97
+ class CompletionStreamResponse(BaseModel):
98
+ id: str
99
+ object: str = "text_completion"
100
+ created: int = Field(default_factory=lambda: int(time.time()))
101
+ model: str
102
+ choices: List[CompletionResponseStreamChoice]
103
+ usage: UsageInfo
104
+
105
+
106
+ class ChatCompletionMessageGenericParam(BaseModel):
107
+ role: Literal["system", "assistant"]
108
+ content: str
109
+
110
+
111
+ class ChatCompletionMessageContentTextPart(BaseModel):
112
+ type: Literal["text"]
113
+ text: str
114
+
115
+
116
+ class ChatCompletionMessageContentImageURL(BaseModel):
117
+ url: str
118
+ detail: Optional[Literal["auto", "low", "high"]] = "auto"
119
+
120
+
121
+ class ChatCompletionMessageContentImagePart(BaseModel):
122
+ type: Literal["image_url"]
123
+ image_url: ChatCompletionMessageContentImageURL
124
+
125
+
126
+ ChatCompletionMessageContentPart = Union[
127
+ ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
128
+ ]
129
+
130
+
131
+ class ChatCompletionMessageUserParam(BaseModel):
132
+ role: Literal["user"]
133
+ content: Union[str, List[ChatCompletionMessageContentPart]]
134
+
135
+
136
+ ChatCompletionMessageParam = Union[
137
+ ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
138
+ ]
139
+
140
+
141
+ class ResponseFormat(BaseModel):
142
+ # type must be "json_object" or "text"
143
+ type: Literal["text", "json_object"]
144
+
145
+
146
+ class ChatCompletionRequest(BaseModel):
147
+ # Ordered by official OpenAI API documentation
148
+ # https://platform.openai.com/docs/api-reference/chat/create
149
+ messages: List[ChatCompletionMessageParam]
150
+ model: str
151
+ frequency_penalty: Optional[float] = 0.0
152
+ logit_bias: Optional[Dict[str, float]] = None
153
+ logprobs: Optional[bool] = False
154
+ top_logprobs: Optional[int] = None
155
+ max_tokens: Optional[int] = 16
156
+ n: Optional[int] = 1
157
+ presence_penalty: Optional[float] = 0.0
158
+ response_format: Optional[ResponseFormat] = None
159
+ seed: Optional[int] = None
160
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
161
+ stream: Optional[bool] = False
162
+ temperature: Optional[float] = 0.7
163
+ top_p: Optional[float] = 1.0
164
+ user: Optional[str] = None
165
+
166
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
167
+ regex: Optional[str] = None
168
+
169
+
170
+ class ChatMessage(BaseModel):
171
+ role: Optional[str] = None
172
+ content: Optional[str] = None
173
+
174
+
175
+ class ChatCompletionResponseChoice(BaseModel):
176
+ index: int
177
+ message: ChatMessage
178
+ logprobs: Optional[LogProbs] = None
179
+ finish_reason: Optional[str] = None
180
+
181
+
182
+ class ChatCompletionResponse(BaseModel):
183
+ id: str
184
+ object: str = "chat.completion"
185
+ created: int = Field(default_factory=lambda: int(time.time()))
186
+ model: str
187
+ choices: List[ChatCompletionResponseChoice]
188
+ usage: UsageInfo
189
+
190
+
191
+ class DeltaMessage(BaseModel):
192
+ role: Optional[str] = None
193
+ content: Optional[str] = None
194
+
195
+
196
+ class ChatCompletionResponseStreamChoice(BaseModel):
197
+ index: int
198
+ delta: DeltaMessage
199
+ logprobs: Optional[LogProbs] = None
200
+ finish_reason: Optional[str] = None
201
+
202
+
203
+ class ChatCompletionStreamResponse(BaseModel):
204
+ id: str
205
+ object: str = "chat.completion.chunk"
206
+ created: int = Field(default_factory=lambda: int(time.time()))
207
+ model: str
208
+ choices: List[ChatCompletionResponseStreamChoice]
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
7
7
  from typing_extensions import Literal
8
8
 
9
9
 
10
+ class ModelCard(BaseModel):
11
+ """Model cards."""
12
+
13
+ id: str
14
+ object: str = "model"
15
+ created: int = Field(default_factory=lambda: int(time.time()))
16
+ owned_by: str = "sglang"
17
+ root: Optional[str] = None
18
+
19
+
20
+ class ModelList(BaseModel):
21
+ """Model list consists of model cards."""
22
+
23
+ object: str = "list"
24
+ data: List[ModelCard] = []
25
+
26
+
10
27
  class ErrorResponse(BaseModel):
11
28
  object: str = "error"
12
29
  message: str
@@ -20,6 +20,7 @@ class SamplingParams:
20
20
  spaces_between_special_tokens: bool = True,
21
21
  dtype: Optional[str] = None,
22
22
  regex: Optional[str] = None,
23
+ n: int = 1,
23
24
  ) -> None:
24
25
  self.temperature = temperature
25
26
  self.top_p = top_p
@@ -33,6 +34,7 @@ class SamplingParams:
33
34
  self.spaces_between_special_tokens = spaces_between_special_tokens
34
35
  self.dtype = dtype
35
36
  self.regex = regex
37
+ self.n = n
36
38
 
37
39
  # Process some special cases
38
40
  if self.temperature < _SAMPLING_EPS:
sglang/srt/server.py CHANGED
@@ -26,33 +26,33 @@ import uvloop
26
26
  from fastapi import FastAPI, Request
27
27
  from fastapi.responses import JSONResponse, Response, StreamingResponse
28
28
 
29
- from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
30
30
  from sglang.srt.constrained import disable_cache
31
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
32
32
  from sglang.srt.managers.controller.manager_multi import (
33
33
  start_controller_process as start_controller_process_multi,
34
34
  )
35
+ from sglang.srt.managers.controller.manager_single import launch_tp_servers
35
36
  from sglang.srt.managers.controller.manager_single import (
36
- launch_tp_servers,
37
37
  start_controller_process as start_controller_process_single,
38
38
  )
39
39
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
40
40
  from sglang.srt.managers.io_struct import GenerateReqInput
41
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
42
- from sglang.srt.openai_api_adapter import (
42
+ from sglang.srt.openai_api.adapter import (
43
43
  load_chat_template_for_openai_api,
44
44
  v1_chat_completions,
45
45
  v1_completions,
46
46
  )
47
- from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
47
+ from sglang.srt.openai_api.protocol import ModelCard, ModelList
48
+ from sglang.srt.server_args import PortArgs, ServerArgs
48
49
  from sglang.srt.utils import (
49
50
  API_KEY_HEADER_NAME,
50
51
  APIKeyValidatorMiddleware,
51
52
  allocate_init_ports,
52
53
  assert_pkg_version,
53
54
  enable_show_time_cost,
54
- receive_addrs,
55
- send_addrs_to_rank_0,
55
+ set_ulimit,
56
56
  )
57
57
  from sglang.utils import get_exception_traceback
58
58
 
@@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
64
64
  app = FastAPI()
65
65
  tokenizer_manager = None
66
66
 
67
+ # Put some args for easily access
68
+ global_server_args_dict = {}
69
+
67
70
 
68
71
  @app.get("/health")
69
72
  async def health() -> Response:
@@ -95,6 +98,7 @@ async def flush_cache():
95
98
 
96
99
 
97
100
  async def generate_request(obj: GenerateReqInput, request: Request):
101
+ """Handle a generate request."""
98
102
  if obj.stream:
99
103
 
100
104
  async def stream_results():
@@ -135,7 +139,43 @@ async def openai_v1_chat_completions(raw_request: Request):
135
139
  return await v1_chat_completions(tokenizer_manager, raw_request)
136
140
 
137
141
 
138
- def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
142
+ @app.get("/v1/models")
143
+ def available_models():
144
+ """Show available models."""
145
+ model_names = [tokenizer_manager.model_path]
146
+ model_cards = []
147
+ for model_name in model_names:
148
+ model_cards.append(ModelCard(id=model_name, root=model_name))
149
+ return ModelList(data=model_cards)
150
+
151
+
152
+ def _set_global_server_args(server_args: ServerArgs):
153
+ global global_server_args_dict
154
+ global_server_args_dict = {
155
+ "disable_flashinfer": server_args.disable_flashinfer,
156
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
157
+ }
158
+
159
+
160
+ def _set_torch_compile_config():
161
+ # The following configurations are for torch compile optimizations
162
+ import torch._dynamo.config
163
+ import torch._inductor.config
164
+
165
+ torch._inductor.config.coordinate_descent_tuning = True
166
+ torch._inductor.config.triton.unique_kernel_names = True
167
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
168
+
169
+ # FIXME: tmp workaround
170
+ torch._dynamo.config.accumulated_cache_size_limit = 256
171
+
172
+
173
+ def launch_server(
174
+ server_args: ServerArgs,
175
+ model_overide_args: Optional[dict] = None,
176
+ pipe_finish_writer: Optional[mp.connection.Connection] = None,
177
+ ):
178
+ """Launch an HTTP server."""
139
179
  global tokenizer_manager
140
180
 
141
181
  logging.basicConfig(
@@ -146,6 +186,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
146
186
  # Set global environments
147
187
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
148
188
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
189
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
190
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
191
+ set_ulimit()
149
192
  if server_args.show_time_cost:
150
193
  enable_show_time_cost()
151
194
  if server_args.disable_disk_cache:
@@ -153,7 +196,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
153
196
  if not server_args.disable_flashinfer:
154
197
  assert_pkg_version(
155
198
  "flashinfer",
156
- "0.0.8",
199
+ "0.1.1",
157
200
  "Please uninstall the old version and "
158
201
  "reinstall the latest version by following the instructions "
159
202
  "at https://docs.flashinfer.ai/installation.html.",
@@ -162,63 +205,65 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
162
205
  # TODO: replace this with huggingface transformers template
163
206
  load_chat_template_for_openai_api(server_args.chat_template)
164
207
 
208
+ if server_args.enable_torch_compile:
209
+ _set_torch_compile_config()
210
+
211
+ _set_global_server_args(server_args)
212
+
165
213
  # Allocate ports
166
- assert server_args.tp_size % server_args.nnodes == 0
167
- tp_size_local = server_args.tp_size // server_args.nnodes
168
214
  server_args.port, server_args.additional_ports = allocate_init_ports(
169
215
  server_args.port,
170
216
  server_args.additional_ports,
171
- tp_size_local,
172
217
  server_args.dp_size,
173
218
  )
174
-
175
219
  ports = server_args.additional_ports
176
- model_port_args = []
177
- for i in range(server_args.dp_size):
178
- model_port_args.append(
179
- ModelPortArgs(
180
- nccl_port=ports[3 + i * (tp_size_local + 1)],
181
- model_tp_ips=[None] * tp_size_local,
182
- model_tp_ports=ports[
183
- 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
184
- ],
185
- )
186
- )
187
220
  port_args = PortArgs(
188
221
  tokenizer_port=ports[0],
189
- router_port=ports[1],
222
+ controller_port=ports[1],
190
223
  detokenizer_port=ports[2],
191
- model_port_args=model_port_args,
224
+ nccl_ports=ports[3:],
192
225
  )
226
+ logger.info(f"{server_args=}")
193
227
 
194
- # Handle multi-node tp
228
+ # Handle multi-node tensor parallelism
195
229
  if server_args.nnodes > 1:
196
230
  assert server_args.dp_size == 1, "Multi-node dp is not supported."
197
231
 
198
232
  if server_args.node_rank != 0:
199
233
  tp_size_local = server_args.tp_size // server_args.nnodes
200
- gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
201
- tp_rank_range = list(range(server_args.node_rank * tp_size_local,
202
- (server_args.node_rank + 1) * tp_size_local))
203
- procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
204
- port_args.model_port_args[0], model_overide_args)
234
+ gpu_ids = [
235
+ i for _ in range(server_args.nnodes) for i in range(tp_size_local)
236
+ ]
237
+ tp_rank_range = list(
238
+ range(
239
+ server_args.node_rank * tp_size_local,
240
+ (server_args.node_rank + 1) * tp_size_local,
241
+ )
242
+ )
243
+ procs = launch_tp_servers(
244
+ gpu_ids,
245
+ tp_rank_range,
246
+ server_args,
247
+ ports[3],
248
+ model_overide_args,
249
+ )
205
250
  while True:
206
251
  pass
207
252
 
208
253
  # Launch processes
209
254
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
210
- pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
255
+ pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
211
256
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
212
257
 
213
258
  if server_args.dp_size == 1:
214
259
  start_process = start_controller_process_single
215
260
  else:
216
261
  start_process = start_controller_process_multi
217
- proc_router = mp.Process(
262
+ proc_controller = mp.Process(
218
263
  target=start_process,
219
- args=(server_args, port_args, pipe_router_writer, model_overide_args),
264
+ args=(server_args, port_args, pipe_controller_writer, model_overide_args),
220
265
  )
221
- proc_router.start()
266
+ proc_controller.start()
222
267
  proc_detoken = mp.Process(
223
268
  target=start_detokenizer_process,
224
269
  args=(
@@ -230,68 +275,30 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
230
275
  proc_detoken.start()
231
276
 
232
277
  # Wait for the model to finish loading
233
- router_init_state = pipe_router_reader.recv()
278
+ controller_init_state = pipe_controller_reader.recv()
234
279
  detoken_init_state = pipe_detoken_reader.recv()
235
280
 
236
- if router_init_state != "init ok" or detoken_init_state != "init ok":
237
- proc_router.kill()
281
+ if controller_init_state != "init ok" or detoken_init_state != "init ok":
282
+ proc_controller.kill()
238
283
  proc_detoken.kill()
239
284
  print(
240
- f"Initialization failed. router_init_state: {router_init_state}", flush=True
285
+ f"Initialization failed. controller_init_state: {controller_init_state}",
286
+ flush=True,
241
287
  )
242
288
  print(
243
289
  f"Initialization failed. detoken_init_state: {detoken_init_state}",
244
290
  flush=True,
245
291
  )
246
292
  sys.exit(1)
247
- assert proc_router.is_alive() and proc_detoken.is_alive()
293
+ assert proc_controller.is_alive() and proc_detoken.is_alive()
248
294
 
249
295
  if server_args.api_key and server_args.api_key != "":
250
296
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
251
297
 
252
298
  # Send a warmup request
253
- def _wait_and_warmup():
254
- headers = {}
255
- url = server_args.url()
256
- if server_args.api_key:
257
- headers[API_KEY_HEADER_NAME] = server_args.api_key
258
-
259
- # Wait until the server is launched
260
- for _ in range(120):
261
- time.sleep(0.5)
262
- try:
263
- requests.get(url + "/get_model_info", timeout=5, headers=headers)
264
- break
265
- except requests.exceptions.RequestException:
266
- pass
267
-
268
- # Send a warmup request
269
- try:
270
- for _ in range(server_args.dp_size):
271
- res = requests.post(
272
- url + "/generate",
273
- json={
274
- "text": "The capital city of France is",
275
- "sampling_params": {
276
- "temperature": 0,
277
- "max_new_tokens": 8,
278
- },
279
- },
280
- headers=headers,
281
- timeout=600,
282
- )
283
- assert res.status_code == 200
284
- except Exception as e:
285
- if pipe_finish_writer is not None:
286
- pipe_finish_writer.send(get_exception_traceback())
287
- print(f"Initialization failed. warmup error: {e}", flush=True)
288
- raise e
289
-
290
- logger.info("The server is fired up and ready to roll!")
291
- if pipe_finish_writer is not None:
292
- pipe_finish_writer.send("init ok")
293
-
294
- t = threading.Thread(target=_wait_and_warmup)
299
+ t = threading.Thread(
300
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
301
+ )
295
302
  t.start()
296
303
 
297
304
  # Listen for requests
@@ -308,6 +315,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
308
315
  t.join()
309
316
 
310
317
 
318
+ def _wait_and_warmup(server_args, pipe_finish_writer):
319
+ headers = {}
320
+ url = server_args.url()
321
+ if server_args.api_key:
322
+ headers[API_KEY_HEADER_NAME] = server_args.api_key
323
+
324
+ # Wait until the server is launched
325
+ for _ in range(120):
326
+ time.sleep(0.5)
327
+ try:
328
+ requests.get(url + "/get_model_info", timeout=5, headers=headers)
329
+ break
330
+ except requests.exceptions.RequestException:
331
+ pass
332
+
333
+ # Send a warmup request
334
+ try:
335
+ for _ in range(server_args.dp_size):
336
+ res = requests.post(
337
+ url + "/generate",
338
+ json={
339
+ "text": "The capital city of France is",
340
+ "sampling_params": {
341
+ "temperature": 0,
342
+ "max_new_tokens": 8,
343
+ },
344
+ },
345
+ headers=headers,
346
+ timeout=600,
347
+ )
348
+ assert res.status_code == 200
349
+ except Exception as e:
350
+ if pipe_finish_writer is not None:
351
+ pipe_finish_writer.send(get_exception_traceback())
352
+ print(f"Initialization failed. warmup error: {e}", flush=True)
353
+ raise e
354
+
355
+ logger.info("The server is fired up and ready to roll!")
356
+ if pipe_finish_writer is not None:
357
+ pipe_finish_writer.send("init ok")
358
+
359
+
311
360
  class Runtime:
312
361
  """
313
362
  A wrapper for the server.
@@ -329,7 +378,6 @@ class Runtime:
329
378
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
330
379
  self.server_args.port,
331
380
  self.server_args.additional_ports,
332
- self.server_args.tp_size,
333
381
  self.server_args.dp_size,
334
382
  )
335
383
 
@@ -342,7 +390,7 @@ class Runtime:
342
390
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
343
391
  proc = mp.Process(
344
392
  target=launch_server,
345
- args=(self.server_args, pipe_writer, model_overide_args),
393
+ args=(self.server_args, model_overide_args, pipe_writer),
346
394
  )
347
395
  proc.start()
348
396
  pipe_writer.close()