sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -1,97 +1,68 @@
1
- """SRT: SGLang Runtime"""
1
+ """
2
+ The entry point of inference server.
3
+ SRT = SGLang Runtime.
4
+ """
2
5
 
3
6
  import asyncio
4
7
  import dataclasses
5
8
  import json
9
+ import logging
6
10
  import multiprocessing as mp
7
11
  import os
8
12
  import sys
9
13
  import threading
10
14
  import time
11
- from typing import List, Optional, Union
15
+ from http import HTTPStatus
16
+ from typing import Dict, Optional
12
17
 
13
- # Fix a Python bug
18
+ # Fix a bug of Python threading
14
19
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
15
20
 
16
21
  import aiohttp
17
22
  import psutil
18
- import pydantic
19
23
  import requests
20
24
  import uvicorn
21
25
  import uvloop
22
- from fastapi import FastAPI, HTTPException, Request
23
- from fastapi.responses import Response, StreamingResponse
24
- from pydantic import BaseModel
26
+ from fastapi import FastAPI, Request
27
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
28
+
25
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
26
30
  from sglang.srt.constrained import disable_cache
27
- from sglang.srt.conversation import (
28
- Conversation,
29
- SeparatorStyle,
30
- chat_template_exists,
31
- generate_chat_conv,
32
- register_conv_template,
33
- )
34
31
  from sglang.srt.hf_transformers_utils import get_tokenizer
35
- from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
36
- from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
37
- from sglang.srt.managers.openai_protocol import (
38
- ChatCompletionRequest,
39
- ChatCompletionResponse,
40
- ChatCompletionResponseChoice,
41
- ChatCompletionResponseStreamChoice,
42
- ChatCompletionStreamResponse,
43
- ChatMessage,
44
- CompletionRequest,
45
- CompletionResponse,
46
- CompletionResponseChoice,
47
- CompletionResponseStreamChoice,
48
- CompletionStreamResponse,
49
- DeltaMessage,
50
- LogProbs,
51
- UsageInfo,
32
+ from sglang.srt.managers.controller.manager_multi import (
33
+ start_controller_process as start_controller_process_multi,
34
+ )
35
+ from sglang.srt.managers.controller.manager_single import (
36
+ launch_tp_servers,
37
+ start_controller_process as start_controller_process_single,
52
38
  )
53
- from sglang.srt.managers.router.manager import start_router_process
39
+ from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
40
+ from sglang.srt.managers.io_struct import GenerateReqInput
54
41
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
55
- from sglang.srt.server_args import PortArgs, ServerArgs
56
- from sglang.srt.utils import handle_port_init
57
- from starlette.middleware.base import BaseHTTPMiddleware
58
- from starlette.responses import JSONResponse
59
-
60
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
61
-
62
- API_KEY_HEADER_NAME = "X-API-Key"
63
-
42
+ from sglang.srt.openai_api_adapter import (
43
+ load_chat_template_for_openai_api,
44
+ v1_chat_completions,
45
+ v1_completions,
46
+ )
47
+ from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
48
+ from sglang.srt.utils import (
49
+ API_KEY_HEADER_NAME,
50
+ APIKeyValidatorMiddleware,
51
+ allocate_init_ports,
52
+ assert_pkg_version,
53
+ enable_show_time_cost,
54
+ receive_addrs,
55
+ send_addrs_to_rank_0,
56
+ )
57
+ from sglang.utils import get_exception_traceback
64
58
 
65
- class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
66
- def __init__(self, app, api_key: str):
67
- super().__init__(app)
68
- self.api_key = api_key
59
+ logger = logging.getLogger(__name__)
69
60
 
70
- async def dispatch(self, request: Request, call_next):
71
- # extract API key from the request headers
72
- api_key_header = request.headers.get(API_KEY_HEADER_NAME)
73
- if not api_key_header or api_key_header != self.api_key:
74
- return JSONResponse(
75
- status_code=403,
76
- content={"detail": "Invalid API Key"},
77
- )
78
- response = await call_next(request)
79
- return response
61
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
80
62
 
81
63
 
82
64
  app = FastAPI()
83
65
  tokenizer_manager = None
84
- chat_template_name = None
85
-
86
-
87
- # FIXME: Remove this once we drop support for pydantic 1.x
88
- IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
89
-
90
-
91
- def jsonify_pydantic_model(obj: BaseModel):
92
- if IS_PYDANTIC_1:
93
- return obj.json(ensure_ascii=False)
94
- return obj.model_dump_json()
95
66
 
96
67
 
97
68
  @app.get("/health")
@@ -115,7 +86,7 @@ async def get_server_args():
115
86
 
116
87
  @app.get("/flush_cache")
117
88
  async def flush_cache():
118
- await tokenizer_manager.flush_cache()
89
+ tokenizer_manager.flush_cache()
119
90
  return Response(
120
91
  content="Cache flushed.\nPlease check backend logs for more details. "
121
92
  "(When there are running or waiting requests, the operation will not be performed.)\n",
@@ -123,361 +94,129 @@ async def flush_cache():
123
94
  )
124
95
 
125
96
 
126
- async def detokenize_logprob_tokens(token_logprobs):
127
- token_ids = [tid for tid, _ in token_logprobs]
128
- token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
129
- return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
130
-
131
-
132
- async def stream_generator(obj: GenerateReqInput):
133
- async for out in tokenizer_manager.generate_request(obj):
134
- if obj.return_logprob and obj.return_text_in_logprobs:
135
- out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
136
- out["meta_info"]["token_logprob"]
137
- )
138
- yield out
139
-
140
-
141
- async def make_openai_style_logprobs(token_logprobs):
142
- ret_logprobs = LogProbs()
143
-
144
- for token_text, token_logprob in token_logprobs:
145
- ret_logprobs.tokens.append(token_text)
146
- ret_logprobs.token_logprobs.append(token_logprob)
147
-
148
- # Not supported yet.
149
- ret_logprobs.top_logprobs.append({})
150
- ret_logprobs.text_offset.append(-1)
151
- return ret_logprobs
152
-
153
-
154
- @app.post("/generate")
155
- async def generate_request(obj: GenerateReqInput):
156
- obj.post_init()
157
-
97
+ async def generate_request(obj: GenerateReqInput, request: Request):
158
98
  if obj.stream:
159
99
 
160
100
  async def stream_results():
161
- async for out in stream_generator(obj):
101
+ try:
102
+ async for out in tokenizer_manager.generate_request(obj, request):
103
+ yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
104
+ except ValueError as e:
105
+ out = {"error": {"message": str(e)}}
162
106
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
163
107
  yield "data: [DONE]\n\n"
164
108
 
165
- return StreamingResponse(stream_results(), media_type="text/event-stream")
166
-
167
- ret = await tokenizer_manager.generate_request(obj).__anext__()
168
- if obj.return_logprob and obj.return_text_in_logprobs:
169
- ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
170
- ret["meta_info"]["token_logprob"]
109
+ return StreamingResponse(
110
+ stream_results(),
111
+ media_type="text/event-stream",
112
+ background=tokenizer_manager.create_abort_task(obj),
171
113
  )
172
-
173
- return ret
174
-
175
-
176
- @app.post("/v1/completions")
177
- async def v1_completions(raw_request: Request):
178
- request_json = await raw_request.json()
179
- request = CompletionRequest(**request_json)
180
-
181
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
182
- assert request.n == 1
183
-
184
- adapted_request = GenerateReqInput(
185
- text=request.prompt,
186
- sampling_params={
187
- "temperature": request.temperature,
188
- "max_new_tokens": request.max_tokens,
189
- "stop": request.stop,
190
- "top_p": request.top_p,
191
- "presence_penalty": request.presence_penalty,
192
- "frequency_penalty": request.frequency_penalty,
193
- "regex": request.regex,
194
- },
195
- return_logprob=request.logprobs is not None,
196
- return_text_in_logprobs=True,
197
- stream=request.stream,
198
- )
199
- adapted_request.post_init()
200
-
201
- if adapted_request.stream:
202
-
203
- async def gnerate_stream_resp():
204
- stream_buffer = ""
205
- n_prev_token = 0
206
- async for content in stream_generator(adapted_request):
207
- text = content["text"]
208
- prompt_tokens = content["meta_info"]["prompt_tokens"]
209
- completion_tokens = content["meta_info"]["completion_tokens"]
210
-
211
- if not stream_buffer: # The first chunk
212
- if request.echo:
213
- # Prepend prompt in response text.
214
- text = request.prompt + text
215
- else:
216
- # Skip prompt tokens if echo is disabled.
217
- n_prev_token = prompt_tokens
218
-
219
- if request.logprobs is not None:
220
- logprobs = await make_openai_style_logprobs(
221
- content["meta_info"]["token_logprob"][n_prev_token:]
222
- )
223
- n_prev_token = len(content["meta_info"]["token_logprob"])
224
- else:
225
- logprobs = None
226
-
227
- delta = text[len(stream_buffer) :]
228
- stream_buffer = content["text"]
229
- choice_data = CompletionResponseStreamChoice(
230
- index=0,
231
- text=delta,
232
- logprobs=logprobs,
233
- finish_reason=None,
234
- )
235
- chunk = CompletionStreamResponse(
236
- id=content["meta_info"]["id"],
237
- object="text_completion",
238
- choices=[choice_data],
239
- model=request.model,
240
- usage=UsageInfo(
241
- prompt_tokens=prompt_tokens,
242
- completion_tokens=completion_tokens,
243
- total_tokens=prompt_tokens + completion_tokens,
244
- ),
245
- )
246
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
247
- yield "data: [DONE]\n\n"
248
-
249
- return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
250
-
251
- # Non-streaming response.
252
- ret = await generate_request(adapted_request)
253
- ret = ret[0] if isinstance(ret, list) else ret
254
-
255
- prompt_tokens = ret["meta_info"]["prompt_tokens"]
256
- completion_tokens = ret["meta_info"]["completion_tokens"]
257
- text = ret["text"]
258
- token_logprob_pos = prompt_tokens
259
- if request.echo:
260
- token_logprob_pos = 0
261
- text = request.prompt + text
262
114
  else:
263
- token_logprob_pos = prompt_tokens
115
+ try:
116
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
117
+ return ret
118
+ except ValueError as e:
119
+ return JSONResponse(
120
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
121
+ )
264
122
 
265
- logprobs = (
266
- await make_openai_style_logprobs(
267
- ret["meta_info"]["token_logprob"][token_logprob_pos:]
268
- )
269
- if request.logprobs is not None
270
- else None
271
- )
272
- choice_data = CompletionResponseChoice(
273
- index=0,
274
- text=text,
275
- logprobs=logprobs,
276
- finish_reason=None, # TODO(comaniac): Add finish reason.
277
- )
278
123
 
279
- response = CompletionResponse(
280
- id=ret["meta_info"]["id"],
281
- model=request.model,
282
- choices=[choice_data],
283
- usage=UsageInfo(
284
- prompt_tokens=prompt_tokens,
285
- completion_tokens=completion_tokens,
286
- total_tokens=prompt_tokens + completion_tokens,
287
- ),
288
- )
289
- return response
124
+ app.post("/generate")(generate_request)
125
+ app.put("/generate")(generate_request)
290
126
 
291
127
 
292
- @app.post("/v1/chat/completions")
293
- async def v1_chat_completions(raw_request: Request):
294
- request_json = await raw_request.json()
295
- request = ChatCompletionRequest(**request_json)
296
-
297
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
298
- assert request.n == 1
299
-
300
- # Prep the data needed for the underlying GenerateReqInput:
301
- # - prompt: The full prompt string.
302
- # - stop: Custom stop tokens.
303
- # - image_data: None or a list of image strings (URLs or base64 strings).
304
- # None skips any image processing in GenerateReqInput.
305
- if not isinstance(request.messages, str):
306
- # Apply chat template and its stop strings.
307
- if chat_template_name is None:
308
- # This flow doesn't support the full OpenAI spec. Verify messages
309
- # has the right type before proceeding:
310
- for m in request.messages:
311
- if not isinstance(m.content, str):
312
- raise HTTPException(
313
- status_code=503,
314
- detail="Structured content requests not supported with "
315
- "HuggingFace Chat Templates. "
316
- "Make sure the server specifies a sglang chat template.",
317
- )
318
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
319
- request.messages, tokenize=False, add_generation_prompt=True
320
- )
321
- stop = request.stop
322
- image_data = None
323
- else:
324
- conv = generate_chat_conv(request, chat_template_name)
325
- prompt = conv.get_prompt()
326
- image_data = conv.image_data
327
- stop = conv.stop_str or []
328
- if request.stop:
329
- if isinstance(request.stop, str):
330
- stop.append(request.stop)
331
- else:
332
- stop.extend(request.stop)
333
- else:
334
- # Use the raw prompt and stop strings if the messages is already a string.
335
- prompt = request.messages
336
- stop = request.stop
337
- image_data = None
338
-
339
- adapted_request = GenerateReqInput(
340
- text=prompt,
341
- image_data=image_data,
342
- sampling_params={
343
- "temperature": request.temperature,
344
- "max_new_tokens": request.max_tokens,
345
- "stop": stop,
346
- "top_p": request.top_p,
347
- "presence_penalty": request.presence_penalty,
348
- "frequency_penalty": request.frequency_penalty,
349
- "regex": request.regex,
350
- },
351
- stream=request.stream,
352
- )
353
- adapted_request.post_init()
354
-
355
- if adapted_request.stream:
356
-
357
- async def gnerate_stream_resp():
358
- is_first = True
359
-
360
- stream_buffer = ""
361
- async for content in stream_generator(adapted_request):
362
- if is_first:
363
- # First chunk with role
364
- is_first = False
365
- choice_data = ChatCompletionResponseStreamChoice(
366
- index=0,
367
- delta=DeltaMessage(role="assistant"),
368
- finish_reason=None,
369
- )
370
- chunk = ChatCompletionStreamResponse(
371
- id=content["meta_info"]["id"],
372
- choices=[choice_data],
373
- model=request.model,
374
- )
375
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
376
-
377
- text = content["text"]
378
- delta = text[len(stream_buffer) :]
379
- stream_buffer = text
380
- choice_data = ChatCompletionResponseStreamChoice(
381
- index=0, delta=DeltaMessage(content=delta), finish_reason=None
382
- )
383
- chunk = ChatCompletionStreamResponse(
384
- id=content["meta_info"]["id"],
385
- choices=[choice_data],
386
- model=request.model,
387
- )
388
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
389
- yield "data: [DONE]\n\n"
128
+ @app.post("/v1/completions")
129
+ async def openai_v1_completions(raw_request: Request):
130
+ return await v1_completions(tokenizer_manager, raw_request)
390
131
 
391
- return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
392
132
 
393
- # Non-streaming response.
394
- ret = await generate_request(adapted_request)
395
- prompt_tokens = ret["meta_info"]["prompt_tokens"]
396
- completion_tokens = ret["meta_info"]["completion_tokens"]
397
- choice_data = ChatCompletionResponseChoice(
398
- index=0,
399
- message=ChatMessage(role="assistant", content=ret["text"]),
400
- finish_reason=None, # TODO(comaniac): Add finish reason.
401
- )
402
- response = ChatCompletionResponse(
403
- id=ret["meta_info"]["id"],
404
- model=request.model,
405
- choices=[choice_data],
406
- usage=UsageInfo(
407
- prompt_tokens=prompt_tokens,
408
- completion_tokens=completion_tokens,
409
- total_tokens=prompt_tokens + completion_tokens,
410
- ),
411
- )
412
- return response
133
+ @app.post("/v1/chat/completions")
134
+ async def openai_v1_chat_completions(raw_request: Request):
135
+ return await v1_chat_completions(tokenizer_manager, raw_request)
413
136
 
414
137
 
415
- def launch_server(server_args, pipe_finish_writer):
138
+ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
416
139
  global tokenizer_manager
417
- global chat_template_name
418
140
 
419
- # disable disk cache if needed
141
+ logging.basicConfig(
142
+ level=getattr(logging, server_args.log_level.upper()),
143
+ format="%(message)s",
144
+ )
145
+
146
+ # Set global environments
147
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
148
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
149
+ if server_args.show_time_cost:
150
+ enable_show_time_cost()
420
151
  if server_args.disable_disk_cache:
421
152
  disable_cache()
422
-
423
- # Handle ports
424
- server_args.port, server_args.additional_ports = handle_port_init(
425
- server_args.port, server_args.additional_ports, server_args.tp_size
153
+ if not server_args.disable_flashinfer:
154
+ assert_pkg_version(
155
+ "flashinfer",
156
+ "0.0.8",
157
+ "Please uninstall the old version and "
158
+ "reinstall the latest version by following the instructions "
159
+ "at https://docs.flashinfer.ai/installation.html.",
160
+ )
161
+ if server_args.chat_template:
162
+ # TODO: replace this with huggingface transformers template
163
+ load_chat_template_for_openai_api(server_args.chat_template)
164
+
165
+ # Allocate ports
166
+ assert server_args.tp_size % server_args.nnodes == 0
167
+ tp_size_local = server_args.tp_size // server_args.nnodes
168
+ server_args.port, server_args.additional_ports = allocate_init_ports(
169
+ server_args.port,
170
+ server_args.additional_ports,
171
+ tp_size_local,
172
+ server_args.dp_size,
426
173
  )
427
174
 
175
+ ports = server_args.additional_ports
176
+ model_port_args = []
177
+ for i in range(server_args.dp_size):
178
+ model_port_args.append(
179
+ ModelPortArgs(
180
+ nccl_port=ports[3 + i * (tp_size_local + 1)],
181
+ model_tp_ips=[None] * tp_size_local,
182
+ model_tp_ports=ports[
183
+ 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
184
+ ],
185
+ )
186
+ )
428
187
  port_args = PortArgs(
429
- tokenizer_port=server_args.additional_ports[0],
430
- router_port=server_args.additional_ports[1],
431
- detokenizer_port=server_args.additional_ports[2],
432
- nccl_port=server_args.additional_ports[3],
433
- model_rpc_ports=server_args.additional_ports[4:],
188
+ tokenizer_port=ports[0],
189
+ router_port=ports[1],
190
+ detokenizer_port=ports[2],
191
+ model_port_args=model_port_args,
434
192
  )
435
193
 
436
- # Load chat template if needed
437
- if server_args.chat_template is not None:
438
- print(f"Use chat template: {server_args.chat_template}")
439
- if not chat_template_exists(server_args.chat_template):
440
- if not os.path.exists(server_args.chat_template):
441
- raise RuntimeError(
442
- f"Chat template {server_args.chat_template} is not a built-in template name "
443
- "or a valid chat template file path."
444
- )
445
- with open(server_args.chat_template, "r") as filep:
446
- template = json.load(filep)
447
- try:
448
- sep_style = SeparatorStyle[template["sep_style"]]
449
- except KeyError:
450
- raise ValueError(
451
- f"Unknown separator style: {template['sep_style']}"
452
- ) from None
453
- register_conv_template(
454
- Conversation(
455
- name=template["name"],
456
- system_template=template["system"] + "\n{system_message}",
457
- system_message=template.get("system_message", ""),
458
- roles=(template["user"], template["assistant"]),
459
- sep_style=sep_style,
460
- sep=template.get("sep", "\n"),
461
- stop_str=template["stop_str"],
462
- ),
463
- override=True,
464
- )
465
- chat_template_name = template["name"]
466
- else:
467
- chat_template_name = server_args.chat_template
194
+ # Handle multi-node tp
195
+ if server_args.nnodes > 1:
196
+ assert server_args.dp_size == 1, "Multi-node dp is not supported."
197
+
198
+ if server_args.node_rank != 0:
199
+ tp_size_local = server_args.tp_size // server_args.nnodes
200
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
201
+ tp_rank_range = list(range(server_args.node_rank * tp_size_local,
202
+ (server_args.node_rank + 1) * tp_size_local))
203
+ procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
204
+ port_args.model_port_args[0], model_overide_args)
205
+ while True:
206
+ pass
468
207
 
469
208
  # Launch processes
470
- tokenizer_manager = TokenizerManager(server_args, port_args)
209
+ tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
471
210
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
472
211
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
473
212
 
213
+ if server_args.dp_size == 1:
214
+ start_process = start_controller_process_single
215
+ else:
216
+ start_process = start_controller_process_multi
474
217
  proc_router = mp.Process(
475
- target=start_router_process,
476
- args=(
477
- server_args,
478
- port_args,
479
- pipe_router_writer,
480
- ),
218
+ target=start_process,
219
+ args=(server_args, port_args, pipe_router_writer, model_overide_args),
481
220
  )
482
221
  proc_router.start()
483
222
  proc_detoken = mp.Process(
@@ -497,128 +236,101 @@ def launch_server(server_args, pipe_finish_writer):
497
236
  if router_init_state != "init ok" or detoken_init_state != "init ok":
498
237
  proc_router.kill()
499
238
  proc_detoken.kill()
500
- print("router init state:", router_init_state)
501
- print("detoken init state:", detoken_init_state)
239
+ print(
240
+ f"Initialization failed. router_init_state: {router_init_state}", flush=True
241
+ )
242
+ print(
243
+ f"Initialization failed. detoken_init_state: {detoken_init_state}",
244
+ flush=True,
245
+ )
502
246
  sys.exit(1)
503
-
504
247
  assert proc_router.is_alive() and proc_detoken.is_alive()
505
248
 
506
249
  if server_args.api_key and server_args.api_key != "":
507
250
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
508
251
 
509
- def _launch_server():
510
- uvicorn.run(
511
- app,
512
- host=server_args.host,
513
- port=server_args.port,
514
- log_level=server_args.log_level,
515
- timeout_keep_alive=5,
516
- loop="uvloop",
517
- )
518
-
252
+ # Send a warmup request
519
253
  def _wait_and_warmup():
520
254
  headers = {}
521
255
  url = server_args.url()
522
- if server_args.api_key and server_args.api_key != "":
256
+ if server_args.api_key:
523
257
  headers[API_KEY_HEADER_NAME] = server_args.api_key
524
258
 
259
+ # Wait until the server is launched
525
260
  for _ in range(120):
526
261
  time.sleep(0.5)
527
262
  try:
528
263
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
529
264
  break
530
- except requests.exceptions.RequestException as e:
265
+ except requests.exceptions.RequestException:
531
266
  pass
532
- else:
533
- if pipe_finish_writer is not None:
534
- pipe_finish_writer.send(str(e))
535
- else:
536
- print(e, flush=True)
537
- return
538
267
 
539
- # Warmup
268
+ # Send a warmup request
540
269
  try:
541
- # print("Warmup...", flush=True)
542
- res = requests.post(
543
- url + "/generate",
544
- json={
545
- "text": "Say this is a warmup request.",
546
- "sampling_params": {
547
- "temperature": 0,
548
- "max_new_tokens": 16,
270
+ for _ in range(server_args.dp_size):
271
+ res = requests.post(
272
+ url + "/generate",
273
+ json={
274
+ "text": "The capital city of France is",
275
+ "sampling_params": {
276
+ "temperature": 0,
277
+ "max_new_tokens": 8,
278
+ },
549
279
  },
550
- },
551
- headers=headers,
552
- timeout=60,
553
- )
554
- # print(f"Warmup done. model response: {res.json()['text']}")
555
- # print("=" * 20, "Server is ready", "=" * 20, flush=True)
556
- except requests.exceptions.RequestException as e:
280
+ headers=headers,
281
+ timeout=600,
282
+ )
283
+ assert res.status_code == 200
284
+ except Exception as e:
557
285
  if pipe_finish_writer is not None:
558
- pipe_finish_writer.send(str(e))
559
- else:
560
- print(e, flush=True)
561
- return
286
+ pipe_finish_writer.send(get_exception_traceback())
287
+ print(f"Initialization failed. warmup error: {e}", flush=True)
288
+ raise e
562
289
 
290
+ logger.info("The server is fired up and ready to roll!")
563
291
  if pipe_finish_writer is not None:
564
292
  pipe_finish_writer.send("init ok")
565
293
 
566
294
  t = threading.Thread(target=_wait_and_warmup)
567
295
  t.start()
296
+
297
+ # Listen for requests
568
298
  try:
569
- _launch_server()
299
+ uvicorn.run(
300
+ app,
301
+ host=server_args.host,
302
+ port=server_args.port,
303
+ log_level=server_args.log_level_http or server_args.log_level,
304
+ timeout_keep_alive=5,
305
+ loop="uvloop",
306
+ )
570
307
  finally:
571
308
  t.join()
572
309
 
573
310
 
574
311
  class Runtime:
312
+ """
313
+ A wrapper for the server.
314
+ This is used for launching the server in a python program without
315
+ using the commond line interface.
316
+ """
317
+
575
318
  def __init__(
576
319
  self,
577
- model_path: str,
578
- tokenizer_path: Optional[str] = None,
579
- load_format: str = "auto",
580
- tokenizer_mode: str = "auto",
581
- trust_remote_code: bool = True,
582
- mem_fraction_static: float = ServerArgs.mem_fraction_static,
583
- max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
584
- context_length: int = ServerArgs.context_length,
585
- tp_size: int = 1,
586
- schedule_heuristic: str = "lpm",
587
- attention_reduce_in_fp32: bool = False,
588
- random_seed: int = 42,
589
320
  log_level: str = "error",
590
- disable_radix_cache: bool = False,
591
- enable_flashinfer: bool = False,
592
- disable_regex_jump_forward: bool = False,
593
- disable_disk_cache: bool = False,
594
- api_key: str = "",
595
- port: Optional[int] = None,
596
- additional_ports: Optional[Union[List[int], int]] = None,
321
+ model_overide_args: Optional[dict] = None,
322
+ *args,
323
+ **kwargs,
597
324
  ):
598
- host = "127.0.0.1"
599
- port, additional_ports = handle_port_init(port, additional_ports, tp_size)
600
- self.server_args = ServerArgs(
601
- model_path=model_path,
602
- tokenizer_path=tokenizer_path,
603
- host=host,
604
- port=port,
605
- additional_ports=additional_ports,
606
- load_format=load_format,
607
- tokenizer_mode=tokenizer_mode,
608
- trust_remote_code=trust_remote_code,
609
- mem_fraction_static=mem_fraction_static,
610
- max_prefill_num_token=max_prefill_num_token,
611
- context_length=context_length,
612
- tp_size=tp_size,
613
- schedule_heuristic=schedule_heuristic,
614
- attention_reduce_in_fp32=attention_reduce_in_fp32,
615
- random_seed=random_seed,
616
- log_level=log_level,
617
- disable_radix_cache=disable_radix_cache,
618
- enable_flashinfer=enable_flashinfer,
619
- disable_regex_jump_forward=disable_regex_jump_forward,
620
- disable_disk_cache=disable_disk_cache,
621
- api_key=api_key,
325
+ """See the arguments in server_args.py::ServerArgs"""
326
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
327
+
328
+ # Pre-allocate ports
329
+ self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
330
+ self.server_args.port,
331
+ self.server_args.additional_ports,
332
+ self.server_args.tp_size,
333
+ self.server_args.dp_size,
622
334
  )
623
335
 
624
336
  self.url = self.server_args.url()
@@ -628,7 +340,10 @@ class Runtime:
628
340
 
629
341
  self.pid = None
630
342
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
631
- proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
343
+ proc = mp.Process(
344
+ target=launch_server,
345
+ args=(self.server_args, pipe_writer, model_overide_args),
346
+ )
632
347
  proc.start()
633
348
  pipe_writer.close()
634
349
  self.pid = proc.pid
@@ -640,7 +355,9 @@ class Runtime:
640
355
 
641
356
  if init_state != "init ok":
642
357
  self.shutdown()
643
- raise RuntimeError("Launch failed. Please see the error messages above.")
358
+ raise RuntimeError(
359
+ "Initialization failed. Please see the error messages above."
360
+ )
644
361
 
645
362
  self.endpoint = RuntimeEndpoint(self.url)
646
363
 
@@ -668,14 +385,13 @@ class Runtime:
668
385
  async def add_request(
669
386
  self,
670
387
  prompt: str,
671
- sampling_params,
672
- ) -> None:
388
+ sampling_params: Dict,
389
+ ):
673
390
  json_data = {
674
391
  "text": prompt,
675
392
  "sampling_params": sampling_params,
676
393
  "stream": True,
677
394
  }
678
-
679
395
  pos = 0
680
396
 
681
397
  timeout = aiohttp.ClientTimeout(total=3 * 3600)