sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -3,6 +3,7 @@
3
3
  import asyncio
4
4
  import dataclasses
5
5
  import json
6
+ import logging
6
7
  import multiprocessing as mp
7
8
  import os
8
9
  import sys
@@ -10,88 +11,41 @@ import threading
10
11
  import time
11
12
  from typing import List, Optional, Union
12
13
 
13
- # Fix a Python bug
14
+ # Fix a bug of Python threading
14
15
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
15
16
 
16
17
  import aiohttp
17
18
  import psutil
18
- import pydantic
19
19
  import requests
20
20
  import uvicorn
21
21
  import uvloop
22
- from fastapi import FastAPI, HTTPException, Request
23
- from fastapi.responses import Response, StreamingResponse
24
- from pydantic import BaseModel
22
+ from fastapi import FastAPI, Request
23
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
24
+
25
25
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
26
26
  from sglang.srt.constrained import disable_cache
27
- from sglang.srt.conversation import (
28
- Conversation,
29
- SeparatorStyle,
30
- chat_template_exists,
31
- generate_chat_conv,
32
- register_conv_template,
33
- )
34
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
35
28
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
36
- from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
37
- from sglang.srt.managers.openai_protocol import (
38
- ChatCompletionRequest,
39
- ChatCompletionResponse,
40
- ChatCompletionResponseChoice,
41
- ChatCompletionResponseStreamChoice,
42
- ChatCompletionStreamResponse,
43
- ChatMessage,
44
- CompletionRequest,
45
- CompletionResponse,
46
- CompletionResponseChoice,
47
- CompletionResponseStreamChoice,
48
- CompletionStreamResponse,
49
- DeltaMessage,
50
- LogProbs,
51
- UsageInfo,
52
- )
29
+ from sglang.srt.managers.io_struct import GenerateReqInput
53
30
  from sglang.srt.managers.router.manager import start_router_process
54
31
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
32
+ from sglang.srt.openai_api_adapter import (
33
+ v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
55
34
  from sglang.srt.server_args import PortArgs, ServerArgs
56
- from sglang.srt.utils import handle_port_init
57
- from starlette.middleware.base import BaseHTTPMiddleware
58
- from starlette.responses import JSONResponse
35
+ from sglang.srt.utils import (
36
+ allocate_init_ports,
37
+ assert_pkg_version,
38
+ enable_show_time_cost,
39
+ get_exception_traceback,
40
+ API_KEY_HEADER_NAME,
41
+ APIKeyValidatorMiddleware
42
+ )
59
43
 
60
44
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
61
45
 
62
- API_KEY_HEADER_NAME = "X-API-Key"
63
-
64
-
65
- class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
66
- def __init__(self, app, api_key: str):
67
- super().__init__(app)
68
- self.api_key = api_key
69
-
70
- async def dispatch(self, request: Request, call_next):
71
- # extract API key from the request headers
72
- api_key_header = request.headers.get(API_KEY_HEADER_NAME)
73
- if not api_key_header or api_key_header != self.api_key:
74
- return JSONResponse(
75
- status_code=403,
76
- content={"detail": "Invalid API Key"},
77
- )
78
- response = await call_next(request)
79
- return response
80
-
81
46
 
82
47
  app = FastAPI()
83
48
  tokenizer_manager = None
84
- chat_template_name = None
85
-
86
-
87
- # FIXME: Remove this once we drop support for pydantic 1.x
88
- IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
89
-
90
-
91
- def jsonify_pydantic_model(obj: BaseModel):
92
- if IS_PYDANTIC_1:
93
- return obj.json(ensure_ascii=False)
94
- return obj.model_dump_json()
95
49
 
96
50
 
97
51
  @app.get("/health")
@@ -123,34 +77,6 @@ async def flush_cache():
123
77
  )
124
78
 
125
79
 
126
- async def detokenize_logprob_tokens(token_logprobs):
127
- token_ids = [tid for tid, _ in token_logprobs]
128
- token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
129
- return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
130
-
131
-
132
- async def stream_generator(obj: GenerateReqInput):
133
- async for out in tokenizer_manager.generate_request(obj):
134
- if obj.return_logprob and obj.return_text_in_logprobs:
135
- out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
136
- out["meta_info"]["token_logprob"]
137
- )
138
- yield out
139
-
140
-
141
- async def make_openai_style_logprobs(token_logprobs):
142
- ret_logprobs = LogProbs()
143
-
144
- for token_text, token_logprob in token_logprobs:
145
- ret_logprobs.tokens.append(token_text)
146
- ret_logprobs.token_logprobs.append(token_logprob)
147
-
148
- # Not supported yet.
149
- ret_logprobs.top_logprobs.append({})
150
- ret_logprobs.text_offset.append(-1)
151
- return ret_logprobs
152
-
153
-
154
80
  @app.post("/generate")
155
81
  async def generate_request(obj: GenerateReqInput):
156
82
  obj.post_init()
@@ -158,273 +84,53 @@ async def generate_request(obj: GenerateReqInput):
158
84
  if obj.stream:
159
85
 
160
86
  async def stream_results():
161
- async for out in stream_generator(obj):
87
+ async for out in tokenizer_manager.generate_request(obj):
162
88
  yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
163
89
  yield "data: [DONE]\n\n"
164
90
 
165
91
  return StreamingResponse(stream_results(), media_type="text/event-stream")
166
92
 
167
- ret = await tokenizer_manager.generate_request(obj).__anext__()
168
- if obj.return_logprob and obj.return_text_in_logprobs:
169
- ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
170
- ret["meta_info"]["token_logprob"]
171
- )
172
-
173
- return ret
93
+ try:
94
+ ret = await tokenizer_manager.generate_request(obj).__anext__()
95
+ return ret
96
+ except ValueError as e:
97
+ return JSONResponse({"error": str(e)}, status_code=400)
174
98
 
175
99
 
176
100
  @app.post("/v1/completions")
177
- async def v1_completions(raw_request: Request):
178
- request_json = await raw_request.json()
179
- request = CompletionRequest(**request_json)
180
-
181
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
182
- assert request.n == 1
183
-
184
- adapted_request = GenerateReqInput(
185
- text=request.prompt,
186
- sampling_params={
187
- "temperature": request.temperature,
188
- "max_new_tokens": request.max_tokens,
189
- "stop": request.stop,
190
- "top_p": request.top_p,
191
- "presence_penalty": request.presence_penalty,
192
- "frequency_penalty": request.frequency_penalty,
193
- "regex": request.regex,
194
- },
195
- return_logprob=request.logprobs is not None,
196
- return_text_in_logprobs=True,
197
- stream=request.stream,
198
- )
199
- adapted_request.post_init()
200
-
201
- if adapted_request.stream:
202
-
203
- async def gnerate_stream_resp():
204
- stream_buffer = ""
205
- n_prev_token = 0
206
- async for content in stream_generator(adapted_request):
207
- text = content["text"]
208
- prompt_tokens = content["meta_info"]["prompt_tokens"]
209
- completion_tokens = content["meta_info"]["completion_tokens"]
210
-
211
- if not stream_buffer: # The first chunk
212
- if request.echo:
213
- # Prepend prompt in response text.
214
- text = request.prompt + text
215
- else:
216
- # Skip prompt tokens if echo is disabled.
217
- n_prev_token = prompt_tokens
218
-
219
- if request.logprobs is not None:
220
- logprobs = await make_openai_style_logprobs(
221
- content["meta_info"]["token_logprob"][n_prev_token:]
222
- )
223
- n_prev_token = len(content["meta_info"]["token_logprob"])
224
- else:
225
- logprobs = None
226
-
227
- delta = text[len(stream_buffer) :]
228
- stream_buffer = content["text"]
229
- choice_data = CompletionResponseStreamChoice(
230
- index=0,
231
- text=delta,
232
- logprobs=logprobs,
233
- finish_reason=None,
234
- )
235
- chunk = CompletionStreamResponse(
236
- id=content["meta_info"]["id"],
237
- object="text_completion",
238
- choices=[choice_data],
239
- model=request.model,
240
- usage=UsageInfo(
241
- prompt_tokens=prompt_tokens,
242
- completion_tokens=completion_tokens,
243
- total_tokens=prompt_tokens + completion_tokens,
244
- ),
245
- )
246
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
247
- yield "data: [DONE]\n\n"
248
-
249
- return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
250
-
251
- # Non-streaming response.
252
- ret = await generate_request(adapted_request)
253
- ret = ret[0] if isinstance(ret, list) else ret
254
-
255
- prompt_tokens = ret["meta_info"]["prompt_tokens"]
256
- completion_tokens = ret["meta_info"]["completion_tokens"]
257
- text = ret["text"]
258
- token_logprob_pos = prompt_tokens
259
- if request.echo:
260
- token_logprob_pos = 0
261
- text = request.prompt + text
262
- else:
263
- token_logprob_pos = prompt_tokens
264
-
265
- logprobs = (
266
- await make_openai_style_logprobs(
267
- ret["meta_info"]["token_logprob"][token_logprob_pos:]
268
- )
269
- if request.logprobs is not None
270
- else None
271
- )
272
- choice_data = CompletionResponseChoice(
273
- index=0,
274
- text=text,
275
- logprobs=logprobs,
276
- finish_reason=None, # TODO(comaniac): Add finish reason.
277
- )
278
-
279
- response = CompletionResponse(
280
- id=ret["meta_info"]["id"],
281
- model=request.model,
282
- choices=[choice_data],
283
- usage=UsageInfo(
284
- prompt_tokens=prompt_tokens,
285
- completion_tokens=completion_tokens,
286
- total_tokens=prompt_tokens + completion_tokens,
287
- ),
288
- )
289
- return response
101
+ async def openai_v1_completions(raw_request: Request):
102
+ return await v1_completions(tokenizer_manager, raw_request)
290
103
 
291
104
 
292
105
  @app.post("/v1/chat/completions")
293
- async def v1_chat_completions(raw_request: Request):
294
- request_json = await raw_request.json()
295
- request = ChatCompletionRequest(**request_json)
296
-
297
- # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
298
- assert request.n == 1
299
-
300
- # Prep the data needed for the underlying GenerateReqInput:
301
- # - prompt: The full prompt string.
302
- # - stop: Custom stop tokens.
303
- # - image_data: None or a list of image strings (URLs or base64 strings).
304
- # None skips any image processing in GenerateReqInput.
305
- if not isinstance(request.messages, str):
306
- # Apply chat template and its stop strings.
307
- if chat_template_name is None:
308
- # This flow doesn't support the full OpenAI spec. Verify messages
309
- # has the right type before proceeding:
310
- for m in request.messages:
311
- if not isinstance(m.content, str):
312
- raise HTTPException(
313
- status_code=503,
314
- detail="Structured content requests not supported with "
315
- "HuggingFace Chat Templates. "
316
- "Make sure the server specifies a sglang chat template.",
317
- )
318
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
319
- request.messages, tokenize=False, add_generation_prompt=True
320
- )
321
- stop = request.stop
322
- image_data = None
323
- else:
324
- conv = generate_chat_conv(request, chat_template_name)
325
- prompt = conv.get_prompt()
326
- image_data = conv.image_data
327
- stop = conv.stop_str or []
328
- if request.stop:
329
- if isinstance(request.stop, str):
330
- stop.append(request.stop)
331
- else:
332
- stop.extend(request.stop)
333
- else:
334
- # Use the raw prompt and stop strings if the messages is already a string.
335
- prompt = request.messages
336
- stop = request.stop
337
- image_data = None
338
-
339
- adapted_request = GenerateReqInput(
340
- text=prompt,
341
- image_data=image_data,
342
- sampling_params={
343
- "temperature": request.temperature,
344
- "max_new_tokens": request.max_tokens,
345
- "stop": stop,
346
- "top_p": request.top_p,
347
- "presence_penalty": request.presence_penalty,
348
- "frequency_penalty": request.frequency_penalty,
349
- "regex": request.regex,
350
- },
351
- stream=request.stream,
352
- )
353
- adapted_request.post_init()
354
-
355
- if adapted_request.stream:
356
-
357
- async def gnerate_stream_resp():
358
- is_first = True
359
-
360
- stream_buffer = ""
361
- async for content in stream_generator(adapted_request):
362
- if is_first:
363
- # First chunk with role
364
- is_first = False
365
- choice_data = ChatCompletionResponseStreamChoice(
366
- index=0,
367
- delta=DeltaMessage(role="assistant"),
368
- finish_reason=None,
369
- )
370
- chunk = ChatCompletionStreamResponse(
371
- id=content["meta_info"]["id"],
372
- choices=[choice_data],
373
- model=request.model,
374
- )
375
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
376
-
377
- text = content["text"]
378
- delta = text[len(stream_buffer) :]
379
- stream_buffer = text
380
- choice_data = ChatCompletionResponseStreamChoice(
381
- index=0, delta=DeltaMessage(content=delta), finish_reason=None
382
- )
383
- chunk = ChatCompletionStreamResponse(
384
- id=content["meta_info"]["id"],
385
- choices=[choice_data],
386
- model=request.model,
387
- )
388
- yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
389
- yield "data: [DONE]\n\n"
390
-
391
- return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
106
+ async def openai_v1_chat_completions(raw_request: Request):
107
+ return await v1_chat_completions(tokenizer_manager, raw_request)
392
108
 
393
- # Non-streaming response.
394
- ret = await generate_request(adapted_request)
395
- prompt_tokens = ret["meta_info"]["prompt_tokens"]
396
- completion_tokens = ret["meta_info"]["completion_tokens"]
397
- choice_data = ChatCompletionResponseChoice(
398
- index=0,
399
- message=ChatMessage(role="assistant", content=ret["text"]),
400
- finish_reason=None, # TODO(comaniac): Add finish reason.
401
- )
402
- response = ChatCompletionResponse(
403
- id=ret["meta_info"]["id"],
404
- model=request.model,
405
- choices=[choice_data],
406
- usage=UsageInfo(
407
- prompt_tokens=prompt_tokens,
408
- completion_tokens=completion_tokens,
409
- total_tokens=prompt_tokens + completion_tokens,
410
- ),
411
- )
412
- return response
413
109
 
414
-
415
- def launch_server(server_args, pipe_finish_writer):
110
+ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
416
111
  global tokenizer_manager
417
- global chat_template_name
418
112
 
419
- # disable disk cache if needed
113
+ logging.basicConfig(
114
+ level=getattr(logging, server_args.log_level.upper()),
115
+ format="%(message)s",
116
+ )
117
+
118
+ # Set global environments
119
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
120
+ if server_args.show_time_cost:
121
+ enable_show_time_cost()
420
122
  if server_args.disable_disk_cache:
421
123
  disable_cache()
422
-
423
- # Handle ports
424
- server_args.port, server_args.additional_ports = handle_port_init(
124
+ if server_args.enable_flashinfer:
125
+ assert_pkg_version("flashinfer", "0.0.4")
126
+ if server_args.chat_template:
127
+ # TODO: replace this with huggingface transformers template
128
+ load_chat_template_for_openai_api(server_args.chat_template)
129
+
130
+ # Allocate ports
131
+ server_args.port, server_args.additional_ports = allocate_init_ports(
425
132
  server_args.port, server_args.additional_ports, server_args.tp_size
426
133
  )
427
-
428
134
  port_args = PortArgs(
429
135
  tokenizer_port=server_args.additional_ports[0],
430
136
  router_port=server_args.additional_ports[1],
@@ -433,51 +139,14 @@ def launch_server(server_args, pipe_finish_writer):
433
139
  model_rpc_ports=server_args.additional_ports[4:],
434
140
  )
435
141
 
436
- # Load chat template if needed
437
- if server_args.chat_template is not None:
438
- print(f"Use chat template: {server_args.chat_template}")
439
- if not chat_template_exists(server_args.chat_template):
440
- if not os.path.exists(server_args.chat_template):
441
- raise RuntimeError(
442
- f"Chat template {server_args.chat_template} is not a built-in template name "
443
- "or a valid chat template file path."
444
- )
445
- with open(server_args.chat_template, "r") as filep:
446
- template = json.load(filep)
447
- try:
448
- sep_style = SeparatorStyle[template["sep_style"]]
449
- except KeyError:
450
- raise ValueError(
451
- f"Unknown separator style: {template['sep_style']}"
452
- ) from None
453
- register_conv_template(
454
- Conversation(
455
- name=template["name"],
456
- system_template=template["system"] + "\n{system_message}",
457
- system_message=template.get("system_message", ""),
458
- roles=(template["user"], template["assistant"]),
459
- sep_style=sep_style,
460
- sep=template.get("sep", "\n"),
461
- stop_str=template["stop_str"],
462
- ),
463
- override=True,
464
- )
465
- chat_template_name = template["name"]
466
- else:
467
- chat_template_name = server_args.chat_template
468
-
469
142
  # Launch processes
470
- tokenizer_manager = TokenizerManager(server_args, port_args)
143
+ tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
471
144
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
472
145
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
473
146
 
474
147
  proc_router = mp.Process(
475
148
  target=start_router_process,
476
- args=(
477
- server_args,
478
- port_args,
479
- pipe_router_writer,
480
- ),
149
+ args=(server_args, port_args, pipe_router_writer, model_overide_args),
481
150
  )
482
151
  proc_router.start()
483
152
  proc_detoken = mp.Process(
@@ -497,48 +166,37 @@ def launch_server(server_args, pipe_finish_writer):
497
166
  if router_init_state != "init ok" or detoken_init_state != "init ok":
498
167
  proc_router.kill()
499
168
  proc_detoken.kill()
500
- print("router init state:", router_init_state)
501
- print("detoken init state:", detoken_init_state)
169
+ print(
170
+ f"Initialization failed. router_init_state: {router_init_state}", flush=True
171
+ )
172
+ print(
173
+ f"Initialization failed. detoken_init_state: {detoken_init_state}",
174
+ flush=True,
175
+ )
502
176
  sys.exit(1)
503
-
504
177
  assert proc_router.is_alive() and proc_detoken.is_alive()
505
178
 
506
179
  if server_args.api_key and server_args.api_key != "":
507
180
  app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
508
181
 
509
- def _launch_server():
510
- uvicorn.run(
511
- app,
512
- host=server_args.host,
513
- port=server_args.port,
514
- log_level=server_args.log_level,
515
- timeout_keep_alive=5,
516
- loop="uvloop",
517
- )
518
-
519
182
  def _wait_and_warmup():
520
183
  headers = {}
521
184
  url = server_args.url()
522
- if server_args.api_key and server_args.api_key != "":
185
+ if server_args.api_key:
523
186
  headers[API_KEY_HEADER_NAME] = server_args.api_key
524
187
 
188
+ # Wait until the server is launched
525
189
  for _ in range(120):
526
190
  time.sleep(0.5)
527
191
  try:
528
192
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
193
+ success = True # Set flag to True if request succeeds
529
194
  break
530
195
  except requests.exceptions.RequestException as e:
531
196
  pass
532
- else:
533
- if pipe_finish_writer is not None:
534
- pipe_finish_writer.send(str(e))
535
- else:
536
- print(e, flush=True)
537
- return
538
197
 
539
- # Warmup
198
+ # Send a warmup request
540
199
  try:
541
- # print("Warmup...", flush=True)
542
200
  res = requests.post(
543
201
  url + "/generate",
544
202
  json={
@@ -549,16 +207,14 @@ def launch_server(server_args, pipe_finish_writer):
549
207
  },
550
208
  },
551
209
  headers=headers,
552
- timeout=60,
210
+ timeout=600,
553
211
  )
554
- # print(f"Warmup done. model response: {res.json()['text']}")
555
- # print("=" * 20, "Server is ready", "=" * 20, flush=True)
556
- except requests.exceptions.RequestException as e:
212
+ assert res.status_code == 200
213
+ except Exception as e:
557
214
  if pipe_finish_writer is not None:
558
- pipe_finish_writer.send(str(e))
559
- else:
560
- print(e, flush=True)
561
- return
215
+ pipe_finish_writer.send(get_exception_traceback())
216
+ print(f"Initialization failed. warmup error: {e}")
217
+ raise e
562
218
 
563
219
  if pipe_finish_writer is not None:
564
220
  pipe_finish_writer.send("init ok")
@@ -566,7 +222,14 @@ def launch_server(server_args, pipe_finish_writer):
566
222
  t = threading.Thread(target=_wait_and_warmup)
567
223
  t.start()
568
224
  try:
569
- _launch_server()
225
+ uvicorn.run(
226
+ app,
227
+ host=server_args.host,
228
+ port=server_args.port,
229
+ log_level=server_args.log_level,
230
+ timeout_keep_alive=5,
231
+ loop="uvloop",
232
+ )
570
233
  finally:
571
234
  t.join()
572
235
 
@@ -574,51 +237,19 @@ def launch_server(server_args, pipe_finish_writer):
574
237
  class Runtime:
575
238
  def __init__(
576
239
  self,
577
- model_path: str,
578
- tokenizer_path: Optional[str] = None,
579
- load_format: str = "auto",
580
- tokenizer_mode: str = "auto",
581
- trust_remote_code: bool = True,
582
- mem_fraction_static: float = ServerArgs.mem_fraction_static,
583
- max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
584
- context_length: int = ServerArgs.context_length,
585
- tp_size: int = 1,
586
- schedule_heuristic: str = "lpm",
587
- attention_reduce_in_fp32: bool = False,
588
- random_seed: int = 42,
589
- log_level: str = "error",
590
- disable_radix_cache: bool = False,
591
- enable_flashinfer: bool = False,
592
- disable_regex_jump_forward: bool = False,
593
- disable_disk_cache: bool = False,
594
- api_key: str = "",
595
- port: Optional[int] = None,
596
- additional_ports: Optional[Union[List[int], int]] = None,
240
+ log_evel: str = "error",
241
+ model_overide_args: Optional[dict] = None,
242
+ *args,
243
+ **kwargs,
597
244
  ):
598
- host = "127.0.0.1"
599
- port, additional_ports = handle_port_init(port, additional_ports, tp_size)
600
- self.server_args = ServerArgs(
601
- model_path=model_path,
602
- tokenizer_path=tokenizer_path,
603
- host=host,
604
- port=port,
605
- additional_ports=additional_ports,
606
- load_format=load_format,
607
- tokenizer_mode=tokenizer_mode,
608
- trust_remote_code=trust_remote_code,
609
- mem_fraction_static=mem_fraction_static,
610
- max_prefill_num_token=max_prefill_num_token,
611
- context_length=context_length,
612
- tp_size=tp_size,
613
- schedule_heuristic=schedule_heuristic,
614
- attention_reduce_in_fp32=attention_reduce_in_fp32,
615
- random_seed=random_seed,
616
- log_level=log_level,
617
- disable_radix_cache=disable_radix_cache,
618
- enable_flashinfer=enable_flashinfer,
619
- disable_regex_jump_forward=disable_regex_jump_forward,
620
- disable_disk_cache=disable_disk_cache,
621
- api_key=api_key,
245
+ """See the arguments in server_args.py::ServerArgs"""
246
+ self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
247
+
248
+ # Pre-allocate ports
249
+ self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
250
+ self.server_args.port,
251
+ self.server_args.additional_ports,
252
+ self.server_args.tp_size,
622
253
  )
623
254
 
624
255
  self.url = self.server_args.url()
@@ -628,7 +259,10 @@ class Runtime:
628
259
 
629
260
  self.pid = None
630
261
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
631
- proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
262
+ proc = mp.Process(
263
+ target=launch_server,
264
+ args=(self.server_args, pipe_writer, model_overide_args),
265
+ )
632
266
  proc.start()
633
267
  pipe_writer.close()
634
268
  self.pid = proc.pid
@@ -640,7 +274,9 @@ class Runtime:
640
274
 
641
275
  if init_state != "init ok":
642
276
  self.shutdown()
643
- raise RuntimeError("Launch failed. Please see the error messages above.")
277
+ raise RuntimeError(
278
+ "Initialization failed. Please see the error messages above."
279
+ )
644
280
 
645
281
  self.endpoint = RuntimeEndpoint(self.url)
646
282
 
@@ -669,13 +305,12 @@ class Runtime:
669
305
  self,
670
306
  prompt: str,
671
307
  sampling_params,
672
- ) -> None:
308
+ ):
673
309
  json_data = {
674
310
  "text": prompt,
675
311
  "sampling_params": sampling_params,
676
312
  "stream": True,
677
313
  }
678
-
679
314
  pos = 0
680
315
 
681
316
  timeout = aiohttp.ClientTimeout(total=3 * 3600)