sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -24,7 +24,6 @@ import json
24
24
  import logging
25
25
  import multiprocessing as mp
26
26
  import os
27
- import sys
28
27
  import threading
29
28
  import time
30
29
  from http import HTTPStatus
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
34
33
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
35
34
 
36
35
  import aiohttp
37
- import psutil
38
36
  import requests
39
37
  import uvicorn
40
38
  import uvloop
@@ -52,11 +50,16 @@ from sglang.srt.managers.controller_single import (
52
50
  start_controller_process as start_controller_process_single,
53
51
  )
54
52
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
55
- from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
53
+ from sglang.srt.managers.io_struct import (
54
+ EmbeddingReqInput,
55
+ GenerateReqInput,
56
+ UpdateWeightReqInput,
57
+ )
56
58
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
57
59
  from sglang.srt.openai_api.adapter import (
58
60
  load_chat_template_for_openai_api,
59
61
  v1_batches,
62
+ v1_cancel_batch,
60
63
  v1_chat_completions,
61
64
  v1_completions,
62
65
  v1_delete_file,
@@ -72,6 +75,7 @@ from sglang.srt.utils import (
72
75
  add_api_key_middleware,
73
76
  allocate_init_ports,
74
77
  assert_pkg_version,
78
+ configure_logger,
75
79
  enable_show_time_cost,
76
80
  kill_child_process,
77
81
  maybe_set_triton_cache_manager,
@@ -92,10 +96,25 @@ tokenizer_manager = None
92
96
 
93
97
  @app.get("/health")
94
98
  async def health() -> Response:
95
- """Health check."""
99
+ """Check the health of the http server."""
96
100
  return Response(status_code=200)
97
101
 
98
102
 
103
+ @app.get("/health_generate")
104
+ async def health_generate(request: Request) -> Response:
105
+ """Check the health of the inference server by generating one token."""
106
+ gri = GenerateReqInput(
107
+ text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
108
+ )
109
+ try:
110
+ async for _ in tokenizer_manager.generate_request(gri, request):
111
+ break
112
+ return Response(status_code=200)
113
+ except Exception as e:
114
+ logger.exception(e)
115
+ return Response(status_code=503)
116
+
117
+
99
118
  @app.get("/get_model_info")
100
119
  async def get_model_info():
101
120
  result = {
@@ -120,6 +139,23 @@ async def flush_cache():
120
139
  )
121
140
 
122
141
 
142
+ @app.post("/update_weights")
143
+ async def update_weights(obj: UpdateWeightReqInput, request: Request):
144
+
145
+ success, message = await tokenizer_manager.update_weights(obj, request)
146
+ content = {"message": message, "success": str(success)}
147
+ if success:
148
+ return JSONResponse(
149
+ content,
150
+ status_code=HTTPStatus.OK,
151
+ )
152
+ else:
153
+ return JSONResponse(
154
+ content,
155
+ status_code=HTTPStatus.BAD_REQUEST,
156
+ )
157
+
158
+
123
159
  async def generate_request(obj: GenerateReqInput, request: Request):
124
160
  """Handle a generate request."""
125
161
  if obj.stream:
@@ -211,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
211
247
  return await v1_batches(tokenizer_manager, raw_request)
212
248
 
213
249
 
250
+ @app.post("/v1/batches/{batch_id}/cancel")
251
+ async def cancel_batches(batch_id: str):
252
+ # https://platform.openai.com/docs/api-reference/batch/cancel
253
+ return await v1_cancel_batch(tokenizer_manager, batch_id)
254
+
255
+
214
256
  @app.get("/v1/batches/{batch_id}")
215
257
  async def retrieve_batch(batch_id: str):
216
258
  return await v1_retrieve_batch(batch_id)
@@ -236,15 +278,12 @@ def launch_server(
236
278
  """Launch an HTTP server."""
237
279
  global tokenizer_manager
238
280
 
239
- logging.basicConfig(
240
- level=getattr(logging, server_args.log_level.upper()),
241
- format="%(message)s",
242
- )
281
+ configure_logger(server_args)
243
282
 
244
283
  server_args.check_server_args()
245
284
  _set_envs_and_config(server_args)
246
285
 
247
- # Allocate ports
286
+ # Allocate ports for inter-process communications
248
287
  server_args.port, server_args.additional_ports = allocate_init_ports(
249
288
  server_args.port,
250
289
  server_args.additional_ports,
@@ -264,27 +303,29 @@ def launch_server(
264
303
  server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
265
304
 
266
305
  # Launch processes for multi-node tensor parallelism
267
- if server_args.nnodes > 1:
268
- if server_args.node_rank != 0:
269
- tp_size_local = server_args.tp_size // server_args.nnodes
270
- gpu_ids = [
271
- i for _ in range(server_args.nnodes) for i in range(tp_size_local)
272
- ]
273
- tp_rank_range = list(
274
- range(
275
- server_args.node_rank * tp_size_local,
276
- (server_args.node_rank + 1) * tp_size_local,
277
- )
306
+ if server_args.nnodes > 1 and server_args.node_rank != 0:
307
+ tp_size_local = server_args.tp_size // server_args.nnodes
308
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
309
+ tp_rank_range = list(
310
+ range(
311
+ server_args.node_rank * tp_size_local,
312
+ (server_args.node_rank + 1) * tp_size_local,
278
313
  )
279
- procs = launch_tp_servers(
280
- gpu_ids,
281
- tp_rank_range,
282
- server_args,
283
- ports[3],
284
- model_overide_args,
285
- )
286
- while True:
287
- pass
314
+ )
315
+ procs = launch_tp_servers(
316
+ gpu_ids,
317
+ tp_rank_range,
318
+ server_args,
319
+ ports[3],
320
+ model_overide_args,
321
+ )
322
+
323
+ try:
324
+ for p in procs:
325
+ p.join()
326
+ finally:
327
+ kill_child_process(os.getpid(), including_parent=False)
328
+ return
288
329
 
289
330
  # Launch processes
290
331
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
@@ -297,11 +338,13 @@ def launch_server(
297
338
  start_process = start_controller_process_single
298
339
  else:
299
340
  start_process = start_controller_process_multi
341
+
300
342
  proc_controller = mp.Process(
301
343
  target=start_process,
302
344
  args=(server_args, port_args, pipe_controller_writer, model_overide_args),
303
345
  )
304
346
  proc_controller.start()
347
+
305
348
  proc_detoken = mp.Process(
306
349
  target=start_detokenizer_process,
307
350
  args=(
@@ -319,15 +362,11 @@ def launch_server(
319
362
  if controller_init_state != "init ok" or detoken_init_state != "init ok":
320
363
  proc_controller.kill()
321
364
  proc_detoken.kill()
322
- print(
323
- f"Initialization failed. controller_init_state: {controller_init_state}",
324
- flush=True,
365
+ raise RuntimeError(
366
+ "Initialization failed. "
367
+ f"controller_init_state: {controller_init_state}, "
368
+ f"detoken_init_state: {detoken_init_state}"
325
369
  )
326
- print(
327
- f"Initialization failed. detoken_init_state: {detoken_init_state}",
328
- flush=True,
329
- )
330
- sys.exit(1)
331
370
  assert proc_controller.is_alive() and proc_detoken.is_alive()
332
371
 
333
372
  # Add api key authorization
@@ -336,12 +375,12 @@ def launch_server(
336
375
 
337
376
  # Send a warmup request
338
377
  t = threading.Thread(
339
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
378
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
340
379
  )
341
380
  t.start()
342
381
 
343
- # Listen for requests
344
382
  try:
383
+ # Listen for requests
345
384
  uvicorn.run(
346
385
  app,
347
386
  host=server_args.host,
@@ -382,14 +421,14 @@ def _set_envs_and_config(server_args: ServerArgs):
382
421
  if not server_args.disable_flashinfer:
383
422
  assert_pkg_version(
384
423
  "flashinfer",
385
- "0.1.5",
424
+ "0.1.6",
386
425
  "Please uninstall the old version and "
387
426
  "reinstall the latest version by following the instructions "
388
427
  "at https://docs.flashinfer.ai/installation.html.",
389
428
  )
390
429
 
391
430
 
392
- def _wait_and_warmup(server_args, pipe_finish_writer):
431
+ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
393
432
  headers = {}
394
433
  url = server_args.url()
395
434
  if server_args.api_key:
@@ -412,8 +451,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
412
451
  if not success:
413
452
  if pipe_finish_writer is not None:
414
453
  pipe_finish_writer.send(last_traceback)
415
- print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
416
- sys.exit(1)
454
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
455
+ kill_child_process(pid, including_parent=False)
456
+ return
417
457
 
418
458
  # Send a warmup request
419
459
  request_name = "/generate" if model_info["is_generation"] else "/encode"
@@ -438,21 +478,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
438
478
  timeout=600,
439
479
  )
440
480
  assert res.status_code == 200, f"{res}"
441
- except Exception as e:
481
+ except Exception:
442
482
  last_traceback = get_exception_traceback()
443
483
  if pipe_finish_writer is not None:
444
484
  pipe_finish_writer.send(last_traceback)
445
- print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
446
- sys.exit(1)
447
-
448
- # Print warnings here
449
- if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
450
- logger.warning(
451
- "You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
452
- "This combination is an experimental feature and we noticed it can lead to "
453
- "wrong generation results. If you want to use chunked prefill, it is recommended "
454
- "not using `--disable-radix-cache`."
455
- )
485
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
486
+ kill_child_process(pid, including_parent=False)
487
+ return
456
488
 
457
489
  logger.info("The server is fired up and ready to roll!")
458
490
  if pipe_finish_writer is not None:
@@ -490,6 +522,7 @@ class Runtime:
490
522
 
491
523
  self.pid = None
492
524
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
525
+
493
526
  proc = mp.Process(
494
527
  target=launch_server,
495
528
  args=(self.server_args, model_overide_args, pipe_writer),
@@ -566,15 +599,17 @@ class Runtime:
566
599
 
567
600
  def generate(
568
601
  self,
569
- prompt: str,
602
+ prompt: Union[str, List[str]],
570
603
  sampling_params: Optional[Dict] = None,
571
604
  return_logprob: Optional[Union[List[bool], bool]] = False,
605
+ logprob_start_len: Optional[Union[List[int], int]] = None,
572
606
  top_logprobs_num: Optional[Union[List[int], int]] = None,
573
607
  ):
574
608
  json_data = {
575
609
  "text": prompt,
576
610
  "sampling_params": sampling_params,
577
611
  "return_logprob": return_logprob,
612
+ "logprob_start_len": logprob_start_len,
578
613
  "top_logprobs_num": top_logprobs_num,
579
614
  }
580
615
  response = requests.post(
@@ -585,7 +620,7 @@ class Runtime:
585
620
 
586
621
  def encode(
587
622
  self,
588
- prompt: str,
623
+ prompt: Union[str, List[str]],
589
624
  ):
590
625
  json_data = {
591
626
  "text": prompt,
sglang/srt/server_args.py CHANGED
@@ -33,11 +33,13 @@ class ServerArgs:
33
33
  skip_tokenizer_init: bool = False
34
34
  load_format: str = "auto"
35
35
  dtype: str = "auto"
36
+ kv_cache_dtype: str = "auto"
36
37
  trust_remote_code: bool = True
37
38
  context_length: Optional[int] = None
38
39
  quantization: Optional[str] = None
39
40
  served_model_name: Optional[str] = None
40
41
  chat_template: Optional[str] = None
42
+ is_embedding: bool = False
41
43
 
42
44
  # Port
43
45
  host: str = "127.0.0.1"
@@ -79,12 +81,14 @@ class ServerArgs:
79
81
  disable_radix_cache: bool = False
80
82
  disable_regex_jump_forward: bool = False
81
83
  disable_cuda_graph: bool = False
84
+ disable_cuda_graph_padding: bool = False
82
85
  disable_disk_cache: bool = False
86
+ disable_custom_all_reduce: bool = False
87
+ enable_mixed_chunk: bool = False
83
88
  enable_torch_compile: bool = False
84
89
  enable_p2p_check: bool = False
85
90
  enable_mla: bool = False
86
- attention_reduce_in_fp32: bool = False
87
- efficient_weight_load: bool = False
91
+ triton_attention_reduce_in_fp32: bool = False
88
92
 
89
93
  # Distributed args
90
94
  nccl_init_addr: Optional[str] = None
@@ -193,11 +197,23 @@ class ServerArgs:
193
197
  '* "float" is shorthand for FP32 precision.\n'
194
198
  '* "float32" for FP32 precision.',
195
199
  )
200
+ parser.add_argument(
201
+ "--kv-cache-dtype",
202
+ type=str,
203
+ default=ServerArgs.kv_cache_dtype,
204
+ choices=["auto", "fp8_e5m2"],
205
+ help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
206
+ )
196
207
  parser.add_argument(
197
208
  "--trust-remote-code",
198
209
  action="store_true",
199
210
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
200
211
  )
212
+ parser.add_argument(
213
+ "--is-embedding",
214
+ action="store_true",
215
+ help="Whether to use a CausalLM as an embedding model.",
216
+ )
201
217
  parser.add_argument(
202
218
  "--context-length",
203
219
  type=int,
@@ -391,11 +407,27 @@ class ServerArgs:
391
407
  action="store_true",
392
408
  help="Disable cuda graph.",
393
409
  )
410
+ parser.add_argument(
411
+ "--disable-cuda-graph-padding",
412
+ action="store_true",
413
+ help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
414
+ )
394
415
  parser.add_argument(
395
416
  "--disable-disk-cache",
396
417
  action="store_true",
397
418
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
398
419
  )
420
+ parser.add_argument(
421
+ "--disable-custom-all-reduce",
422
+ action="store_true",
423
+ default=False,
424
+ help="Disable the custom all-reduce kernel and fall back to NCCL.",
425
+ )
426
+ parser.add_argument(
427
+ "--enable-mixed-chunk",
428
+ action="store_true",
429
+ help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
430
+ )
399
431
  parser.add_argument(
400
432
  "--enable-torch-compile",
401
433
  action="store_true",
@@ -409,13 +441,13 @@ class ServerArgs:
409
441
  parser.add_argument(
410
442
  "--enable-mla",
411
443
  action="store_true",
412
- help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
444
+ help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
413
445
  )
414
446
  parser.add_argument(
415
- "--attention-reduce-in-fp32",
447
+ "--triton-attention-reduce-in-fp32",
416
448
  action="store_true",
417
449
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
418
- "This only affects Triton attention kernels",
450
+ "This only affects Triton attention kernels.",
419
451
  )
420
452
  parser.add_argument(
421
453
  "--efficient-weight-load",
@@ -433,15 +465,6 @@ class ServerArgs:
433
465
  def url(self):
434
466
  return f"http://{self.host}:{self.port}"
435
467
 
436
- def print_mode_args(self):
437
- return (
438
- f"disable_flashinfer={self.disable_flashinfer}, "
439
- f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
440
- f"disable_radix_cache={self.disable_radix_cache}, "
441
- f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
442
- f"disable_disk_cache={self.disable_disk_cache}, "
443
- )
444
-
445
468
  def check_server_args(self):
446
469
  assert (
447
470
  self.tp_size % self.nnodes == 0
@@ -449,8 +472,13 @@ class ServerArgs:
449
472
  assert not (
450
473
  self.dp_size > 1 and self.node_rank is not None
451
474
  ), "multi-node data parallel is not supported"
475
+ if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
476
+ logger.info(
477
+ "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
478
+ )
479
+ self.trust_remote_code = False
452
480
  if "gemma-2" in self.model_path.lower():
453
- logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
481
+ logger.info("When using sliding window in gemma-2, turn on flashinfer.")
454
482
  self.disable_flashinfer = False
455
483
 
456
484
 
sglang/srt/utils.py CHANGED
@@ -224,13 +224,18 @@ def is_multimodal_model(model):
224
224
  raise ValueError("unrecognized type")
225
225
 
226
226
 
227
- def is_generation_model(model_architectures):
227
+ def is_generation_model(model_architectures, is_embedding: bool = False):
228
+ # We have two ways to determine whether a model is a generative model.
229
+ # 1. Check the model architectue
230
+ # 2. check the `is_embedding` server args
231
+
228
232
  if (
229
233
  "LlamaEmbeddingModel" in model_architectures
230
234
  or "MistralModel" in model_architectures
231
235
  ):
232
236
  return False
233
- return True
237
+ else:
238
+ return not is_embedding
234
239
 
235
240
 
236
241
  def decode_video_base64(video_base64):
@@ -347,7 +352,7 @@ def suppress_other_loggers():
347
352
  logging.WARN
348
353
  )
349
354
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
350
- logging.getLogger("vllm.utils").setLevel(logging.WARN)
355
+ logging.getLogger("vllm.utils").setLevel(logging.ERROR)
351
356
 
352
357
 
353
358
  def assert_pkg_version(pkg: str, min_version: str, message: str):
@@ -369,14 +374,11 @@ def kill_parent_process():
369
374
  """Kill the parent process and all children of the parent process."""
370
375
  current_process = psutil.Process()
371
376
  parent_process = current_process.parent()
372
- children = parent_process.children(recursive=True)
373
- for child in children:
374
- if child.pid != current_process.pid:
375
- os.kill(child.pid, 9)
376
- os.kill(parent_process.pid, 9)
377
+ kill_child_process(parent_process.pid, skip_pid=current_process.pid)
377
378
 
378
379
 
379
- def kill_child_process(pid, including_parent=True):
380
+ def kill_child_process(pid, including_parent=True, skip_pid=None):
381
+ """Kill the process and all its children process."""
380
382
  try:
381
383
  parent = psutil.Process(pid)
382
384
  except psutil.NoSuchProcess:
@@ -384,6 +386,8 @@ def kill_child_process(pid, including_parent=True):
384
386
 
385
387
  children = parent.children(recursive=True)
386
388
  for child in children:
389
+ if child.pid == skip_pid:
390
+ continue
387
391
  try:
388
392
  child.kill()
389
393
  except psutil.NoSuchProcess:
@@ -452,10 +456,6 @@ def monkey_patch_vllm_dummy_weight_loader():
452
456
  quant_method = getattr(module, "quant_method", None)
453
457
  if quant_method is not None:
454
458
  quant_method.process_weights_after_loading(module)
455
- # FIXME: Remove this after Mixtral is updated
456
- # to use quant_method.
457
- if hasattr(module, "process_weights_after_loading"):
458
- module.process_weights_after_loading()
459
459
 
460
460
  # NOTE(woosuk): For accurate performance evaluation, we assign
461
461
  # random values to the weights.
@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
692
692
  setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
693
693
 
694
694
 
695
- def add_api_key_middleware(app, api_key):
695
+ def add_api_key_middleware(app, api_key: str):
696
696
  @app.middleware("http")
697
697
  async def authentication(request, call_next):
698
698
  if request.method == "OPTIONS":
@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key):
704
704
  return await call_next(request)
705
705
 
706
706
 
707
- def prepare_model(model_path):
707
+ def prepare_model(model_path: str):
708
708
  if "SGLANG_USE_MODELSCOPE" in os.environ:
709
709
  if not os.path.exists(model_path):
710
710
  from modelscope import snapshot_download
@@ -713,7 +713,7 @@ def prepare_model(model_path):
713
713
  return model_path
714
714
 
715
715
 
716
- def prepare_tokenizer(tokenizer_path):
716
+ def prepare_tokenizer(tokenizer_path: str):
717
717
  if "SGLANG_USE_MODELSCOPE" in os.environ:
718
718
  if not os.path.exists(tokenizer_path):
719
719
  from modelscope import snapshot_download
@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
722
722
  tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
723
723
  )
724
724
  return tokenizer_path
725
+
726
+
727
+ def configure_logger(server_args, prefix: str = ""):
728
+ format = f"[%(asctime)s{prefix}] %(message)s"
729
+ logging.basicConfig(
730
+ level=getattr(logging, server_args.log_level.upper()),
731
+ format=format,
732
+ datefmt="%H:%M:%S",
733
+ force=True,
734
+ )
sglang/test/runners.py CHANGED
@@ -14,7 +14,7 @@ limitations under the License.
14
14
  """
15
15
 
16
16
  import json
17
- import multiprocessing
17
+ import multiprocessing as mp
18
18
  import os
19
19
  from dataclasses import dataclass
20
20
  from typing import List, Union
@@ -24,15 +24,15 @@ import torch.nn.functional as F
24
24
  from transformers import AutoModelForCausalLM, AutoTokenizer
25
25
 
26
26
  from sglang.srt.server import Runtime
27
- from sglang.srt.utils import is_generation_model
27
+ from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
28
28
 
29
29
  DEFAULT_PROMPTS = [
30
30
  # the output of gemma-2-2b from SRT is unstable on the commented prompt
31
31
  # "The capital of France is",
32
+ "Apple is red. Banana is Yellow. " * 800 + "Apple is",
32
33
  "The capital of the United Kindom is",
33
34
  "Today is a sunny day and I like",
34
35
  "AI is a field of computer science focused on",
35
- "Apple is red. Banana is Yellow. " * 800 + "Apple is",
36
36
  ]
37
37
 
38
38
  dirpath = os.path.dirname(__file__)
@@ -63,44 +63,37 @@ class HFRunner:
63
63
  def __init__(
64
64
  self,
65
65
  model_path,
66
- torch_dtype=torch.float16,
67
- is_generation_model=None,
66
+ torch_dtype,
67
+ is_generation,
68
68
  ):
69
- self.in_queue = multiprocessing.Queue()
70
- self.out_queue = multiprocessing.Queue()
69
+ self.is_generation = is_generation
70
+
71
+ self.in_queue = mp.Queue()
72
+ self.out_queue = mp.Queue()
71
73
 
72
- self.model_proc = multiprocessing.Process(
74
+ self.model_proc = mp.Process(
73
75
  target=self.start_model_process,
74
76
  args=(
75
77
  self.in_queue,
76
78
  self.out_queue,
77
79
  model_path,
78
80
  torch_dtype,
79
- is_generation_model,
80
81
  ),
81
82
  )
82
83
  self.model_proc.start()
83
84
 
84
- def start_model_process(
85
- self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
86
- ):
85
+ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
87
86
  self.tokenizer = AutoTokenizer.from_pretrained(
88
87
  model_path,
89
88
  torch_dtype=torch_dtype,
90
- trust_remote_code=True,
91
89
  )
92
90
 
93
- self.is_generation_model = (
94
- is_generation_model(model_path)
95
- if is_generation_model is None
96
- else is_generation_model
97
- )
98
- if self.is_generation_model:
91
+ if self.is_generation:
99
92
  self.model = AutoModelForCausalLM.from_pretrained(
100
93
  model_path,
101
94
  torch_dtype=torch_dtype,
95
+ trust_remote_code=False,
102
96
  low_cpu_mem_usage=True,
103
- trust_remote_code=True,
104
97
  ).cuda()
105
98
  else:
106
99
  from sentence_transformers import SentenceTransformer
@@ -113,7 +106,7 @@ class HFRunner:
113
106
  while True:
114
107
  prompts, max_new_tokens = in_queue.get()
115
108
  if prompts is not None:
116
- if self.is_generation_model:
109
+ if self.is_generation:
117
110
  output_strs = []
118
111
  prefill_logprobs = []
119
112
  for p in prompts:
@@ -176,22 +169,20 @@ class SRTRunner:
176
169
  def __init__(
177
170
  self,
178
171
  model_path,
172
+ torch_dtype,
173
+ is_generation,
179
174
  tp_size=1,
180
- torch_dtype=torch.float16,
181
- is_generation_model=None,
182
- port=5157,
175
+ port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
183
176
  ):
184
- self.is_generation_model = (
185
- is_generation_model(model_path)
186
- if is_generation_model is None
187
- else is_generation_model
188
- )
177
+ self.is_generation = is_generation
189
178
  self.runtime = Runtime(
190
179
  model_path=model_path,
191
180
  tp_size=tp_size,
192
181
  dtype=get_dtype_str(torch_dtype),
193
182
  port=port,
194
183
  mem_fraction_static=0.7,
184
+ trust_remote_code=False,
185
+ is_embedding=not self.is_generation,
195
186
  )
196
187
 
197
188
  def forward(
@@ -199,7 +190,7 @@ class SRTRunner:
199
190
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
200
191
  max_new_tokens=8,
201
192
  ):
202
- if self.is_generation_model:
193
+ if self.is_generation:
203
194
  # the return value contains logprobs from prefill
204
195
  output_strs = []
205
196
  top_input_logprobs = []
@@ -209,6 +200,7 @@ class SRTRunner:
209
200
  prompt,
210
201
  sampling_params=sampling_params,
211
202
  return_logprob=True,
203
+ logprob_start_len=0,
212
204
  top_logprobs_num=NUM_TOP_LOGPROBS,
213
205
  )
214
206
  response = json.loads(response)