sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import random
6
6
  import threading
7
7
  import warnings
8
8
  from collections import deque
9
+ from contextlib import nullcontext
9
10
  from enum import Enum
10
11
  from typing import TYPE_CHECKING, List, Optional
11
12
 
@@ -84,28 +85,48 @@ class ReqToMetadataIdxAllocator:
84
85
 
85
86
 
86
87
  class MetadataBuffers:
87
- def __init__(self, size: int, max_top_logprobs_num: int = 128):
88
- # TODO: abort top_logprobs_num > 128 in PD
89
-
90
- # We transfer the metadata of first output token to decode
91
- # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
92
- self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
93
- self.output_token_logprobs_val = torch.zeros(
94
- (size, 16), dtype=torch.float32, device="cpu"
95
- )
96
- self.output_token_logprobs_idx = torch.zeros(
97
- (size, 16), dtype=torch.int32, device="cpu"
98
- )
99
- self.output_top_logprobs_val = torch.zeros(
100
- (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
101
- )
102
- self.output_top_logprobs_idx = torch.zeros(
103
- (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
104
- )
88
+ def __init__(
89
+ self,
90
+ size: int,
91
+ hidden_size: int,
92
+ dtype: torch.dtype,
93
+ max_top_logprobs_num: int = 128,
94
+ custom_mem_pool: torch.cuda.MemPool = None,
95
+ ):
96
+ self.custom_mem_pool = custom_mem_pool
97
+ device = "cuda" if self.custom_mem_pool else "cpu"
98
+
99
+ with (
100
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
101
+ if self.custom_mem_pool
102
+ else nullcontext()
103
+ ):
104
+ # TODO: abort top_logprobs_num > 128 in PD
105
+
106
+ # We transfer the metadata of first output token to decode
107
+ # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
108
+ self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
109
+
110
+ self.output_hidden_states = torch.zeros(
111
+ (size, hidden_size), dtype=dtype, device=device
112
+ )
113
+ self.output_token_logprobs_val = torch.zeros(
114
+ (size, 16), dtype=torch.float32, device=device
115
+ )
116
+ self.output_token_logprobs_idx = torch.zeros(
117
+ (size, 16), dtype=torch.int32, device=device
118
+ )
119
+ self.output_top_logprobs_val = torch.zeros(
120
+ (size, max_top_logprobs_num), dtype=torch.float32, device=device
121
+ )
122
+ self.output_top_logprobs_idx = torch.zeros(
123
+ (size, max_top_logprobs_num), dtype=torch.int32, device=device
124
+ )
105
125
 
106
126
  def get_buf_infos(self):
107
127
  ptrs = [
108
128
  self.output_ids.data_ptr(),
129
+ self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
109
130
  self.output_token_logprobs_val.data_ptr(),
110
131
  self.output_token_logprobs_idx.data_ptr(),
111
132
  self.output_top_logprobs_val.data_ptr(),
@@ -113,6 +134,7 @@ class MetadataBuffers:
113
134
  ]
114
135
  data_lens = [
115
136
  self.output_ids.nbytes,
137
+ self.output_hidden_states.nbytes,
116
138
  self.output_token_logprobs_val.nbytes,
117
139
  self.output_token_logprobs_idx.nbytes,
118
140
  self.output_top_logprobs_val.nbytes,
@@ -120,6 +142,7 @@ class MetadataBuffers:
120
142
  ]
121
143
  item_lens = [
122
144
  self.output_ids[0].nbytes,
145
+ self.output_hidden_states[0].nbytes,
123
146
  self.output_token_logprobs_val[0].nbytes,
124
147
  self.output_token_logprobs_idx[0].nbytes,
125
148
  self.output_top_logprobs_val[0].nbytes,
@@ -130,6 +153,7 @@ class MetadataBuffers:
130
153
  def get_buf(self, idx: int):
131
154
  return (
132
155
  self.output_ids[idx],
156
+ self.output_hidden_states[idx],
133
157
  self.output_token_logprobs_val[idx],
134
158
  self.output_token_logprobs_idx[idx],
135
159
  self.output_top_logprobs_val[idx],
@@ -139,6 +163,10 @@ class MetadataBuffers:
139
163
  def set_buf(self, req: Req):
140
164
 
141
165
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
166
+ if req.hidden_states_tensor is not None:
167
+ self.output_hidden_states[req.metadata_buffer_index].copy_(
168
+ req.hidden_states_tensor
169
+ )
142
170
  if req.return_logprob:
143
171
  if req.output_token_logprobs_val: # not none or empty list
144
172
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -523,17 +523,25 @@ class GroupCoordinator:
523
523
  self,
524
524
  input_: torch.Tensor,
525
525
  dim: int = -1,
526
- tensor_list: List[torch.Tensor] = None,
526
+ output_tensor_list: Optional[List[torch.Tensor]] = None,
527
527
  ) -> torch.Tensor:
528
528
  world_size = self.world_size
529
529
  # Bypass the function if we are using only 1 GPU.
530
530
  if world_size == 1:
531
- return input_
531
+ if output_tensor_list is not None:
532
+ logger.warning(
533
+ "Performing in-place all-gather with a group size of 1. "
534
+ "This may be unnecessary; consider bypassing it for better efficiency."
535
+ )
536
+ output_tensor_list[0].copy_(input_)
537
+ return None
538
+ else:
539
+ return input_
532
540
 
533
- if tensor_list is not None:
541
+ if output_tensor_list is not None:
534
542
  # TODO(ch-wan): support other backends
535
543
  return torch.distributed.all_gather(
536
- tensor_list, input_, group=self.device_group
544
+ output_tensor_list, input_, group=self.device_group
537
545
  )
538
546
 
539
547
  assert (
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
37
37
  import torch
38
38
  import uvloop
39
39
 
40
- from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
41
40
  from sglang.srt.entrypoints.EngineBase import EngineBase
42
41
  from sglang.srt.managers.data_parallel_controller import (
43
42
  run_data_parallel_controller_process,
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
58
57
  UpdateWeightsFromTensorReqInput,
59
58
  )
60
59
  from sglang.srt.managers.scheduler import run_scheduler_process
60
+ from sglang.srt.managers.template_manager import TemplateManager
61
61
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
62
- from sglang.srt.openai_api.adapter import (
63
- guess_chat_template_name_from_model_path,
64
- load_chat_template_for_openai_api,
65
- )
66
62
  from sglang.srt.server_args import PortArgs, ServerArgs
67
63
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
68
64
  from sglang.srt.utils import (
@@ -123,12 +119,13 @@ class Engine(EngineBase):
123
119
  logger.info(f"{server_args=}")
124
120
 
125
121
  # Launch subprocesses
126
- tokenizer_manager, scheduler_info = _launch_subprocesses(
122
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
127
123
  server_args=server_args,
128
124
  port_args=port_args,
129
125
  )
130
126
  self.server_args = server_args
131
127
  self.tokenizer_manager = tokenizer_manager
128
+ self.template_manager = template_manager
132
129
  self.scheduler_info = scheduler_info
133
130
 
134
131
  context = zmq.Context(2)
@@ -175,7 +172,7 @@ class Engine(EngineBase):
175
172
  """
176
173
  if self.server_args.enable_dp_attention:
177
174
  if data_parallel_rank is None:
178
- logger.info("data_parallel_rank not provided, using default dispatch")
175
+ logger.debug("data_parallel_rank not provided, using default dispatch")
179
176
  elif data_parallel_rank < 0:
180
177
  raise ValueError("data_parallel_rank must be non-negative")
181
178
  elif data_parallel_rank >= self.server_args.dp_size:
@@ -258,7 +255,7 @@ class Engine(EngineBase):
258
255
 
259
256
  if self.server_args.enable_dp_attention:
260
257
  if data_parallel_rank is None:
261
- logger.info("data_parallel_rank not provided, using default dispatch")
258
+ logger.debug("data_parallel_rank not provided, using default dispatch")
262
259
  elif data_parallel_rank < 0:
263
260
  raise ValueError("data_parallel_rank must be non-negative")
264
261
  elif data_parallel_rank >= self.server_args.dp_size:
@@ -479,17 +476,15 @@ class Engine(EngineBase):
479
476
  self.tokenizer_manager.get_weights_by_name(obj, None)
480
477
  )
481
478
 
482
- def release_memory_occupation(self):
483
- """Release GPU occupation temporarily."""
484
- obj = ReleaseMemoryOccupationReqInput()
479
+ def release_memory_occupation(self, tags: Optional[List[str]] = None):
480
+ obj = ReleaseMemoryOccupationReqInput(tags=tags)
485
481
  loop = asyncio.get_event_loop()
486
482
  return loop.run_until_complete(
487
483
  self.tokenizer_manager.release_memory_occupation(obj, None)
488
484
  )
489
485
 
490
- def resume_memory_occupation(self):
491
- """Resume GPU occupation."""
492
- obj = ResumeMemoryOccupationReqInput()
486
+ def resume_memory_occupation(self, tags: Optional[List[str]] = None):
487
+ obj = ResumeMemoryOccupationReqInput(tags=tags)
493
488
  loop = asyncio.get_event_loop()
494
489
  return loop.run_until_complete(
495
490
  self.tokenizer_manager.resume_memory_occupation(obj, None)
@@ -649,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
649
644
 
650
645
  def _launch_subprocesses(
651
646
  server_args: ServerArgs, port_args: Optional[PortArgs] = None
652
- ) -> Tuple[TokenizerManager, Dict]:
647
+ ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
653
648
  """
654
649
  Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
655
650
  """
@@ -670,11 +665,9 @@ def _launch_subprocesses(
670
665
 
671
666
  scheduler_procs = []
672
667
  if server_args.dp_size == 1:
673
- # Launch tensor parallel scheduler processes
674
668
  memory_saver_adapter = TorchMemorySaverAdapter.create(
675
669
  enable=server_args.enable_memory_saver
676
670
  )
677
-
678
671
  scheduler_pipe_readers = []
679
672
 
680
673
  nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
@@ -710,6 +703,7 @@ def _launch_subprocesses(
710
703
  writer,
711
704
  ),
712
705
  )
706
+
713
707
  with memory_saver_adapter.configure_subprocess():
714
708
  proc.start()
715
709
  scheduler_procs.append(proc)
@@ -735,7 +729,7 @@ def _launch_subprocesses(
735
729
 
736
730
  if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
737
731
  # When using `Engine` as a Python API, we don't want to block here.
738
- return None, None
732
+ return None, None, None
739
733
 
740
734
  launch_dummy_health_check_server(server_args.host, server_args.port)
741
735
 
@@ -744,7 +738,7 @@ def _launch_subprocesses(
744
738
  logger.error(
745
739
  f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
746
740
  )
747
- return None, None
741
+ return None, None, None
748
742
 
749
743
  # Launch detokenizer process
750
744
  detoken_proc = mp.Process(
@@ -758,15 +752,15 @@ def _launch_subprocesses(
758
752
 
759
753
  # Launch tokenizer process
760
754
  tokenizer_manager = TokenizerManager(server_args, port_args)
761
- if server_args.chat_template:
762
- load_chat_template_for_openai_api(
763
- tokenizer_manager, server_args.chat_template, server_args.model_path
764
- )
765
- else:
766
- guess_chat_template_name_from_model_path(server_args.model_path)
767
755
 
768
- if server_args.completion_template:
769
- load_completion_template_for_openai_api(server_args.completion_template)
756
+ # Initialize templates
757
+ template_manager = TemplateManager()
758
+ template_manager.initialize_templates(
759
+ tokenizer_manager=tokenizer_manager,
760
+ model_path=server_args.model_path,
761
+ chat_template=server_args.chat_template,
762
+ completion_template=server_args.completion_template,
763
+ )
770
764
 
771
765
  # Wait for the model to finish loading
772
766
  scheduler_infos = []
@@ -790,4 +784,4 @@ def _launch_subprocesses(
790
784
  # Assume all schedulers have the same scheduler_info
791
785
  scheduler_info = scheduler_infos[0]
792
786
  tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
793
- return tokenizer_manager, scheduler_info
787
+ return tokenizer_manager, template_manager, scheduler_info
@@ -38,7 +38,8 @@ import orjson
38
38
  import requests
39
39
  import uvicorn
40
40
  import uvloop
41
- from fastapi import FastAPI, File, Form, Request, UploadFile
41
+ from fastapi import Depends, FastAPI, Request, UploadFile
42
+ from fastapi.exceptions import RequestValidationError
42
43
  from fastapi.middleware.cors import CORSMiddleware
43
44
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
45
 
@@ -47,6 +48,21 @@ from sglang.srt.disaggregation.utils import (
47
48
  register_disaggregation_server,
48
49
  )
49
50
  from sglang.srt.entrypoints.engine import _launch_subprocesses
51
+ from sglang.srt.entrypoints.openai.protocol import (
52
+ ChatCompletionRequest,
53
+ CompletionRequest,
54
+ EmbeddingRequest,
55
+ ErrorResponse,
56
+ ModelCard,
57
+ ModelList,
58
+ ScoringRequest,
59
+ V1RerankReqInput,
60
+ )
61
+ from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
62
+ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
63
+ from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
64
+ from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
65
+ from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
50
66
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
51
67
  from sglang.srt.managers.io_struct import (
52
68
  AbortReq,
@@ -67,26 +83,11 @@ from sglang.srt.managers.io_struct import (
67
83
  UpdateWeightFromDiskReqInput,
68
84
  UpdateWeightsFromDistributedReqInput,
69
85
  UpdateWeightsFromTensorReqInput,
70
- V1RerankReqInput,
71
86
  VertexGenerateReqInput,
72
87
  )
88
+ from sglang.srt.managers.template_manager import TemplateManager
73
89
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
74
90
  from sglang.srt.metrics.func_timer import enable_func_timer
75
- from sglang.srt.openai_api.adapter import (
76
- v1_batches,
77
- v1_cancel_batch,
78
- v1_chat_completions,
79
- v1_completions,
80
- v1_delete_file,
81
- v1_embeddings,
82
- v1_files_create,
83
- v1_rerank,
84
- v1_retrieve_batch,
85
- v1_retrieve_file,
86
- v1_retrieve_file_content,
87
- v1_score,
88
- )
89
- from sglang.srt.openai_api.protocol import ModelCard, ModelList
90
91
  from sglang.srt.reasoning_parser import ReasoningParser
91
92
  from sglang.srt.server_args import ServerArgs
92
93
  from sglang.srt.utils import (
@@ -109,6 +110,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
109
110
  @dataclasses.dataclass
110
111
  class _GlobalState:
111
112
  tokenizer_manager: TokenizerManager
113
+ template_manager: TemplateManager
112
114
  scheduler_info: Dict
113
115
 
114
116
 
@@ -123,6 +125,24 @@ def set_global_state(global_state: _GlobalState):
123
125
  @asynccontextmanager
124
126
  async def lifespan(fast_api_app: FastAPI):
125
127
  server_args: ServerArgs = fast_api_app.server_args
128
+
129
+ # Initialize OpenAI serving handlers
130
+ fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
131
+ _global_state.tokenizer_manager, _global_state.template_manager
132
+ )
133
+ fast_api_app.state.openai_serving_chat = OpenAIServingChat(
134
+ _global_state.tokenizer_manager, _global_state.template_manager
135
+ )
136
+ fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
137
+ _global_state.tokenizer_manager, _global_state.template_manager
138
+ )
139
+ fast_api_app.state.openai_serving_score = OpenAIServingScore(
140
+ _global_state.tokenizer_manager
141
+ )
142
+ fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
143
+ _global_state.tokenizer_manager
144
+ )
145
+
126
146
  if server_args.warmups is not None:
127
147
  await execute_warmups(
128
148
  server_args.warmups.split(","), _global_state.tokenizer_manager
@@ -148,6 +168,47 @@ app.add_middleware(
148
168
  allow_headers=["*"],
149
169
  )
150
170
 
171
+
172
+ # Custom exception handlers to change validation error status codes
173
+ @app.exception_handler(RequestValidationError)
174
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
175
+ """Override FastAPI's default 422 validation error with 400"""
176
+ exc_str = str(exc)
177
+ errors_str = str(exc.errors())
178
+
179
+ if errors_str and errors_str != exc_str:
180
+ message = f"{exc_str} {errors_str}"
181
+ else:
182
+ message = exc_str
183
+
184
+ err = ErrorResponse(
185
+ message=message,
186
+ type=HTTPStatus.BAD_REQUEST.phrase,
187
+ code=HTTPStatus.BAD_REQUEST.value,
188
+ )
189
+
190
+ return ORJSONResponse(
191
+ status_code=400,
192
+ content=err.model_dump(),
193
+ )
194
+
195
+
196
+ async def validate_json_request(raw_request: Request):
197
+ """Validate that the request content-type is application/json."""
198
+ content_type = raw_request.headers.get("content-type", "").lower()
199
+ media_type = content_type.split(";", maxsplit=1)[0]
200
+ if media_type != "application/json":
201
+ raise RequestValidationError(
202
+ errors=[
203
+ {
204
+ "loc": ["header", "content-type"],
205
+ "msg": "Unsupported Media Type: Only 'application/json' is allowed",
206
+ "type": "value_error",
207
+ }
208
+ ]
209
+ )
210
+
211
+
151
212
  HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
152
213
 
153
214
 
@@ -330,13 +391,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
330
391
  return _create_error_response(e)
331
392
 
332
393
 
333
- @app.api_route("/v1/rerank", methods=["POST", "PUT"])
334
- async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
335
- try:
336
- ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
337
- return ret
338
- except ValueError as e:
339
- return _create_error_response(e)
394
+ @app.api_route(
395
+ "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
396
+ )
397
+ async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
398
+ """Endpoint for reranking documents based on query relevance."""
399
+ return await raw_request.app.state.openai_serving_rerank.handle_request(
400
+ request, raw_request
401
+ )
340
402
 
341
403
 
342
404
  @app.api_route("/flush_cache", methods=["GET", "POST"])
@@ -619,25 +681,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
619
681
  ##### OpenAI-compatible API endpoints #####
620
682
 
621
683
 
622
- @app.post("/v1/completions")
623
- async def openai_v1_completions(raw_request: Request):
624
- return await v1_completions(_global_state.tokenizer_manager, raw_request)
684
+ @app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
685
+ async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
686
+ """OpenAI-compatible text completion endpoint."""
687
+ return await raw_request.app.state.openai_serving_completion.handle_request(
688
+ request, raw_request
689
+ )
625
690
 
626
691
 
627
- @app.post("/v1/chat/completions")
628
- async def openai_v1_chat_completions(raw_request: Request):
629
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
692
+ @app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
693
+ async def openai_v1_chat_completions(
694
+ request: ChatCompletionRequest, raw_request: Request
695
+ ):
696
+ """OpenAI-compatible chat completion endpoint."""
697
+ return await raw_request.app.state.openai_serving_chat.handle_request(
698
+ request, raw_request
699
+ )
630
700
 
631
701
 
632
- @app.post("/v1/embeddings", response_class=ORJSONResponse)
633
- async def openai_v1_embeddings(raw_request: Request):
634
- response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
635
- return response
702
+ @app.post(
703
+ "/v1/embeddings",
704
+ response_class=ORJSONResponse,
705
+ dependencies=[Depends(validate_json_request)],
706
+ )
707
+ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
708
+ """OpenAI-compatible embeddings endpoint."""
709
+ return await raw_request.app.state.openai_serving_embedding.handle_request(
710
+ request, raw_request
711
+ )
636
712
 
637
713
 
638
714
  @app.get("/v1/models", response_class=ORJSONResponse)
639
- def available_models():
640
- """Show available models."""
715
+ async def available_models():
716
+ """Show available models. OpenAI-compatible endpoint."""
641
717
  served_model_names = [_global_state.tokenizer_manager.served_model_name]
642
718
  model_cards = []
643
719
  for served_model_name in served_model_names:
@@ -651,45 +727,29 @@ def available_models():
651
727
  return ModelList(data=model_cards)
652
728
 
653
729
 
654
- @app.post("/v1/files")
655
- async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
656
- return await v1_files_create(
657
- file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
658
- )
659
-
660
-
661
- @app.delete("/v1/files/{file_id}")
662
- async def delete_file(file_id: str):
663
- # https://platform.openai.com/docs/api-reference/files/delete
664
- return await v1_delete_file(file_id)
665
-
666
-
667
- @app.post("/v1/batches")
668
- async def openai_v1_batches(raw_request: Request):
669
- return await v1_batches(_global_state.tokenizer_manager, raw_request)
670
-
671
-
672
- @app.post("/v1/batches/{batch_id}/cancel")
673
- async def cancel_batches(batch_id: str):
674
- # https://platform.openai.com/docs/api-reference/batch/cancel
675
- return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
676
-
677
-
678
- @app.get("/v1/batches/{batch_id}")
679
- async def retrieve_batch(batch_id: str):
680
- return await v1_retrieve_batch(batch_id)
681
-
682
-
683
- @app.get("/v1/files/{file_id}")
684
- async def retrieve_file(file_id: str):
685
- # https://platform.openai.com/docs/api-reference/files/retrieve
686
- return await v1_retrieve_file(file_id)
730
+ @app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
731
+ async def retrieve_model(model: str):
732
+ """Retrieves a model instance, providing basic information about the model."""
733
+ served_model_names = [_global_state.tokenizer_manager.served_model_name]
687
734
 
735
+ if model not in served_model_names:
736
+ return ORJSONResponse(
737
+ status_code=404,
738
+ content={
739
+ "error": {
740
+ "message": f"The model '{model}' does not exist",
741
+ "type": "invalid_request_error",
742
+ "param": "model",
743
+ "code": "model_not_found",
744
+ }
745
+ },
746
+ )
688
747
 
689
- @app.get("/v1/files/{file_id}/content")
690
- async def retrieve_file_content(file_id: str):
691
- # https://platform.openai.com/docs/api-reference/files/retrieve-contents
692
- return await v1_retrieve_file_content(file_id)
748
+ return ModelCard(
749
+ id=model,
750
+ root=model,
751
+ max_model_len=_global_state.tokenizer_manager.model_config.context_len,
752
+ )
693
753
 
694
754
 
695
755
  ## SageMaker API
@@ -700,8 +760,13 @@ async def sagemaker_health() -> Response:
700
760
 
701
761
 
702
762
  @app.post("/invocations")
703
- async def sagemaker_chat_completions(raw_request: Request):
704
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
763
+ async def sagemaker_chat_completions(
764
+ request: ChatCompletionRequest, raw_request: Request
765
+ ):
766
+ """OpenAI-compatible chat completion endpoint."""
767
+ return await raw_request.app.state.openai_serving_chat.handle_request(
768
+ request, raw_request
769
+ )
705
770
 
706
771
 
707
772
  ## Vertex AI API
@@ -732,10 +797,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
732
797
  return ORJSONResponse({"predictions": ret})
733
798
 
734
799
 
735
- @app.post("/v1/score")
736
- async def v1_score_request(raw_request: Request):
800
+ @app.post("/v1/score", dependencies=[Depends(validate_json_request)])
801
+ async def v1_score_request(request: ScoringRequest, raw_request: Request):
737
802
  """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
738
- return await v1_score(_global_state.tokenizer_manager, raw_request)
803
+ return await raw_request.app.state.openai_serving_score.handle_request(
804
+ request, raw_request
805
+ )
739
806
 
740
807
 
741
808
  def _create_error_response(e):
@@ -764,10 +831,13 @@ def launch_server(
764
831
  1. The HTTP server, Engine, and TokenizerManager both run in the main process.
765
832
  2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
766
833
  """
767
- tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
834
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
835
+ server_args=server_args
836
+ )
768
837
  set_global_state(
769
838
  _GlobalState(
770
839
  tokenizer_manager=tokenizer_manager,
840
+ template_manager=template_manager,
771
841
  scheduler_info=scheduler_info,
772
842
  )
773
843
  )
@@ -64,11 +64,9 @@ class HttpServerEngineAdapter(EngineBase):
64
64
 
65
65
  def _make_request(self, endpoint: str, payload: Optional[dict] = None):
66
66
  """Make a POST request to the specified endpoint with the given payload.
67
-
68
67
  Args:
69
68
  endpoint: The API endpoint to call
70
69
  payload: The JSON payload to send (default: empty dict)
71
-
72
70
  Returns:
73
71
  The JSON response from the server
74
72
  """
@@ -85,7 +83,6 @@ class HttpServerEngineAdapter(EngineBase):
85
83
  ):
86
84
  """
87
85
  Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
88
-
89
86
  Note: The model should be on GPUs rather than CPU for this functionality to work properly.
90
87
  If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
91
88
  """
File without changes