sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
1
  """Conversion between OpenAI APIs and native SRT APIs"""
2
+
3
+ import asyncio
2
4
  import json
3
5
  import os
6
+ from http import HTTPStatus
4
7
 
5
- from fastapi import HTTPException, Request
6
- from fastapi.responses import StreamingResponse
8
+ from fastapi import Request
9
+ from fastapi.responses import JSONResponse, StreamingResponse
7
10
 
8
11
  from sglang.srt.conversation import (
9
12
  Conversation,
@@ -26,14 +29,33 @@ from sglang.srt.openai_protocol import (
26
29
  CompletionResponseStreamChoice,
27
30
  CompletionStreamResponse,
28
31
  DeltaMessage,
32
+ ErrorResponse,
29
33
  LogProbs,
30
34
  UsageInfo,
31
35
  )
32
- from sglang.srt.utils import jsonify_pydantic_model
33
-
34
36
 
35
37
  chat_template_name = None
36
38
 
39
+
40
+ def create_error_response(
41
+ message: str,
42
+ err_type: str = "BadRequestError",
43
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
44
+ ):
45
+ error = ErrorResponse(message=message, type=err_type, code=status_code.value)
46
+ return JSONResponse(content=error.model_dump(), status_code=error.code)
47
+
48
+
49
+ def create_streaming_error_response(
50
+ message: str,
51
+ err_type: str = "BadRequestError",
52
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
53
+ ) -> str:
54
+ error = ErrorResponse(message=message, type=err_type, code=status_code.value)
55
+ json_str = json.dumps({"error": error.model_dump()})
56
+ return json_str
57
+
58
+
37
59
  def load_chat_template_for_openai_api(chat_template_arg):
38
60
  global chat_template_name
39
61
 
@@ -73,8 +95,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
73
95
  request_json = await raw_request.json()
74
96
  request = CompletionRequest(**request_json)
75
97
 
76
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
77
- assert request.n == 1
98
+ if request.n != 1:
99
+ return create_error_response("n != 1 is not supported")
78
100
 
79
101
  adapted_request = GenerateReqInput(
80
102
  text=request.prompt,
@@ -92,79 +114,95 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
92
114
  return_text_in_logprobs=True,
93
115
  stream=request.stream,
94
116
  )
95
- adapted_request.post_init()
96
117
 
97
118
  if adapted_request.stream:
98
119
 
99
120
  async def generate_stream_resp():
100
121
  stream_buffer = ""
101
122
  n_prev_token = 0
102
- async for content in tokenizer_manager.generate_request(adapted_request):
103
- text = content["text"]
104
- prompt_tokens = content["meta_info"]["prompt_tokens"]
105
- completion_tokens = content["meta_info"]["completion_tokens"]
106
-
107
- if not stream_buffer: # The first chunk
108
- if request.echo:
109
- # Prepend prompt in response text.
110
- text = request.prompt + text
111
-
112
- if request.logprobs:
113
- # The first chunk and echo is enabled.
114
- if not stream_buffer and request.echo:
115
- prefill_token_logprobs = content["meta_info"][
116
- "prefill_token_logprobs"
117
- ]
118
- prefill_top_logprobs = content["meta_info"][
119
- "prefill_top_logprobs"
120
- ]
123
+ try:
124
+ async for content in tokenizer_manager.generate_request(
125
+ adapted_request, raw_request
126
+ ):
127
+ text = content["text"]
128
+ prompt_tokens = content["meta_info"]["prompt_tokens"]
129
+ completion_tokens = content["meta_info"]["completion_tokens"]
130
+
131
+ if not stream_buffer: # The first chunk
132
+ if request.echo:
133
+ # Prepend prompt in response text.
134
+ text = request.prompt + text
135
+
136
+ if request.logprobs:
137
+ # The first chunk and echo is enabled.
138
+ if not stream_buffer and request.echo:
139
+ prefill_token_logprobs = content["meta_info"][
140
+ "prefill_token_logprobs"
141
+ ]
142
+ prefill_top_logprobs = content["meta_info"][
143
+ "prefill_top_logprobs"
144
+ ]
145
+ else:
146
+ prefill_token_logprobs = None
147
+ prefill_top_logprobs = None
148
+
149
+ logprobs = to_openai_style_logprobs(
150
+ prefill_token_logprobs=prefill_token_logprobs,
151
+ prefill_top_logprobs=prefill_top_logprobs,
152
+ decode_token_logprobs=content["meta_info"][
153
+ "decode_token_logprobs"
154
+ ][n_prev_token:],
155
+ decode_top_logprobs=content["meta_info"][
156
+ "decode_top_logprobs"
157
+ ][n_prev_token:],
158
+ )
159
+
160
+ n_prev_token = len(
161
+ content["meta_info"]["decode_token_logprobs"]
162
+ )
121
163
  else:
122
- prefill_token_logprobs = None
123
- prefill_top_logprobs = None
124
-
125
- logprobs = to_openai_style_logprobs(
126
- prefill_token_logprobs=prefill_token_logprobs,
127
- prefill_top_logprobs=prefill_top_logprobs,
128
- decode_token_logprobs=content["meta_info"][
129
- "decode_token_logprobs"
130
- ][n_prev_token:],
131
- decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
132
- n_prev_token:
133
- ],
134
- )
164
+ logprobs = None
135
165
 
136
- n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
137
- else:
138
- logprobs = None
139
-
140
- delta = text[len(stream_buffer) :]
141
- stream_buffer = content["text"]
142
- choice_data = CompletionResponseStreamChoice(
143
- index=0,
144
- text=delta,
145
- logprobs=logprobs,
146
- finish_reason=None,
147
- )
148
- chunk = CompletionStreamResponse(
149
- id=content["meta_info"]["id"],
150
- object="text_completion",
151
- choices=[choice_data],
152
- model=request.model,
153
- usage=UsageInfo(
154
- prompt_tokens=prompt_tokens,
155
- completion_tokens=completion_tokens,
156
- total_tokens=prompt_tokens + completion_tokens,
157
- ),
158
- )
159
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
166
+ delta = text[len(stream_buffer) :]
167
+ stream_buffer = content["text"]
168
+ choice_data = CompletionResponseStreamChoice(
169
+ index=0,
170
+ text=delta,
171
+ logprobs=logprobs,
172
+ finish_reason=content["meta_info"]["finish_reason"],
173
+ )
174
+ chunk = CompletionStreamResponse(
175
+ id=content["meta_info"]["id"],
176
+ object="text_completion",
177
+ choices=[choice_data],
178
+ model=request.model,
179
+ usage=UsageInfo(
180
+ prompt_tokens=prompt_tokens,
181
+ completion_tokens=completion_tokens,
182
+ total_tokens=prompt_tokens + completion_tokens,
183
+ ),
184
+ )
185
+ yield f"data: {chunk.model_dump_json()}\n\n"
186
+ except ValueError as e:
187
+ error = create_streaming_error_response(str(e))
188
+ yield f"data: {error}\n\n"
160
189
  yield "data: [DONE]\n\n"
161
190
 
162
- return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
191
+ return StreamingResponse(
192
+ generate_stream_resp(),
193
+ media_type="text/event-stream",
194
+ background=tokenizer_manager.create_abort_task(adapted_request),
195
+ )
163
196
 
164
197
  # Non-streaming response.
165
- ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
166
- ret = ret[0] if isinstance(ret, list) else ret
198
+ try:
199
+ ret = await tokenizer_manager.generate_request(
200
+ adapted_request, raw_request
201
+ ).__anext__()
202
+ except ValueError as e:
203
+ return create_error_response(str(e))
167
204
 
205
+ ret = ret[0] if isinstance(ret, list) else ret
168
206
  prompt_tokens = ret["meta_info"]["prompt_tokens"]
169
207
  completion_tokens = ret["meta_info"]["completion_tokens"]
170
208
  text = ret["text"]
@@ -192,7 +230,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
192
230
  index=0,
193
231
  text=text,
194
232
  logprobs=logprobs,
195
- finish_reason=None, # TODO(comaniac): Add finish reason.
233
+ finish_reason=ret["meta_info"]["finish_reason"],
196
234
  )
197
235
  response = CompletionResponse(
198
236
  id=ret["meta_info"]["id"],
@@ -211,8 +249,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
211
249
  request_json = await raw_request.json()
212
250
  request = ChatCompletionRequest(**request_json)
213
251
 
214
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
215
- assert request.n == 1
252
+ if request.n != 1:
253
+ return create_error_response("n != 1 is not supported")
216
254
 
217
255
  # Prep the data needed for the underlying GenerateReqInput:
218
256
  # - prompt: The full prompt string.
@@ -257,7 +295,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
257
295
  },
258
296
  stream=request.stream,
259
297
  )
260
- adapted_request.post_init()
261
298
 
262
299
  if adapted_request.stream:
263
300
 
@@ -265,46 +302,64 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
265
302
  is_first = True
266
303
 
267
304
  stream_buffer = ""
268
- async for content in tokenizer_manager.generate_request(adapted_request):
269
- if is_first:
270
- # First chunk with role
271
- is_first = False
305
+ try:
306
+ async for content in tokenizer_manager.generate_request(
307
+ adapted_request, raw_request
308
+ ):
309
+ if is_first:
310
+ # First chunk with role
311
+ is_first = False
312
+ choice_data = ChatCompletionResponseStreamChoice(
313
+ index=0,
314
+ delta=DeltaMessage(role="assistant"),
315
+ finish_reason=content["meta_info"]["finish_reason"],
316
+ )
317
+ chunk = ChatCompletionStreamResponse(
318
+ id=content["meta_info"]["id"],
319
+ choices=[choice_data],
320
+ model=request.model,
321
+ )
322
+ yield f"data: {chunk.model_dump_json()}\n\n"
323
+
324
+ text = content["text"]
325
+ delta = text[len(stream_buffer) :]
326
+ stream_buffer = text
272
327
  choice_data = ChatCompletionResponseStreamChoice(
273
328
  index=0,
274
- delta=DeltaMessage(role="assistant"),
275
- finish_reason=None,
329
+ delta=DeltaMessage(content=delta),
330
+ finish_reason=content["meta_info"]["finish_reason"],
276
331
  )
277
332
  chunk = ChatCompletionStreamResponse(
278
333
  id=content["meta_info"]["id"],
279
334
  choices=[choice_data],
280
335
  model=request.model,
281
336
  )
282
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
283
-
284
- text = content["text"]
285
- delta = text[len(stream_buffer) :]
286
- stream_buffer = text
287
- choice_data = ChatCompletionResponseStreamChoice(
288
- index=0, delta=DeltaMessage(content=delta), finish_reason=None
289
- )
290
- chunk = ChatCompletionStreamResponse(
291
- id=content["meta_info"]["id"],
292
- choices=[choice_data],
293
- model=request.model,
294
- )
295
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
337
+ yield f"data: {chunk.model_dump_json()}\n\n"
338
+ except ValueError as e:
339
+ error = create_streaming_error_response(str(e))
340
+ yield f"data: {error}\n\n"
296
341
  yield "data: [DONE]\n\n"
297
342
 
298
- return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
343
+ return StreamingResponse(
344
+ generate_stream_resp(),
345
+ media_type="text/event-stream",
346
+ background=tokenizer_manager.create_abort_task(adapted_request),
347
+ )
299
348
 
300
349
  # Non-streaming response.
301
- ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
350
+ try:
351
+ ret = await tokenizer_manager.generate_request(
352
+ adapted_request, raw_request
353
+ ).__anext__()
354
+ except ValueError as e:
355
+ return create_error_response(str(e))
356
+
302
357
  prompt_tokens = ret["meta_info"]["prompt_tokens"]
303
358
  completion_tokens = ret["meta_info"]["completion_tokens"]
304
359
  choice_data = ChatCompletionResponseChoice(
305
360
  index=0,
306
361
  message=ChatMessage(role="assistant", content=ret["text"]),
307
- finish_reason=None, # TODO(comaniac): Add finish reason.
362
+ finish_reason=ret["meta_info"]["finish_reason"],
308
363
  )
309
364
  response = ChatCompletionResponse(
310
365
  id=ret["meta_info"]["id"],
@@ -332,7 +387,7 @@ def to_openai_style_logprobs(
332
387
  ret_logprobs.tokens.append(token_text)
333
388
  ret_logprobs.token_logprobs.append(logprob)
334
389
 
335
- # Not Supported yet
390
+ # Not supported yet
336
391
  ret_logprobs.text_offset.append(-1)
337
392
 
338
393
  def append_top_logprobs(top_logprobs):
@@ -353,4 +408,4 @@ def to_openai_style_logprobs(
353
408
  if decode_top_logprobs is not None:
354
409
  append_top_logprobs(decode_top_logprobs)
355
410
 
356
- return ret_logprobs
411
+ return ret_logprobs
@@ -1,4 +1,5 @@
1
- """pydantic models for OpenAI API protocol"""
1
+ """Pydantic models for OpenAI API protocol"""
2
+
2
3
  import time
3
4
  from typing import Dict, List, Optional, Union
4
5
 
@@ -6,6 +7,14 @@ from pydantic import BaseModel, Field
6
7
  from typing_extensions import Literal
7
8
 
8
9
 
10
+ class ErrorResponse(BaseModel):
11
+ object: str = "error"
12
+ message: str
13
+ type: str
14
+ param: Optional[str] = None
15
+ code: int
16
+
17
+
9
18
  class LogProbs(BaseModel):
10
19
  text_offset: List[int] = Field(default_factory=list)
11
20
  token_logprobs: List[Optional[float]] = Field(default_factory=list)
@@ -178,4 +187,4 @@ class ChatCompletionStreamResponse(BaseModel):
178
187
  object: str = "chat.completion.chunk"
179
188
  created: int = Field(default_factory=lambda: int(time.time()))
180
189
  model: str
181
- choices: List[ChatCompletionResponseStreamChoice]
190
+ choices: List[ChatCompletionResponseStreamChoice]
sglang/srt/server.py CHANGED
@@ -1,4 +1,7 @@
1
- """SRT: SGLang Runtime"""
1
+ """
2
+ The entry point of inference server.
3
+ SRT = SGLang Runtime.
4
+ """
2
5
 
3
6
  import asyncio
4
7
  import dataclasses
@@ -9,7 +12,8 @@ import os
9
12
  import sys
10
13
  import threading
11
14
  import time
12
- from typing import List, Optional, Union
15
+ from http import HTTPStatus
16
+ from typing import Dict, Optional
13
17
 
14
18
  # Fix a bug of Python threading
15
19
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -25,21 +29,36 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
25
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
26
30
  from sglang.srt.constrained import disable_cache
27
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
32
+ from sglang.srt.managers.controller.manager_multi import (
33
+ start_controller_process as start_controller_process_multi,
34
+ )
35
+ from sglang.srt.managers.controller.manager_single import (
36
+ start_controller_process as start_controller_process_single,
37
+ )
38
+ from sglang.srt.managers.controller.tp_worker import ModelTpService
28
39
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
29
40
  from sglang.srt.managers.io_struct import GenerateReqInput
30
- from sglang.srt.managers.router.manager import start_router_process
31
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
32
42
  from sglang.srt.openai_api_adapter import (
33
- v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
34
- from sglang.srt.server_args import PortArgs, ServerArgs
43
+ load_chat_template_for_openai_api,
44
+ v1_chat_completions,
45
+ v1_completions,
46
+ )
47
+ from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
35
48
  from sglang.srt.utils import (
49
+ API_KEY_HEADER_NAME,
50
+ APIKeyValidatorMiddleware,
36
51
  allocate_init_ports,
37
52
  assert_pkg_version,
38
53
  enable_show_time_cost,
39
- get_exception_traceback,
40
- API_KEY_HEADER_NAME,
41
- APIKeyValidatorMiddleware
54
+ send_addrs_to_rank_0,
55
+ receive_addrs,
56
+ start_rpyc_service_process,
42
57
  )
58
+ from sglang.utils import get_exception_traceback
59
+
60
+
61
+ logger = logging.getLogger(__name__)
43
62
 
44
63
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
45
64
 
@@ -69,7 +88,7 @@ async def get_server_args():
69
88
 
70
89
  @app.get("/flush_cache")
71
90
  async def flush_cache():
72
- await tokenizer_manager.flush_cache()
91
+ tokenizer_manager.flush_cache()
73
92
  return Response(
74
93
  content="Cache flushed.\nPlease check backend logs for more details. "
75
94
  "(When there are running or waiting requests, the operation will not be performed.)\n",
@@ -77,24 +96,35 @@ async def flush_cache():
77
96
  )
78
97
 
79
98
 
80
- @app.post("/generate")
81
- async def generate_request(obj: GenerateReqInput):
82
- obj.post_init()
83
-
99
+ async def generate_request(obj: GenerateReqInput, request: Request):
84
100
  if obj.stream:
85
101
 
86
102
  async def stream_results():
87
- async for out in tokenizer_manager.generate_request(obj):
103
+ try:
104
+ async for out in tokenizer_manager.generate_request(obj, request):
105
+ yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
106
+ except ValueError as e:
107
+ out = {"error": {"message": str(e)}}
88
108
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
89
109
  yield "data: [DONE]\n\n"
90
110
 
91
- return StreamingResponse(stream_results(), media_type="text/event-stream")
111
+ return StreamingResponse(
112
+ stream_results(),
113
+ media_type="text/event-stream",
114
+ background=tokenizer_manager.create_abort_task(obj),
115
+ )
116
+ else:
117
+ try:
118
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
119
+ return ret
120
+ except ValueError as e:
121
+ return JSONResponse(
122
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
123
+ )
124
+
92
125
 
93
- try:
94
- ret = await tokenizer_manager.generate_request(obj).__anext__()
95
- return ret
96
- except ValueError as e:
97
- return JSONResponse({"error": str(e)}, status_code=400)
126
+ app.post("/generate")(generate_request)
127
+ app.put("/generate")(generate_request)
98
128
 
99
129
 
100
130
  @app.post("/v1/completions")
@@ -121,31 +151,66 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
121
151
  enable_show_time_cost()
122
152
  if server_args.disable_disk_cache:
123
153
  disable_cache()
124
- if server_args.enable_flashinfer:
125
- assert_pkg_version("flashinfer", "0.0.4")
154
+ if not server_args.disable_flashinfer:
155
+ assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
156
+ "reinstall the latest version by following the instructions "
157
+ "at https://docs.flashinfer.ai/installation.html.")
126
158
  if server_args.chat_template:
127
159
  # TODO: replace this with huggingface transformers template
128
160
  load_chat_template_for_openai_api(server_args.chat_template)
129
161
 
130
162
  # Allocate ports
163
+ assert server_args.tp_size % server_args.nnodes == 0
164
+ tp_size_local = server_args.tp_size // server_args.nnodes
131
165
  server_args.port, server_args.additional_ports = allocate_init_ports(
132
- server_args.port, server_args.additional_ports, server_args.tp_size
166
+ server_args.port,
167
+ server_args.additional_ports,
168
+ tp_size_local,
169
+ server_args.dp_size,
133
170
  )
171
+
172
+ ports = server_args.additional_ports
173
+ model_port_args = []
174
+ for i in range(server_args.dp_size):
175
+ model_port_args.append(
176
+ ModelPortArgs(
177
+ nccl_port=ports[3 + i * (tp_size_local + 1)],
178
+ model_tp_ips=[None] * tp_size_local,
179
+ model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
180
+ )
181
+ )
134
182
  port_args = PortArgs(
135
- tokenizer_port=server_args.additional_ports[0],
136
- router_port=server_args.additional_ports[1],
137
- detokenizer_port=server_args.additional_ports[2],
138
- nccl_port=server_args.additional_ports[3],
139
- model_rpc_ports=server_args.additional_ports[4:],
183
+ tokenizer_port=ports[0],
184
+ router_port=ports[1],
185
+ detokenizer_port=ports[2],
186
+ model_port_args=model_port_args,
140
187
  )
141
188
 
189
+ # TODO multi-node dp is not supported
190
+ assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
191
+ if server_args.nnodes > 1:
192
+ if server_args.node_rank != 0:
193
+ send_addrs_to_rank_0(model_port_args[0], server_args)
194
+ else:
195
+ receive_addrs(model_port_args[0], server_args)
196
+ for i in range(tp_size_local):
197
+ start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
198
+ if server_args.node_rank != 0:
199
+ logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
200
+ while True:
201
+ pass
202
+
142
203
  # Launch processes
143
204
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
144
205
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
145
206
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
146
207
 
208
+ if server_args.dp_size == 1:
209
+ start_process = start_controller_process_single
210
+ else:
211
+ start_process = start_controller_process_multi
147
212
  proc_router = mp.Process(
148
- target=start_router_process,
213
+ target=start_process,
149
214
  args=(server_args, port_args, pipe_router_writer, model_overide_args),
150
215
  )
151
216
  proc_router.start()
@@ -179,6 +244,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
179
244
  if server_args.api_key and server_args.api_key != "":
180
245
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
181
246
 
247
+ # Send a warmup request
182
248
  def _wait_and_warmup():
183
249
  headers = {}
184
250
  url = server_args.url()
@@ -190,43 +256,46 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
190
256
  time.sleep(0.5)
191
257
  try:
192
258
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
193
- success = True # Set flag to True if request succeeds
194
259
  break
195
- except requests.exceptions.RequestException as e:
260
+ except requests.exceptions.RequestException:
196
261
  pass
197
262
 
198
263
  # Send a warmup request
199
264
  try:
200
- res = requests.post(
201
- url + "/generate",
202
- json={
203
- "text": "Say this is a warmup request.",
204
- "sampling_params": {
205
- "temperature": 0,
206
- "max_new_tokens": 16,
265
+ for _ in range(server_args.dp_size):
266
+ res = requests.post(
267
+ url + "/generate",
268
+ json={
269
+ "text": "The capital city of France is",
270
+ "sampling_params": {
271
+ "temperature": 0,
272
+ "max_new_tokens": 8,
273
+ },
207
274
  },
208
- },
209
- headers=headers,
210
- timeout=600,
211
- )
212
- assert res.status_code == 200
275
+ headers=headers,
276
+ timeout=600,
277
+ )
278
+ assert res.status_code == 200
213
279
  except Exception as e:
214
280
  if pipe_finish_writer is not None:
215
281
  pipe_finish_writer.send(get_exception_traceback())
216
- print(f"Initialization failed. warmup error: {e}")
282
+ print(f"Initialization failed. warmup error: {e}", flush=True)
217
283
  raise e
218
284
 
285
+ logger.info("The server is fired up and ready to roll!")
219
286
  if pipe_finish_writer is not None:
220
287
  pipe_finish_writer.send("init ok")
221
288
 
222
289
  t = threading.Thread(target=_wait_and_warmup)
223
290
  t.start()
291
+
292
+ # Listen for requests
224
293
  try:
225
294
  uvicorn.run(
226
295
  app,
227
296
  host=server_args.host,
228
297
  port=server_args.port,
229
- log_level=server_args.log_level,
298
+ log_level=server_args.log_level_http or server_args.log_level,
230
299
  timeout_keep_alive=5,
231
300
  loop="uvloop",
232
301
  )
@@ -235,21 +304,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
235
304
 
236
305
 
237
306
  class Runtime:
307
+ """
308
+ A wrapper for the server.
309
+ This is used for launching the server in a python program without
310
+ using the commond line interface.
311
+ """
312
+
238
313
  def __init__(
239
314
  self,
240
- log_evel: str = "error",
315
+ log_level: str = "error",
241
316
  model_overide_args: Optional[dict] = None,
242
317
  *args,
243
318
  **kwargs,
244
319
  ):
245
320
  """See the arguments in server_args.py::ServerArgs"""
246
- self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
321
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
247
322
 
248
323
  # Pre-allocate ports
249
324
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
250
325
  self.server_args.port,
251
326
  self.server_args.additional_ports,
252
327
  self.server_args.tp_size,
328
+ self.server_args.dp_size,
253
329
  )
254
330
 
255
331
  self.url = self.server_args.url()
@@ -304,7 +380,7 @@ class Runtime:
304
380
  async def add_request(
305
381
  self,
306
382
  prompt: str,
307
- sampling_params,
383
+ sampling_params: Dict,
308
384
  ):
309
385
  json_data = {
310
386
  "text": prompt,