sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/tracer.py +6 -4
  11. sglang/launch_server.py +2 -1
  12. sglang/srt/constrained/fsm_cache.py +1 -0
  13. sglang/srt/constrained/jump_forward.py +1 -0
  14. sglang/srt/conversation.py +2 -2
  15. sglang/srt/hf_transformers_utils.py +2 -1
  16. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  17. sglang/srt/layers/extend_attention.py +1 -0
  18. sglang/srt/layers/logits_processor.py +114 -54
  19. sglang/srt/layers/radix_attention.py +2 -1
  20. sglang/srt/layers/token_attention.py +1 -0
  21. sglang/srt/managers/detokenizer_manager.py +5 -1
  22. sglang/srt/managers/io_struct.py +12 -0
  23. sglang/srt/managers/router/infer_batch.py +70 -33
  24. sglang/srt/managers/router/manager.py +7 -2
  25. sglang/srt/managers/router/model_rpc.py +116 -73
  26. sglang/srt/managers/router/model_runner.py +111 -167
  27. sglang/srt/managers/router/radix_cache.py +46 -38
  28. sglang/srt/managers/tokenizer_manager.py +56 -11
  29. sglang/srt/memory_pool.py +5 -14
  30. sglang/srt/model_config.py +7 -0
  31. sglang/srt/models/commandr.py +376 -0
  32. sglang/srt/models/dbrx.py +413 -0
  33. sglang/srt/models/dbrx_config.py +281 -0
  34. sglang/srt/models/gemma.py +22 -20
  35. sglang/srt/models/llama2.py +23 -21
  36. sglang/srt/models/llava.py +12 -10
  37. sglang/srt/models/mixtral.py +27 -25
  38. sglang/srt/models/qwen.py +23 -21
  39. sglang/srt/models/qwen2.py +23 -21
  40. sglang/srt/models/stablelm.py +20 -21
  41. sglang/srt/models/yivl.py +6 -5
  42. sglang/srt/openai_api_adapter.py +356 -0
  43. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  44. sglang/srt/sampling_params.py +2 -0
  45. sglang/srt/server.py +68 -447
  46. sglang/srt/server_args.py +76 -49
  47. sglang/srt/utils.py +88 -32
  48. sglang/srt/weight_utils.py +402 -0
  49. sglang/test/test_programs.py +8 -7
  50. sglang/test/test_utils.py +195 -7
  51. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
  52. sglang-0.1.15.dist-info/RECORD +69 -0
  53. sglang-0.1.14.dist-info/RECORD +0 -64
  54. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  55. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
  56. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -3,6 +3,7 @@
3
3
  import asyncio
4
4
  import dataclasses
5
5
  import json
6
+ import logging
6
7
  import multiprocessing as mp
7
8
  import os
8
9
  import sys
@@ -10,88 +11,41 @@ import threading
10
11
  import time
11
12
  from typing import List, Optional, Union
12
13
 
13
- # Fix a Python bug
14
+ # Fix a bug of Python threading
14
15
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
15
16
 
16
17
  import aiohttp
17
18
  import psutil
18
- import pydantic
19
19
  import requests
20
20
  import uvicorn
21
21
  import uvloop
22
- from fastapi import FastAPI, HTTPException, Request
22
+ from fastapi import FastAPI, Request
23
23
  from fastapi.responses import Response, StreamingResponse
24
- from pydantic import BaseModel
24
+
25
25
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
26
26
  from sglang.srt.constrained import disable_cache
27
- from sglang.srt.conversation import (
28
- Conversation,
29
- SeparatorStyle,
30
- chat_template_exists,
31
- generate_chat_conv,
32
- register_conv_template,
33
- )
34
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
35
28
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
36
- from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
37
- from sglang.srt.managers.openai_protocol import (
38
- ChatCompletionRequest,
39
- ChatCompletionResponse,
40
- ChatCompletionResponseChoice,
41
- ChatCompletionResponseStreamChoice,
42
- ChatCompletionStreamResponse,
43
- ChatMessage,
44
- CompletionRequest,
45
- CompletionResponse,
46
- CompletionResponseChoice,
47
- CompletionResponseStreamChoice,
48
- CompletionStreamResponse,
49
- DeltaMessage,
50
- LogProbs,
51
- UsageInfo,
52
- )
29
+ from sglang.srt.managers.io_struct import GenerateReqInput
53
30
  from sglang.srt.managers.router.manager import start_router_process
54
31
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
32
+ from sglang.srt.openai_api_adapter import (
33
+ v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
55
34
  from sglang.srt.server_args import PortArgs, ServerArgs
56
- from sglang.srt.utils import handle_port_init
57
- from starlette.middleware.base import BaseHTTPMiddleware
58
- from starlette.responses import JSONResponse
35
+ from sglang.srt.utils import (
36
+ allocate_init_ports,
37
+ assert_pkg_version,
38
+ enable_show_time_cost,
39
+ get_exception_traceback,
40
+ API_KEY_HEADER_NAME,
41
+ APIKeyValidatorMiddleware
42
+ )
59
43
 
60
44
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
61
45
 
62
- API_KEY_HEADER_NAME = "X-API-Key"
63
-
64
-
65
- class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
66
- def __init__(self, app, api_key: str):
67
- super().__init__(app)
68
- self.api_key = api_key
69
-
70
- async def dispatch(self, request: Request, call_next):
71
- # extract API key from the request headers
72
- api_key_header = request.headers.get(API_KEY_HEADER_NAME)
73
- if not api_key_header or api_key_header != self.api_key:
74
- return JSONResponse(
75
- status_code=403,
76
- content={"detail": "Invalid API Key"},
77
- )
78
- response = await call_next(request)
79
- return response
80
-
81
46
 
82
47
  app = FastAPI()
83
48
  tokenizer_manager = None
84
- chat_template_name = None
85
-
86
-
87
- # FIXME: Remove this once we drop support for pydantic 1.x
88
- IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
89
-
90
-
91
- def jsonify_pydantic_model(obj: BaseModel):
92
- if IS_PYDANTIC_1:
93
- return obj.json(ensure_ascii=False)
94
- return obj.model_dump_json()
95
49
 
96
50
 
97
51
  @app.get("/health")
@@ -123,34 +77,6 @@ async def flush_cache():
123
77
  )
124
78
 
125
79
 
126
- async def detokenize_logprob_tokens(token_logprobs):
127
- token_ids = [tid for tid, _ in token_logprobs]
128
- token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
129
- return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
130
-
131
-
132
- async def stream_generator(obj: GenerateReqInput):
133
- async for out in tokenizer_manager.generate_request(obj):
134
- if obj.return_logprob and obj.return_text_in_logprobs:
135
- out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
136
- out["meta_info"]["token_logprob"]
137
- )
138
- yield out
139
-
140
-
141
- async def make_openai_style_logprobs(token_logprobs):
142
- ret_logprobs = LogProbs()
143
-
144
- for token_text, token_logprob in token_logprobs:
145
- ret_logprobs.tokens.append(token_text)
146
- ret_logprobs.token_logprobs.append(token_logprob)
147
-
148
- # Not supported yet.
149
- ret_logprobs.top_logprobs.append({})
150
- ret_logprobs.text_offset.append(-1)
151
- return ret_logprobs
152
-
153
-
154
80
  @app.post("/generate")
155
81
  async def generate_request(obj: GenerateReqInput):
156
82
  obj.post_init()
@@ -158,273 +84,50 @@ async def generate_request(obj: GenerateReqInput):
158
84
  if obj.stream:
159
85
 
160
86
  async def stream_results():
161
- async for out in stream_generator(obj):
87
+ async for out in tokenizer_manager.generate_request(obj):
162
88
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
163
89
  yield "data: [DONE]\n\n"
164
90
 
165
91
  return StreamingResponse(stream_results(), media_type="text/event-stream")
166
92
 
167
93
  ret = await tokenizer_manager.generate_request(obj).__anext__()
168
- if obj.return_logprob and obj.return_text_in_logprobs:
169
- ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
170
- ret["meta_info"]["token_logprob"]
171
- )
172
-
173
94
  return ret
174
95
 
175
96
 
176
97
  @app.post("/v1/completions")
177
- async def v1_completions(raw_request: Request):
178
- request_json = await raw_request.json()
179
- request = CompletionRequest(**request_json)
180
-
181
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
182
- assert request.n == 1
183
-
184
- adapted_request = GenerateReqInput(
185
- text=request.prompt,
186
- sampling_params={
187
- "temperature": request.temperature,
188
- "max_new_tokens": request.max_tokens,
189
- "stop": request.stop,
190
- "top_p": request.top_p,
191
- "presence_penalty": request.presence_penalty,
192
- "frequency_penalty": request.frequency_penalty,
193
- "regex": request.regex,
194
- },
195
- return_logprob=request.logprobs is not None,
196
- return_text_in_logprobs=True,
197
- stream=request.stream,
198
- )
199
- adapted_request.post_init()
200
-
201
- if adapted_request.stream:
202
-
203
- async def gnerate_stream_resp():
204
- stream_buffer = ""
205
- n_prev_token = 0
206
- async for content in stream_generator(adapted_request):
207
- text = content["text"]
208
- prompt_tokens = content["meta_info"]["prompt_tokens"]
209
- completion_tokens = content["meta_info"]["completion_tokens"]
210
-
211
- if not stream_buffer: # The first chunk
212
- if request.echo:
213
- # Prepend prompt in response text.
214
- text = request.prompt + text
215
- else:
216
- # Skip prompt tokens if echo is disabled.
217
- n_prev_token = prompt_tokens
218
-
219
- if request.logprobs is not None:
220
- logprobs = await make_openai_style_logprobs(
221
- content["meta_info"]["token_logprob"][n_prev_token:]
222
- )
223
- n_prev_token = len(content["meta_info"]["token_logprob"])
224
- else:
225
- logprobs = None
226
-
227
- delta = text[len(stream_buffer) :]
228
- stream_buffer = content["text"]
229
- choice_data = CompletionResponseStreamChoice(
230
- index=0,
231
- text=delta,
232
- logprobs=logprobs,
233
- finish_reason=None,
234
- )
235
- chunk = CompletionStreamResponse(
236
- id=content["meta_info"]["id"],
237
- object="text_completion",
238
- choices=[choice_data],
239
- model=request.model,
240
- usage=UsageInfo(
241
- prompt_tokens=prompt_tokens,
242
- completion_tokens=completion_tokens,
243
- total_tokens=prompt_tokens + completion_tokens,
244
- ),
245
- )
246
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
247
- yield "data: [DONE]\n\n"
248
-
249
- return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
250
-
251
- # Non-streaming response.
252
- ret = await generate_request(adapted_request)
253
- ret = ret[0] if isinstance(ret, list) else ret
254
-
255
- prompt_tokens = ret["meta_info"]["prompt_tokens"]
256
- completion_tokens = ret["meta_info"]["completion_tokens"]
257
- text = ret["text"]
258
- token_logprob_pos = prompt_tokens
259
- if request.echo:
260
- token_logprob_pos = 0
261
- text = request.prompt + text
262
- else:
263
- token_logprob_pos = prompt_tokens
264
-
265
- logprobs = (
266
- await make_openai_style_logprobs(
267
- ret["meta_info"]["token_logprob"][token_logprob_pos:]
268
- )
269
- if request.logprobs is not None
270
- else None
271
- )
272
- choice_data = CompletionResponseChoice(
273
- index=0,
274
- text=text,
275
- logprobs=logprobs,
276
- finish_reason=None, # TODO(comaniac): Add finish reason.
277
- )
278
-
279
- response = CompletionResponse(
280
- id=ret["meta_info"]["id"],
281
- model=request.model,
282
- choices=[choice_data],
283
- usage=UsageInfo(
284
- prompt_tokens=prompt_tokens,
285
- completion_tokens=completion_tokens,
286
- total_tokens=prompt_tokens + completion_tokens,
287
- ),
288
- )
289
- return response
98
+ async def openai_v1_completions(raw_request: Request):
99
+ return await v1_completions(tokenizer_manager, raw_request)
290
100
 
291
101
 
292
102
  @app.post("/v1/chat/completions")
293
- async def v1_chat_completions(raw_request: Request):
294
- request_json = await raw_request.json()
295
- request = ChatCompletionRequest(**request_json)
296
-
297
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
298
- assert request.n == 1
299
-
300
- # Prep the data needed for the underlying GenerateReqInput:
301
- # - prompt: The full prompt string.
302
- # - stop: Custom stop tokens.
303
- # - image_data: None or a list of image strings (URLs or base64 strings).
304
- # None skips any image processing in GenerateReqInput.
305
- if not isinstance(request.messages, str):
306
- # Apply chat template and its stop strings.
307
- if chat_template_name is None:
308
- # This flow doesn't support the full OpenAI spec. Verify messages
309
- # has the right type before proceeding:
310
- for m in request.messages:
311
- if not isinstance(m.content, str):
312
- raise HTTPException(
313
- status_code=503,
314
- detail="Structured content requests not supported with "
315
- "HuggingFace Chat Templates. "
316
- "Make sure the server specifies a sglang chat template.",
317
- )
318
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
319
- request.messages, tokenize=False, add_generation_prompt=True
320
- )
321
- stop = request.stop
322
- image_data = None
323
- else:
324
- conv = generate_chat_conv(request, chat_template_name)
325
- prompt = conv.get_prompt()
326
- image_data = conv.image_data
327
- stop = conv.stop_str or []
328
- if request.stop:
329
- if isinstance(request.stop, str):
330
- stop.append(request.stop)
331
- else:
332
- stop.extend(request.stop)
333
- else:
334
- # Use the raw prompt and stop strings if the messages is already a string.
335
- prompt = request.messages
336
- stop = request.stop
337
- image_data = None
338
-
339
- adapted_request = GenerateReqInput(
340
- text=prompt,
341
- image_data=image_data,
342
- sampling_params={
343
- "temperature": request.temperature,
344
- "max_new_tokens": request.max_tokens,
345
- "stop": stop,
346
- "top_p": request.top_p,
347
- "presence_penalty": request.presence_penalty,
348
- "frequency_penalty": request.frequency_penalty,
349
- "regex": request.regex,
350
- },
351
- stream=request.stream,
352
- )
353
- adapted_request.post_init()
354
-
355
- if adapted_request.stream:
356
-
357
- async def gnerate_stream_resp():
358
- is_first = True
359
-
360
- stream_buffer = ""
361
- async for content in stream_generator(adapted_request):
362
- if is_first:
363
- # First chunk with role
364
- is_first = False
365
- choice_data = ChatCompletionResponseStreamChoice(
366
- index=0,
367
- delta=DeltaMessage(role="assistant"),
368
- finish_reason=None,
369
- )
370
- chunk = ChatCompletionStreamResponse(
371
- id=content["meta_info"]["id"],
372
- choices=[choice_data],
373
- model=request.model,
374
- )
375
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
376
-
377
- text = content["text"]
378
- delta = text[len(stream_buffer) :]
379
- stream_buffer = text
380
- choice_data = ChatCompletionResponseStreamChoice(
381
- index=0, delta=DeltaMessage(content=delta), finish_reason=None
382
- )
383
- chunk = ChatCompletionStreamResponse(
384
- id=content["meta_info"]["id"],
385
- choices=[choice_data],
386
- model=request.model,
387
- )
388
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
389
- yield "data: [DONE]\n\n"
390
-
391
- return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
103
+ async def openai_v1_chat_completions(raw_request: Request):
104
+ return await v1_chat_completions(tokenizer_manager, raw_request)
392
105
 
393
- # Non-streaming response.
394
- ret = await generate_request(adapted_request)
395
- prompt_tokens = ret["meta_info"]["prompt_tokens"]
396
- completion_tokens = ret["meta_info"]["completion_tokens"]
397
- choice_data = ChatCompletionResponseChoice(
398
- index=0,
399
- message=ChatMessage(role="assistant", content=ret["text"]),
400
- finish_reason=None, # TODO(comaniac): Add finish reason.
401
- )
402
- response = ChatCompletionResponse(
403
- id=ret["meta_info"]["id"],
404
- model=request.model,
405
- choices=[choice_data],
406
- usage=UsageInfo(
407
- prompt_tokens=prompt_tokens,
408
- completion_tokens=completion_tokens,
409
- total_tokens=prompt_tokens + completion_tokens,
410
- ),
411
- )
412
- return response
413
106
 
414
-
415
- def launch_server(server_args, pipe_finish_writer):
107
+ def launch_server(server_args: ServerArgs, pipe_finish_writer):
416
108
  global tokenizer_manager
417
- global chat_template_name
418
109
 
419
- # disable disk cache if needed
110
+ logging.basicConfig(
111
+ level=getattr(logging, server_args.log_level.upper()),
112
+ format="%(message)s",
113
+ )
114
+
115
+ # Set global environments
116
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
117
+ if server_args.show_time_cost:
118
+ enable_show_time_cost()
420
119
  if server_args.disable_disk_cache:
421
120
  disable_cache()
422
-
423
- # Handle ports
424
- server_args.port, server_args.additional_ports = handle_port_init(
121
+ if server_args.enable_flashinfer:
122
+ assert_pkg_version("flashinfer", "0.0.4")
123
+ if server_args.chat_template:
124
+ # TODO: replace this with huggingface transformers template
125
+ load_chat_template_for_openai_api(server_args.chat_template)
126
+
127
+ # Allocate ports
128
+ server_args.port, server_args.additional_ports = allocate_init_ports(
425
129
  server_args.port, server_args.additional_ports, server_args.tp_size
426
130
  )
427
-
428
131
  port_args = PortArgs(
429
132
  tokenizer_port=server_args.additional_ports[0],
430
133
  router_port=server_args.additional_ports[1],
@@ -433,39 +136,6 @@ def launch_server(server_args, pipe_finish_writer):
433
136
  model_rpc_ports=server_args.additional_ports[4:],
434
137
  )
435
138
 
436
- # Load chat template if needed
437
- if server_args.chat_template is not None:
438
- print(f"Use chat template: {server_args.chat_template}")
439
- if not chat_template_exists(server_args.chat_template):
440
- if not os.path.exists(server_args.chat_template):
441
- raise RuntimeError(
442
- f"Chat template {server_args.chat_template} is not a built-in template name "
443
- "or a valid chat template file path."
444
- )
445
- with open(server_args.chat_template, "r") as filep:
446
- template = json.load(filep)
447
- try:
448
- sep_style = SeparatorStyle[template["sep_style"]]
449
- except KeyError:
450
- raise ValueError(
451
- f"Unknown separator style: {template['sep_style']}"
452
- ) from None
453
- register_conv_template(
454
- Conversation(
455
- name=template["name"],
456
- system_template=template["system"] + "\n{system_message}",
457
- system_message=template.get("system_message", ""),
458
- roles=(template["user"], template["assistant"]),
459
- sep_style=sep_style,
460
- sep=template.get("sep", "\n"),
461
- stop_str=template["stop_str"],
462
- ),
463
- override=True,
464
- )
465
- chat_template_name = template["name"]
466
- else:
467
- chat_template_name = server_args.chat_template
468
-
469
139
  # Launch processes
470
140
  tokenizer_manager = TokenizerManager(server_args, port_args)
471
141
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
@@ -497,31 +167,21 @@ def launch_server(server_args, pipe_finish_writer):
497
167
  if router_init_state != "init ok" or detoken_init_state != "init ok":
498
168
  proc_router.kill()
499
169
  proc_detoken.kill()
500
- print("router init state:", router_init_state)
501
- print("detoken init state:", detoken_init_state)
170
+ print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
171
+ print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
502
172
  sys.exit(1)
503
-
504
173
  assert proc_router.is_alive() and proc_detoken.is_alive()
505
174
 
506
175
  if server_args.api_key and server_args.api_key != "":
507
176
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
508
177
 
509
- def _launch_server():
510
- uvicorn.run(
511
- app,
512
- host=server_args.host,
513
- port=server_args.port,
514
- log_level=server_args.log_level,
515
- timeout_keep_alive=5,
516
- loop="uvloop",
517
- )
518
-
519
178
  def _wait_and_warmup():
520
179
  headers = {}
521
180
  url = server_args.url()
522
- if server_args.api_key and server_args.api_key != "":
181
+ if server_args.api_key:
523
182
  headers[API_KEY_HEADER_NAME] = server_args.api_key
524
183
 
184
+ # Wait until the server is launched
525
185
  for _ in range(120):
526
186
  time.sleep(0.5)
527
187
  try:
@@ -529,16 +189,9 @@ def launch_server(server_args, pipe_finish_writer):
529
189
  break
530
190
  except requests.exceptions.RequestException as e:
531
191
  pass
532
- else:
533
- if pipe_finish_writer is not None:
534
- pipe_finish_writer.send(str(e))
535
- else:
536
- print(e, flush=True)
537
- return
538
192
 
539
- # Warmup
193
+ # Send a warmup request
540
194
  try:
541
- # print("Warmup...", flush=True)
542
195
  res = requests.post(
543
196
  url + "/generate",
544
197
  json={
@@ -551,14 +204,12 @@ def launch_server(server_args, pipe_finish_writer):
551
204
  headers=headers,
552
205
  timeout=60,
553
206
  )
554
- # print(f"Warmup done. model response: {res.json()['text']}")
555
- # print("=" * 20, "Server is ready", "=" * 20, flush=True)
556
- except requests.exceptions.RequestException as e:
207
+ assert res.status_code == 200
208
+ except Exception as e:
557
209
  if pipe_finish_writer is not None:
558
- pipe_finish_writer.send(str(e))
559
- else:
560
- print(e, flush=True)
561
- return
210
+ pipe_finish_writer.send(get_exception_traceback())
211
+ print(f"Initialization failed. warmup error: {e}")
212
+ raise e
562
213
 
563
214
  if pipe_finish_writer is not None:
564
215
  pipe_finish_writer.send("init ok")
@@ -566,7 +217,14 @@ def launch_server(server_args, pipe_finish_writer):
566
217
  t = threading.Thread(target=_wait_and_warmup)
567
218
  t.start()
568
219
  try:
569
- _launch_server()
220
+ uvicorn.run(
221
+ app,
222
+ host=server_args.host,
223
+ port=server_args.port,
224
+ log_level=server_args.log_level,
225
+ timeout_keep_alive=5,
226
+ loop="uvloop",
227
+ )
570
228
  finally:
571
229
  t.join()
572
230
 
@@ -574,52 +232,16 @@ def launch_server(server_args, pipe_finish_writer):
574
232
  class Runtime:
575
233
  def __init__(
576
234
  self,
577
- model_path: str,
578
- tokenizer_path: Optional[str] = None,
579
- load_format: str = "auto",
580
- tokenizer_mode: str = "auto",
581
- trust_remote_code: bool = True,
582
- mem_fraction_static: float = ServerArgs.mem_fraction_static,
583
- max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
584
- context_length: int = ServerArgs.context_length,
585
- tp_size: int = 1,
586
- schedule_heuristic: str = "lpm",
587
- attention_reduce_in_fp32: bool = False,
588
- random_seed: int = 42,
589
- log_level: str = "error",
590
- disable_radix_cache: bool = False,
591
- enable_flashinfer: bool = False,
592
- disable_regex_jump_forward: bool = False,
593
- disable_disk_cache: bool = False,
594
- api_key: str = "",
595
- port: Optional[int] = None,
596
- additional_ports: Optional[Union[List[int], int]] = None,
235
+ log_evel="error",
236
+ *args,
237
+ **kwargs,
597
238
  ):
598
- host = "127.0.0.1"
599
- port, additional_ports = handle_port_init(port, additional_ports, tp_size)
600
- self.server_args = ServerArgs(
601
- model_path=model_path,
602
- tokenizer_path=tokenizer_path,
603
- host=host,
604
- port=port,
605
- additional_ports=additional_ports,
606
- load_format=load_format,
607
- tokenizer_mode=tokenizer_mode,
608
- trust_remote_code=trust_remote_code,
609
- mem_fraction_static=mem_fraction_static,
610
- max_prefill_num_token=max_prefill_num_token,
611
- context_length=context_length,
612
- tp_size=tp_size,
613
- schedule_heuristic=schedule_heuristic,
614
- attention_reduce_in_fp32=attention_reduce_in_fp32,
615
- random_seed=random_seed,
616
- log_level=log_level,
617
- disable_radix_cache=disable_radix_cache,
618
- enable_flashinfer=enable_flashinfer,
619
- disable_regex_jump_forward=disable_regex_jump_forward,
620
- disable_disk_cache=disable_disk_cache,
621
- api_key=api_key,
622
- )
239
+ """See the arguments in server_args.py::ServerArgs"""
240
+ self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
241
+
242
+ # Pre-allocate ports
243
+ self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
244
+ self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
623
245
 
624
246
  self.url = self.server_args.url()
625
247
  self.generate_url = (
@@ -640,7 +262,7 @@ class Runtime:
640
262
 
641
263
  if init_state != "init ok":
642
264
  self.shutdown()
643
- raise RuntimeError("Launch failed. Please see the error messages above.")
265
+ raise RuntimeError("Initialization failed. Please see the error messages above.")
644
266
 
645
267
  self.endpoint = RuntimeEndpoint(self.url)
646
268
 
@@ -669,13 +291,12 @@ class Runtime:
669
291
  self,
670
292
  prompt: str,
671
293
  sampling_params,
672
- ) -> None:
294
+ ):
673
295
  json_data = {
674
296
  "text": prompt,
675
297
  "sampling_params": sampling_params,
676
298
  "stream": True,
677
299
  }
678
-
679
300
  pos = 0
680
301
 
681
302
  timeout = aiohttp.ClientTimeout(total=3 * 3600)
@@ -693,4 +314,4 @@ class Runtime:
693
314
  pos += len(cur)
694
315
 
695
316
  def __del__(self):
696
- self.shutdown()
317
+ self.shutdown()