sglang 0.2.9.post1__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
251
251
  if end_point == "/v1/chat/completions":
252
252
  responses = v1_chat_generate_response(request, ret, to_file=True)
253
253
  else:
254
- responses = v1_generate_response(request, ret, to_file=True)
254
+ responses = v1_generate_response(
255
+ request, ret, tokenizer_manager, to_file=True
256
+ )
255
257
 
256
258
  except Exception as e:
257
259
  error_json = {
@@ -339,6 +341,7 @@ def v1_generate_request(all_requests):
339
341
  return_logprobs = []
340
342
  top_logprobs_nums = []
341
343
  first_prompt_type = type(all_requests[0].prompt)
344
+
342
345
  for request in all_requests:
343
346
  prompt = request.prompt
344
347
  assert (
@@ -364,7 +367,7 @@ def v1_generate_request(all_requests):
364
367
  )
365
368
  if len(all_requests) > 1 and request.n > 1:
366
369
  raise ValueError(
367
- "Batch operation is not supported for completions from files"
370
+ "Parallel sampling is not supported for completions from files"
368
371
  )
369
372
 
370
373
  if len(all_requests) == 1:
@@ -377,10 +380,11 @@ def v1_generate_request(all_requests):
377
380
  else:
378
381
  prompt_kwargs = {"input_ids": prompt}
379
382
  else:
380
- if isinstance(prompts[0], str):
383
+ if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
381
384
  prompt_kwargs = {"text": prompts}
382
385
  else:
383
386
  prompt_kwargs = {"input_ids": prompts}
387
+
384
388
  adapted_request = GenerateReqInput(
385
389
  **prompt_kwargs,
386
390
  sampling_params=sampling_params_list,
@@ -389,35 +393,52 @@ def v1_generate_request(all_requests):
389
393
  return_text_in_logprobs=True,
390
394
  stream=all_requests[0].stream,
391
395
  )
396
+
392
397
  if len(all_requests) == 1:
393
398
  return adapted_request, all_requests[0]
394
399
  return adapted_request, all_requests
395
400
 
396
401
 
397
- def v1_generate_response(request, ret, to_file=False):
402
+ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
398
403
  choices = []
399
404
  echo = False
400
405
 
401
- if (not isinstance(request, List)) and request.echo:
406
+ if (not isinstance(request, list)) and request.echo:
402
407
  # TODO: handle the case propmt is token ids
403
- if isinstance(request.prompt, list):
408
+ if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
409
+ # for the case of multiple str prompts
404
410
  prompts = request.prompt
411
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
412
+ # for the case of multiple token ids prompts
413
+ prompts = [
414
+ tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
415
+ for prompt in request.prompt
416
+ ]
417
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
418
+ # for the case of single token ids prompt
419
+ prompts = [
420
+ tokenizer_manager.tokenizer.decode(
421
+ request.prompt, skip_special_tokens=True
422
+ )
423
+ ]
405
424
  else:
425
+ # for the case of single str prompt
406
426
  prompts = [request.prompt]
407
427
  echo = True
408
428
 
409
429
  for idx, ret_item in enumerate(ret):
410
430
  text = ret_item["text"]
411
- if isinstance(request, List) and request[idx].echo:
431
+ if isinstance(request, list) and request[idx].echo:
412
432
  echo = True
413
433
  text = request[idx].prompt + text
414
- if (not isinstance(request, List)) and echo:
415
- text = prompts[idx] + text
434
+ if (not isinstance(request, list)) and echo:
435
+ prompt_index = idx // request.n
436
+ text = prompts[prompt_index] + text
416
437
 
417
438
  logprobs = False
418
- if isinstance(request, List) and request[idx].logprobs:
439
+ if isinstance(request, list) and request[idx].logprobs:
419
440
  logprobs = True
420
- elif (not isinstance(request, List)) and request.logprobs:
441
+ elif (not isinstance(request, list)) and request.logprobs:
421
442
  logprobs = True
422
443
  if logprobs:
423
444
  if echo:
@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False):
479
500
  responses.append(response)
480
501
  return responses
481
502
  else:
503
+ prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
482
504
  completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
483
505
  response = CompletionResponse(
484
506
  id=ret[0]["meta_info"]["id"],
485
507
  model=request.model,
486
508
  choices=choices,
487
509
  usage=UsageInfo(
488
- prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
510
+ prompt_tokens=prompt_tokens,
489
511
  completion_tokens=completion_tokens,
490
- total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
512
+ total_tokens=prompt_tokens + completion_tokens,
491
513
  ),
492
514
  )
493
515
  return response
@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
513
535
 
514
536
  if not stream_buffer: # The first chunk
515
537
  if request.echo:
538
+ if isinstance(request.prompt, str):
539
+ # for the case of single str prompts
540
+ prompts = request.prompt
541
+ elif isinstance(request.prompt, list) and isinstance(
542
+ request.prompt[0], int
543
+ ):
544
+ prompts = tokenizer_manager.tokenizer.decode(
545
+ request.prompt, skip_special_tokens=True
546
+ )
547
+
516
548
  # Prepend prompt in response text.
517
- text = request.prompt + text
549
+ text = prompts + text
518
550
 
519
551
  if request.logprobs:
520
552
  # The first chunk and echo is enabled.
@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
539
571
  "output_top_logprobs"
540
572
  ][n_prev_token:],
541
573
  )
542
-
543
574
  n_prev_token = len(
544
575
  content["meta_info"]["output_token_logprobs"]
545
576
  )
@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
588
619
  if not isinstance(ret, list):
589
620
  ret = [ret]
590
621
 
591
- response = v1_generate_response(request, ret)
622
+ response = v1_generate_response(request, ret, tokenizer_manager)
592
623
  return response
593
624
 
594
625
 
@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
626
657
  prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
627
658
  else:
628
659
  # Use the raw prompt and stop strings if the messages is already a string.
629
- prompt = request.messages
660
+ prompt_ids = request.messages
630
661
  stop = request.stop
631
662
  image_data = None
632
663
  input_ids.append(prompt_ids)
@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
647
678
  image_data_list.append(image_data)
648
679
  if len(all_requests) == 1:
649
680
  input_ids = input_ids[0]
681
+ if isinstance(input_ids, str):
682
+ prompt_kwargs = {"text": input_ids}
683
+ else:
684
+ prompt_kwargs = {"input_ids": input_ids}
650
685
  sampling_params_list = sampling_params_list[0]
651
686
  image_data = image_data_list[0]
652
687
  return_logprobs = return_logprobs[0]
653
688
  top_logprobs_nums = top_logprobs_nums[0]
689
+ else:
690
+ if isinstance(input_ids[0], str):
691
+ prompt_kwargs = {"text": input_ids}
692
+ else:
693
+ prompt_kwargs = {"input_ids": input_ids}
654
694
  adapted_request = GenerateReqInput(
655
- input_ids=input_ids,
695
+ **prompt_kwargs,
656
696
  image_data=image_data,
657
697
  sampling_params=sampling_params_list,
658
698
  return_logprob=return_logprobs,
@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
672
712
 
673
713
  for idx, ret_item in enumerate(ret):
674
714
  logprobs = False
675
- if isinstance(request, List) and request[idx].logprobs:
715
+ if isinstance(request, list) and request[idx].logprobs:
676
716
  logprobs = True
677
- elif (not isinstance(request, List)) and request.logprobs:
717
+ elif (not isinstance(request, list)) and request.logprobs:
678
718
  logprobs = True
679
719
  if logprobs:
680
720
  logprobs = to_openai_style_logprobs(
@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
779
819
  is_first = True
780
820
 
781
821
  stream_buffer = ""
822
+ n_prev_token = 0
782
823
  try:
783
824
  async for content in tokenizer_manager.generate_request(
784
825
  adapted_request, raw_request
785
826
  ):
827
+ prompt_tokens = content["meta_info"]["prompt_tokens"]
828
+ completion_tokens = content["meta_info"]["completion_tokens"]
829
+ if request.logprobs:
830
+ logprobs = to_openai_style_logprobs(
831
+ output_token_logprobs=content["meta_info"][
832
+ "output_token_logprobs"
833
+ ][n_prev_token:],
834
+ output_top_logprobs=content["meta_info"][
835
+ "output_top_logprobs"
836
+ ][n_prev_token:],
837
+ )
838
+
839
+ n_prev_token = len(
840
+ content["meta_info"]["output_token_logprobs"]
841
+ )
842
+ token_logprobs = []
843
+ for token, logprob in zip(
844
+ logprobs.tokens, logprobs.token_logprobs
845
+ ):
846
+ token_bytes = list(token.encode("utf-8"))
847
+ top_logprobs = []
848
+ if logprobs.top_logprobs:
849
+ for top_token, top_logprob in logprobs.top_logprobs[
850
+ 0
851
+ ].items():
852
+ top_token_bytes = list(top_token.encode("utf-8"))
853
+ top_logprobs.append(
854
+ TopLogprob(
855
+ token=top_token,
856
+ bytes=top_token_bytes,
857
+ logprob=top_logprob,
858
+ )
859
+ )
860
+ token_logprobs.append(
861
+ ChatCompletionTokenLogprob(
862
+ token=token,
863
+ bytes=token_bytes,
864
+ logprob=logprob,
865
+ top_logprobs=top_logprobs,
866
+ )
867
+ )
868
+
869
+ choice_logprobs = ChoiceLogprobs(content=token_logprobs)
870
+
871
+ else:
872
+ choice_logprobs = None
873
+
786
874
  if is_first:
787
875
  # First chunk with role
788
876
  is_first = False
@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
790
878
  index=0,
791
879
  delta=DeltaMessage(role="assistant"),
792
880
  finish_reason=content["meta_info"]["finish_reason"],
881
+ logprobs=choice_logprobs,
793
882
  )
794
883
  chunk = ChatCompletionStreamResponse(
795
884
  id=content["meta_info"]["id"],
796
885
  choices=[choice_data],
797
886
  model=request.model,
887
+ usage=UsageInfo(
888
+ prompt_tokens=prompt_tokens,
889
+ completion_tokens=completion_tokens,
890
+ total_tokens=prompt_tokens + completion_tokens,
891
+ ),
798
892
  )
799
893
  yield f"data: {chunk.model_dump_json()}\n\n"
800
894
 
@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
805
899
  index=0,
806
900
  delta=DeltaMessage(content=delta),
807
901
  finish_reason=content["meta_info"]["finish_reason"],
902
+ logprobs=choice_logprobs,
808
903
  )
809
904
  chunk = ChatCompletionStreamResponse(
810
905
  id=content["meta_info"]["id"],
811
906
  choices=[choice_data],
812
907
  model=request.model,
908
+ usage=UsageInfo(
909
+ prompt_tokens=prompt_tokens,
910
+ completion_tokens=completion_tokens,
911
+ total_tokens=prompt_tokens + completion_tokens,
912
+ ),
813
913
  )
814
914
  yield f"data: {chunk.model_dump_json()}\n\n"
815
915
  except ValueError as e:
@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel):
278
278
  class ChatCompletionResponseStreamChoice(BaseModel):
279
279
  index: int
280
280
  delta: DeltaMessage
281
- logprobs: Optional[LogProbs] = None
281
+ logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
282
282
  finish_reason: Optional[str] = None
283
283
 
284
284
 
sglang/srt/server.py CHANGED
@@ -28,7 +28,7 @@ import sys
28
28
  import threading
29
29
  import time
30
30
  from http import HTTPStatus
31
- from typing import Dict, Optional
31
+ from typing import Dict, List, Optional, Union
32
32
 
33
33
  # Fix a bug of Python threading
34
34
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
67
67
  from sglang.srt.openai_api.protocol import ModelCard, ModelList
68
68
  from sglang.srt.server_args import PortArgs, ServerArgs
69
69
  from sglang.srt.utils import (
70
- API_KEY_HEADER_NAME,
71
- APIKeyValidatorMiddleware,
70
+ add_api_key_middleware,
72
71
  allocate_init_ports,
73
72
  assert_pkg_version,
74
73
  enable_show_time_cost,
75
74
  kill_child_process,
76
75
  maybe_set_triton_cache_manager,
76
+ set_torch_compile_config,
77
77
  set_ulimit,
78
78
  )
79
79
  from sglang.utils import get_exception_traceback
@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
158
158
  return await v1_chat_completions(tokenizer_manager, raw_request)
159
159
 
160
160
 
161
+ @app.get("/v1/models")
162
+ def available_models():
163
+ """Show available models."""
164
+ served_model_names = [tokenizer_manager.served_model_name]
165
+ model_cards = []
166
+ for served_model_name in served_model_names:
167
+ model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
168
+ return ModelList(data=model_cards)
169
+
170
+
161
171
  @app.post("/v1/files")
162
172
  async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
163
173
  return await v1_files_create(
@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
187
197
  return await v1_retrieve_file_content(file_id)
188
198
 
189
199
 
190
- @app.get("/v1/models")
191
- def available_models():
192
- """Show available models."""
193
- served_model_names = [tokenizer_manager.served_model_name]
194
- model_cards = []
195
- for served_model_name in served_model_names:
196
- model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
197
- return ModelList(data=model_cards)
198
-
199
-
200
- def _set_torch_compile_config():
201
- # The following configurations are for torch compile optimizations
202
- import torch._dynamo.config
203
- import torch._inductor.config
204
-
205
- torch._inductor.config.coordinate_descent_tuning = True
206
- torch._inductor.config.triton.unique_kernel_names = True
207
- torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
208
-
209
- # FIXME: tmp workaround
210
- torch._dynamo.config.accumulated_cache_size_limit = 256
211
-
212
-
213
- def set_envs_and_config(server_args: ServerArgs):
214
- # Set global environments
215
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
216
- os.environ["NCCL_CUMEM_ENABLE"] = "0"
217
- os.environ["NCCL_NVLS_ENABLE"] = "0"
218
- os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
219
-
220
- # Set ulimit
221
- set_ulimit()
222
-
223
- # Enable show time cost for debugging
224
- if server_args.show_time_cost:
225
- enable_show_time_cost()
226
-
227
- # Disable disk cache
228
- if server_args.disable_disk_cache:
229
- disable_cache()
230
-
231
- # Fix triton bugs
232
- if server_args.tp_size * server_args.dp_size > 1:
233
- # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
234
- maybe_set_triton_cache_manager()
235
-
236
- # Set torch compile config
237
- if server_args.enable_torch_compile:
238
- _set_torch_compile_config()
239
-
240
- # Set global chat template
241
- if server_args.chat_template:
242
- # TODO: replace this with huggingface transformers template
243
- load_chat_template_for_openai_api(server_args.chat_template)
244
-
245
-
246
200
  def launch_server(
247
201
  server_args: ServerArgs,
248
202
  model_overide_args: Optional[dict] = None,
249
203
  pipe_finish_writer: Optional[mp.connection.Connection] = None,
250
204
  ):
251
- server_args.check_server_args()
252
-
253
205
  """Launch an HTTP server."""
254
206
  global tokenizer_manager
255
207
 
@@ -258,16 +210,8 @@ def launch_server(
258
210
  format="%(message)s",
259
211
  )
260
212
 
261
- if not server_args.disable_flashinfer:
262
- assert_pkg_version(
263
- "flashinfer",
264
- "0.1.3",
265
- "Please uninstall the old version and "
266
- "reinstall the latest version by following the instructions "
267
- "at https://docs.flashinfer.ai/installation.html.",
268
- )
269
-
270
- set_envs_and_config(server_args)
213
+ server_args.check_server_args()
214
+ _set_envs_and_config(server_args)
271
215
 
272
216
  # Allocate ports
273
217
  server_args.port, server_args.additional_ports = allocate_init_ports(
@@ -284,7 +228,7 @@ def launch_server(
284
228
  )
285
229
  logger.info(f"{server_args=}")
286
230
 
287
- # Handle multi-node tensor parallelism
231
+ # Launch processes for multi-node tensor parallelism
288
232
  if server_args.nnodes > 1:
289
233
  if server_args.node_rank != 0:
290
234
  tp_size_local = server_args.tp_size // server_args.nnodes
@@ -349,8 +293,9 @@ def launch_server(
349
293
  sys.exit(1)
350
294
  assert proc_controller.is_alive() and proc_detoken.is_alive()
351
295
 
352
- if server_args.api_key and server_args.api_key != "":
353
- app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
296
+ # Add api key authorization
297
+ if server_args.api_key:
298
+ add_api_key_middleware(app, server_args.api_key)
354
299
 
355
300
  # Send a warmup request
356
301
  t = threading.Thread(
@@ -372,15 +317,58 @@ def launch_server(
372
317
  t.join()
373
318
 
374
319
 
320
+ def _set_envs_and_config(server_args: ServerArgs):
321
+ # Set global environments
322
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
323
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
324
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
325
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
326
+
327
+ # Set ulimit
328
+ set_ulimit()
329
+
330
+ # Enable show time cost for debugging
331
+ if server_args.show_time_cost:
332
+ enable_show_time_cost()
333
+
334
+ # Disable disk cache
335
+ if server_args.disable_disk_cache:
336
+ disable_cache()
337
+
338
+ # Fix triton bugs
339
+ if server_args.tp_size * server_args.dp_size > 1:
340
+ # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
341
+ maybe_set_triton_cache_manager()
342
+
343
+ # Set torch compile config
344
+ if server_args.enable_torch_compile:
345
+ set_torch_compile_config()
346
+
347
+ # Set global chat template
348
+ if server_args.chat_template:
349
+ # TODO: replace this with huggingface transformers template
350
+ load_chat_template_for_openai_api(server_args.chat_template)
351
+
352
+ # Check flashinfer version
353
+ if not server_args.disable_flashinfer:
354
+ assert_pkg_version(
355
+ "flashinfer",
356
+ "0.1.3",
357
+ "Please uninstall the old version and "
358
+ "reinstall the latest version by following the instructions "
359
+ "at https://docs.flashinfer.ai/installation.html.",
360
+ )
361
+
362
+
375
363
  def _wait_and_warmup(server_args, pipe_finish_writer):
376
364
  headers = {}
377
365
  url = server_args.url()
378
366
  if server_args.api_key:
379
- headers[API_KEY_HEADER_NAME] = server_args.api_key
367
+ headers["Authorization"] = f"Bearer {server_args.api_key}"
380
368
 
381
369
  # Wait until the server is launched
382
370
  for _ in range(120):
383
- time.sleep(0.5)
371
+ time.sleep(1)
384
372
  try:
385
373
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
386
374
  break
@@ -481,10 +469,10 @@ class Runtime:
481
469
  trust_remote_code=self.server_args.trust_remote_code,
482
470
  )
483
471
 
484
- async def add_request(
472
+ async def async_generate(
485
473
  self,
486
474
  prompt: str,
487
- sampling_params: Dict,
475
+ sampling_params: Optional[Dict] = None,
488
476
  ):
489
477
  json_data = {
490
478
  "text": prompt,
@@ -507,5 +495,26 @@ class Runtime:
507
495
  yield cur
508
496
  pos += len(cur)
509
497
 
498
+ add_request = async_generate
499
+
500
+ def generate(
501
+ self,
502
+ prompt: str,
503
+ sampling_params: Optional[Dict] = None,
504
+ return_logprob: Optional[Union[List[bool], bool]] = False,
505
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
506
+ ):
507
+ json_data = {
508
+ "text": prompt,
509
+ "sampling_params": sampling_params,
510
+ "return_logprob": return_logprob,
511
+ "top_logprobs_num": top_logprobs_num,
512
+ }
513
+ response = requests.post(
514
+ self.url + "/generate",
515
+ json=json_data,
516
+ )
517
+ return json.dumps(response.json())
518
+
510
519
  def __del__(self):
511
520
  self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -61,7 +61,7 @@ class ServerArgs:
61
61
  show_time_cost: bool = False
62
62
 
63
63
  # Other
64
- api_key: str = ""
64
+ api_key: Optional[str] = None
65
65
  file_storage_pth: str = "SGlang_storage"
66
66
 
67
67
  # Data parallelism
@@ -80,6 +80,7 @@ class ServerArgs:
80
80
  disable_disk_cache: bool = False
81
81
  enable_torch_compile: bool = False
82
82
  enable_p2p_check: bool = False
83
+ enable_mla: bool = False
83
84
  attention_reduce_in_fp32: bool = False
84
85
  efficient_weight_load: bool = False
85
86
 
@@ -306,7 +307,7 @@ class ServerArgs:
306
307
  "--api-key",
307
308
  type=str,
308
309
  default=ServerArgs.api_key,
309
- help="Set API key of the server.",
310
+ help="Set API key of the server. It is also used in the OpenAI API compatible server.",
310
311
  )
311
312
  parser.add_argument(
312
313
  "--file-storage-pth",
@@ -393,6 +394,11 @@ class ServerArgs:
393
394
  action="store_true",
394
395
  help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
395
396
  )
397
+ parser.add_argument(
398
+ "--enable-mla",
399
+ action="store_true",
400
+ help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
401
+ )
396
402
  parser.add_argument(
397
403
  "--attention-reduce-in-fp32",
398
404
  action="store_true",
sglang/srt/utils.py CHANGED
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
539
539
  raise RuntimeError("Could not create or locate cache dir")
540
540
 
541
541
 
542
- API_KEY_HEADER_NAME = "X-API-Key"
543
-
544
-
545
- class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
546
- def __init__(self, app, api_key: str):
547
- super().__init__(app)
548
- self.api_key = api_key
549
-
550
- async def dispatch(self, request, call_next):
551
- # extract API key from the request headers
552
- api_key_header = request.headers.get(API_KEY_HEADER_NAME)
553
- if not api_key_header or api_key_header != self.api_key:
554
- return JSONResponse(
555
- status_code=403,
556
- content={"detail": "Invalid API Key"},
557
- )
558
- response = await call_next(request)
559
- return response
560
-
561
-
562
542
  def get_ip_address(ifname):
563
543
  """
564
544
  Get the IP address of a network interface.
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
642
622
  dist.destroy_process_group()
643
623
 
644
624
 
625
+ def set_torch_compile_config():
626
+ # The following configurations are for torch compile optimizations
627
+ import torch._dynamo.config
628
+ import torch._inductor.config
629
+
630
+ torch._inductor.config.coordinate_descent_tuning = True
631
+ torch._inductor.config.triton.unique_kernel_names = True
632
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
633
+
634
+ # FIXME: tmp workaround
635
+ torch._dynamo.config.accumulated_cache_size_limit = 256
636
+
637
+
645
638
  def set_ulimit(target_soft_limit=65535):
646
639
  resource_type = resource.RLIMIT_NOFILE
647
640
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
700
693
  origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
701
694
 
702
695
  setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
696
+
697
+
698
+ def add_api_key_middleware(app, api_key):
699
+ @app.middleware("http")
700
+ async def authentication(request, call_next):
701
+ if request.method == "OPTIONS":
702
+ return await call_next(request)
703
+ if request.url.path.startswith("/health"):
704
+ return await call_next(request)
705
+ if request.headers.get("Authorization") != "Bearer " + api_key:
706
+ return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
707
+ return await call_next(request)