sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
1
  """Conversion between OpenAI APIs and native SRT APIs"""
2
+
3
+ import asyncio
2
4
  import json
3
5
  import os
6
+ from http import HTTPStatus
4
7
 
5
- from fastapi import HTTPException, Request
6
- from fastapi.responses import StreamingResponse
8
+ from fastapi import Request
9
+ from fastapi.responses import StreamingResponse, JSONResponse
7
10
 
8
11
  from sglang.srt.conversation import (
9
12
  Conversation,
@@ -26,14 +29,36 @@ from sglang.srt.openai_protocol import (
26
29
  CompletionResponseStreamChoice,
27
30
  CompletionStreamResponse,
28
31
  DeltaMessage,
32
+ ErrorResponse,
29
33
  LogProbs,
30
34
  UsageInfo,
31
35
  )
32
- from sglang.srt.utils import jsonify_pydantic_model
33
-
34
36
 
35
37
  chat_template_name = None
36
38
 
39
+
40
+ def create_error_response(
41
+ message: str,
42
+ err_type: str = "BadRequestError",
43
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
44
+ error = ErrorResponse(message=message,
45
+ type=err_type,
46
+ code=status_code.value)
47
+ return JSONResponse(content=error.model_dump(),
48
+ status_code=error.code)
49
+
50
+
51
+ def create_streaming_error_response(
52
+ message: str,
53
+ err_type: str = "BadRequestError",
54
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
55
+ error = ErrorResponse(message=message,
56
+ type=err_type,
57
+ code=status_code.value)
58
+ json_str = json.dumps({"error": error.model_dump()})
59
+ return json_str
60
+
61
+
37
62
  def load_chat_template_for_openai_api(chat_template_arg):
38
63
  global chat_template_name
39
64
 
@@ -73,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
73
98
  request_json = await raw_request.json()
74
99
  request = CompletionRequest(**request_json)
75
100
 
76
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
77
- assert request.n == 1
101
+ if request.n != 1:
102
+ return create_error_response("n != 1 is not supported")
78
103
 
79
104
  adapted_request = GenerateReqInput(
80
105
  text=request.prompt,
@@ -92,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
92
117
  return_text_in_logprobs=True,
93
118
  stream=request.stream,
94
119
  )
95
- adapted_request.post_init()
96
120
 
97
121
  if adapted_request.stream:
98
122
 
99
123
  async def generate_stream_resp():
100
124
  stream_buffer = ""
101
125
  n_prev_token = 0
102
- async for content in tokenizer_manager.generate_request(adapted_request):
103
- text = content["text"]
104
- prompt_tokens = content["meta_info"]["prompt_tokens"]
105
- completion_tokens = content["meta_info"]["completion_tokens"]
106
-
107
- if not stream_buffer: # The first chunk
108
- if request.echo:
109
- # Prepend prompt in response text.
110
- text = request.prompt + text
111
-
112
- if request.logprobs:
113
- # The first chunk and echo is enabled.
114
- if not stream_buffer and request.echo:
115
- prefill_token_logprobs = content["meta_info"][
116
- "prefill_token_logprobs"
117
- ]
118
- prefill_top_logprobs = content["meta_info"][
119
- "prefill_top_logprobs"
120
- ]
126
+ try:
127
+ async for content in tokenizer_manager.generate_request(
128
+ adapted_request, raw_request):
129
+ text = content["text"]
130
+ prompt_tokens = content["meta_info"]["prompt_tokens"]
131
+ completion_tokens = content["meta_info"]["completion_tokens"]
132
+
133
+ if not stream_buffer: # The first chunk
134
+ if request.echo:
135
+ # Prepend prompt in response text.
136
+ text = request.prompt + text
137
+
138
+ if request.logprobs:
139
+ # The first chunk and echo is enabled.
140
+ if not stream_buffer and request.echo:
141
+ prefill_token_logprobs = content["meta_info"][
142
+ "prefill_token_logprobs"
143
+ ]
144
+ prefill_top_logprobs = content["meta_info"][
145
+ "prefill_top_logprobs"
146
+ ]
147
+ else:
148
+ prefill_token_logprobs = None
149
+ prefill_top_logprobs = None
150
+
151
+ logprobs = to_openai_style_logprobs(
152
+ prefill_token_logprobs=prefill_token_logprobs,
153
+ prefill_top_logprobs=prefill_top_logprobs,
154
+ decode_token_logprobs=content["meta_info"][
155
+ "decode_token_logprobs"
156
+ ][n_prev_token:],
157
+ decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
158
+ n_prev_token:
159
+ ],
160
+ )
161
+
162
+ n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
121
163
  else:
122
- prefill_token_logprobs = None
123
- prefill_top_logprobs = None
124
-
125
- logprobs = to_openai_style_logprobs(
126
- prefill_token_logprobs=prefill_token_logprobs,
127
- prefill_top_logprobs=prefill_top_logprobs,
128
- decode_token_logprobs=content["meta_info"][
129
- "decode_token_logprobs"
130
- ][n_prev_token:],
131
- decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
132
- n_prev_token:
133
- ],
134
- )
164
+ logprobs = None
135
165
 
136
- n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
137
- else:
138
- logprobs = None
139
-
140
- delta = text[len(stream_buffer) :]
141
- stream_buffer = content["text"]
142
- choice_data = CompletionResponseStreamChoice(
143
- index=0,
144
- text=delta,
145
- logprobs=logprobs,
146
- finish_reason=None,
147
- )
148
- chunk = CompletionStreamResponse(
149
- id=content["meta_info"]["id"],
150
- object="text_completion",
151
- choices=[choice_data],
152
- model=request.model,
153
- usage=UsageInfo(
154
- prompt_tokens=prompt_tokens,
155
- completion_tokens=completion_tokens,
156
- total_tokens=prompt_tokens + completion_tokens,
157
- ),
158
- )
159
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
166
+ delta = text[len(stream_buffer) :]
167
+ stream_buffer = content["text"]
168
+ choice_data = CompletionResponseStreamChoice(
169
+ index=0,
170
+ text=delta,
171
+ logprobs=logprobs,
172
+ finish_reason=content["meta_info"]["finish_reason"],
173
+ )
174
+ chunk = CompletionStreamResponse(
175
+ id=content["meta_info"]["id"],
176
+ object="text_completion",
177
+ choices=[choice_data],
178
+ model=request.model,
179
+ usage=UsageInfo(
180
+ prompt_tokens=prompt_tokens,
181
+ completion_tokens=completion_tokens,
182
+ total_tokens=prompt_tokens + completion_tokens,
183
+ ),
184
+ )
185
+ yield f"data: {chunk.model_dump_json()}\n\n"
186
+ except ValueError as e:
187
+ error = create_streaming_error_response(str(e))
188
+ yield f"data: {error}\n\n"
160
189
  yield "data: [DONE]\n\n"
161
190
 
162
- return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
191
+ return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
192
+ background=tokenizer_manager.create_abort_task(adapted_request))
163
193
 
164
194
  # Non-streaming response.
165
- ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
166
- ret = ret[0] if isinstance(ret, list) else ret
195
+ try:
196
+ ret = await tokenizer_manager.generate_request(
197
+ adapted_request, raw_request).__anext__()
198
+ except ValueError as e:
199
+ return create_error_response(str(e))
167
200
 
201
+ ret = ret[0] if isinstance(ret, list) else ret
168
202
  prompt_tokens = ret["meta_info"]["prompt_tokens"]
169
203
  completion_tokens = ret["meta_info"]["completion_tokens"]
170
204
  text = ret["text"]
@@ -192,7 +226,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
192
226
  index=0,
193
227
  text=text,
194
228
  logprobs=logprobs,
195
- finish_reason=None, # TODO(comaniac): Add finish reason.
229
+ finish_reason=ret["meta_info"]["finish_reason"],
196
230
  )
197
231
  response = CompletionResponse(
198
232
  id=ret["meta_info"]["id"],
@@ -211,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
211
245
  request_json = await raw_request.json()
212
246
  request = ChatCompletionRequest(**request_json)
213
247
 
214
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
215
- assert request.n == 1
248
+ if request.n != 1:
249
+ return create_error_response("n != 1 is not supported")
216
250
 
217
251
  # Prep the data needed for the underlying GenerateReqInput:
218
252
  # - prompt: The full prompt string.
@@ -257,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
257
291
  },
258
292
  stream=request.stream,
259
293
  )
260
- adapted_request.post_init()
261
294
 
262
295
  if adapted_request.stream:
263
296
 
@@ -265,46 +298,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
265
298
  is_first = True
266
299
 
267
300
  stream_buffer = ""
268
- async for content in tokenizer_manager.generate_request(adapted_request):
269
- if is_first:
270
- # First chunk with role
271
- is_first = False
301
+ try:
302
+ async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
303
+ if is_first:
304
+ # First chunk with role
305
+ is_first = False
306
+ choice_data = ChatCompletionResponseStreamChoice(
307
+ index=0,
308
+ delta=DeltaMessage(role="assistant"),
309
+ finish_reason=content["meta_info"]["finish_reason"],
310
+ )
311
+ chunk = ChatCompletionStreamResponse(
312
+ id=content["meta_info"]["id"],
313
+ choices=[choice_data],
314
+ model=request.model,
315
+ )
316
+ yield f"data: {chunk.model_dump_json()}\n\n"
317
+
318
+ text = content["text"]
319
+ delta = text[len(stream_buffer) :]
320
+ stream_buffer = text
272
321
  choice_data = ChatCompletionResponseStreamChoice(
273
322
  index=0,
274
- delta=DeltaMessage(role="assistant"),
275
- finish_reason=None,
323
+ delta=DeltaMessage(content=delta),
324
+ finish_reason=content["meta_info"]["finish_reason"],
276
325
  )
277
326
  chunk = ChatCompletionStreamResponse(
278
327
  id=content["meta_info"]["id"],
279
328
  choices=[choice_data],
280
329
  model=request.model,
281
330
  )
282
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
283
-
284
- text = content["text"]
285
- delta = text[len(stream_buffer) :]
286
- stream_buffer = text
287
- choice_data = ChatCompletionResponseStreamChoice(
288
- index=0, delta=DeltaMessage(content=delta), finish_reason=None
289
- )
290
- chunk = ChatCompletionStreamResponse(
291
- id=content["meta_info"]["id"],
292
- choices=[choice_data],
293
- model=request.model,
294
- )
295
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
331
+ yield f"data: {chunk.model_dump_json()}\n\n"
332
+ except ValueError as e:
333
+ error = create_streaming_error_response(str(e))
334
+ yield f"data: {error}\n\n"
296
335
  yield "data: [DONE]\n\n"
297
336
 
298
- return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
337
+ return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
338
+ background=tokenizer_manager.create_abort_task(adapted_request))
299
339
 
300
340
  # Non-streaming response.
301
- ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
341
+ try:
342
+ ret = await tokenizer_manager.generate_request(
343
+ adapted_request, raw_request).__anext__()
344
+ except ValueError as e:
345
+ return create_error_response(str(e))
346
+
302
347
  prompt_tokens = ret["meta_info"]["prompt_tokens"]
303
348
  completion_tokens = ret["meta_info"]["completion_tokens"]
304
349
  choice_data = ChatCompletionResponseChoice(
305
350
  index=0,
306
351
  message=ChatMessage(role="assistant", content=ret["text"]),
307
- finish_reason=None, # TODO(comaniac): Add finish reason.
352
+ finish_reason=ret["meta_info"]["finish_reason"],
308
353
  )
309
354
  response = ChatCompletionResponse(
310
355
  id=ret["meta_info"]["id"],
@@ -332,7 +377,7 @@ def to_openai_style_logprobs(
332
377
  ret_logprobs.tokens.append(token_text)
333
378
  ret_logprobs.token_logprobs.append(logprob)
334
379
 
335
- # Not Supported yet
380
+ # Not supported yet
336
381
  ret_logprobs.text_offset.append(-1)
337
382
 
338
383
  def append_top_logprobs(top_logprobs):
@@ -353,4 +398,4 @@ def to_openai_style_logprobs(
353
398
  if decode_top_logprobs is not None:
354
399
  append_top_logprobs(decode_top_logprobs)
355
400
 
356
- return ret_logprobs
401
+ return ret_logprobs
@@ -1,4 +1,5 @@
1
1
  """pydantic models for OpenAI API protocol"""
2
+
2
3
  import time
3
4
  from typing import Dict, List, Optional, Union
4
5
 
@@ -6,6 +7,14 @@ from pydantic import BaseModel, Field
6
7
  from typing_extensions import Literal
7
8
 
8
9
 
10
+ class ErrorResponse(BaseModel):
11
+ object: str = "error"
12
+ message: str
13
+ type: str
14
+ param: Optional[str] = None
15
+ code: int
16
+
17
+
9
18
  class LogProbs(BaseModel):
10
19
  text_offset: List[int] = Field(default_factory=list)
11
20
  token_logprobs: List[Optional[float]] = Field(default_factory=list)
@@ -178,4 +187,4 @@ class ChatCompletionStreamResponse(BaseModel):
178
187
  object: str = "chat.completion.chunk"
179
188
  created: int = Field(default_factory=lambda: int(time.time()))
180
189
  model: str
181
- choices: List[ChatCompletionResponseStreamChoice]
190
+ choices: List[ChatCompletionResponseStreamChoice]
sglang/srt/server.py CHANGED
@@ -9,7 +9,8 @@ import os
9
9
  import sys
10
10
  import threading
11
11
  import time
12
- from typing import List, Optional, Union
12
+ from http import HTTPStatus
13
+ from typing import Optional
13
14
 
14
15
  # Fix a bug of Python threading
15
16
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -27,19 +28,23 @@ from sglang.srt.constrained import disable_cache
27
28
  from sglang.srt.hf_transformers_utils import get_tokenizer
28
29
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
29
30
  from sglang.srt.managers.io_struct import GenerateReqInput
30
- from sglang.srt.managers.router.manager import start_router_process
31
+ from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
32
+ from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
31
33
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
32
34
  from sglang.srt.openai_api_adapter import (
33
- v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
34
- from sglang.srt.server_args import PortArgs, ServerArgs
35
+ load_chat_template_for_openai_api,
36
+ v1_chat_completions,
37
+ v1_completions,
38
+ )
39
+ from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
35
40
  from sglang.srt.utils import (
41
+ API_KEY_HEADER_NAME,
42
+ APIKeyValidatorMiddleware,
36
43
  allocate_init_ports,
37
44
  assert_pkg_version,
38
45
  enable_show_time_cost,
39
- get_exception_traceback,
40
- API_KEY_HEADER_NAME,
41
- APIKeyValidatorMiddleware
42
46
  )
47
+ from sglang.utils import get_exception_traceback
43
48
 
44
49
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
45
50
 
@@ -69,7 +74,7 @@ async def get_server_args():
69
74
 
70
75
  @app.get("/flush_cache")
71
76
  async def flush_cache():
72
- await tokenizer_manager.flush_cache()
77
+ tokenizer_manager.flush_cache()
73
78
  return Response(
74
79
  content="Cache flushed.\nPlease check backend logs for more details. "
75
80
  "(When there are running or waiting requests, the operation will not be performed.)\n",
@@ -77,24 +82,32 @@ async def flush_cache():
77
82
  )
78
83
 
79
84
 
80
- @app.post("/generate")
81
- async def generate_request(obj: GenerateReqInput):
82
- obj.post_init()
83
-
85
+ async def generate_request(obj: GenerateReqInput, request: Request):
84
86
  if obj.stream:
85
87
 
86
88
  async def stream_results():
87
- async for out in tokenizer_manager.generate_request(obj):
89
+ try:
90
+ async for out in tokenizer_manager.generate_request(obj, request):
91
+ yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
92
+ except ValueError as e:
93
+ out = {"error": {"message": str(e)}}
88
94
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
89
95
  yield "data: [DONE]\n\n"
90
96
 
91
- return StreamingResponse(stream_results(), media_type="text/event-stream")
97
+ return StreamingResponse(stream_results(), media_type="text/event-stream",
98
+ background=tokenizer_manager.create_abort_task(obj))
99
+ else:
100
+ try:
101
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
102
+ return ret
103
+ except ValueError as e:
104
+ return JSONResponse(
105
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
106
+ )
107
+
92
108
 
93
- try:
94
- ret = await tokenizer_manager.generate_request(obj).__anext__()
95
- return ret
96
- except ValueError as e:
97
- return JSONResponse({"error": str(e)}, status_code=400)
109
+ app.post("/generate")(generate_request)
110
+ app.put("/generate")(generate_request)
98
111
 
99
112
 
100
113
  @app.post("/v1/completions")
@@ -129,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
129
142
 
130
143
  # Allocate ports
131
144
  server_args.port, server_args.additional_ports = allocate_init_ports(
132
- server_args.port, server_args.additional_ports, server_args.tp_size
145
+ server_args.port,
146
+ server_args.additional_ports,
147
+ server_args.tp_size,
148
+ server_args.dp_size,
133
149
  )
150
+
151
+ # Init local models port args
152
+ ports = server_args.additional_ports
153
+ tp = server_args.tp_size
154
+ model_port_args = []
155
+ for i in range(server_args.dp_size):
156
+ model_port_args.append(
157
+ ModelPortArgs(
158
+ nccl_port=ports[3 + i * (tp + 1)],
159
+ model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
160
+ )
161
+ )
134
162
  port_args = PortArgs(
135
- tokenizer_port=server_args.additional_ports[0],
136
- router_port=server_args.additional_ports[1],
137
- detokenizer_port=server_args.additional_ports[2],
138
- nccl_port=server_args.additional_ports[3],
139
- model_rpc_ports=server_args.additional_ports[4:],
163
+ tokenizer_port=ports[0],
164
+ router_port=ports[1],
165
+ detokenizer_port=ports[2],
166
+ model_port_args=model_port_args,
140
167
  )
141
168
 
142
169
  # Launch processes
@@ -144,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
144
171
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
145
172
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
146
173
 
174
+ if server_args.dp_size == 1:
175
+ start_process = start_controller_process_single
176
+ else:
177
+ start_process = start_controller_process_multi
147
178
  proc_router = mp.Process(
148
- target=start_router_process,
179
+ target=start_process,
149
180
  args=(server_args, port_args, pipe_router_writer, model_overide_args),
150
181
  )
151
182
  proc_router.start()
@@ -179,6 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
179
210
  if server_args.api_key and server_args.api_key != "":
180
211
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
181
212
 
213
+ # Send a warmup request
182
214
  def _wait_and_warmup():
183
215
  headers = {}
184
216
  url = server_args.url()
@@ -190,27 +222,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
190
222
  time.sleep(0.5)
191
223
  try:
192
224
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
193
- success = True # Set flag to True if request succeeds
194
225
  break
195
226
  except requests.exceptions.RequestException as e:
196
227
  pass
197
228
 
198
229
  # Send a warmup request
199
230
  try:
200
- res = requests.post(
201
- url + "/generate",
202
- json={
203
- "text": "Say this is a warmup request.",
204
- "sampling_params": {
205
- "temperature": 0,
206
- "max_new_tokens": 16,
231
+ for _ in range(server_args.dp_size):
232
+ res = requests.post(
233
+ url + "/generate",
234
+ json={
235
+ "text": "The capital city of France is",
236
+ "sampling_params": {
237
+ "temperature": 0,
238
+ "max_new_tokens": 16,
239
+ },
207
240
  },
208
- },
209
- headers=headers,
210
- timeout=600,
211
- )
212
- assert res.status_code == 200
213
- except Exception as e:
241
+ headers=headers,
242
+ timeout=600,
243
+ )
244
+ assert res.status_code == 200
245
+ except Exception:
214
246
  if pipe_finish_writer is not None:
215
247
  pipe_finish_writer.send(get_exception_traceback())
216
248
  print(f"Initialization failed. warmup error: {e}")
@@ -221,6 +253,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
221
253
 
222
254
  t = threading.Thread(target=_wait_and_warmup)
223
255
  t.start()
256
+
257
+ # Listen for requests
224
258
  try:
225
259
  uvicorn.run(
226
260
  app,
@@ -237,19 +271,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
237
271
  class Runtime:
238
272
  def __init__(
239
273
  self,
240
- log_evel: str = "error",
274
+ log_level: str = "error",
241
275
  model_overide_args: Optional[dict] = None,
242
276
  *args,
243
277
  **kwargs,
244
278
  ):
245
279
  """See the arguments in server_args.py::ServerArgs"""
246
- self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
280
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
247
281
 
248
282
  # Pre-allocate ports
249
283
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
250
284
  self.server_args.port,
251
285
  self.server_args.additional_ports,
252
286
  self.server_args.tp_size,
287
+ self.server_args.dp_size,
253
288
  )
254
289
 
255
290
  self.url = self.server_args.url()