sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -38,15 +38,31 @@ import orjson
38
38
  import requests
39
39
  import uvicorn
40
40
  import uvloop
41
- from fastapi import FastAPI, File, Form, Request, UploadFile
41
+ from fastapi import Depends, FastAPI, Request, UploadFile
42
+ from fastapi.exceptions import RequestValidationError
42
43
  from fastapi.middleware.cors import CORSMiddleware
43
44
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
45
 
45
46
  from sglang.srt.disaggregation.utils import (
46
- FakeBootstrapHost,
47
+ FAKE_BOOTSTRAP_HOST,
47
48
  register_disaggregation_server,
48
49
  )
49
50
  from sglang.srt.entrypoints.engine import _launch_subprocesses
51
+ from sglang.srt.entrypoints.openai.protocol import (
52
+ ChatCompletionRequest,
53
+ CompletionRequest,
54
+ EmbeddingRequest,
55
+ ErrorResponse,
56
+ ModelCard,
57
+ ModelList,
58
+ ScoringRequest,
59
+ V1RerankReqInput,
60
+ )
61
+ from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
62
+ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
63
+ from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
64
+ from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
65
+ from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
50
66
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
51
67
  from sglang.srt.managers.io_struct import (
52
68
  AbortReq,
@@ -69,22 +85,9 @@ from sglang.srt.managers.io_struct import (
69
85
  UpdateWeightsFromTensorReqInput,
70
86
  VertexGenerateReqInput,
71
87
  )
88
+ from sglang.srt.managers.template_manager import TemplateManager
72
89
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
73
90
  from sglang.srt.metrics.func_timer import enable_func_timer
74
- from sglang.srt.openai_api.adapter import (
75
- v1_batches,
76
- v1_cancel_batch,
77
- v1_chat_completions,
78
- v1_completions,
79
- v1_delete_file,
80
- v1_embeddings,
81
- v1_files_create,
82
- v1_retrieve_batch,
83
- v1_retrieve_file,
84
- v1_retrieve_file_content,
85
- v1_score,
86
- )
87
- from sglang.srt.openai_api.protocol import ModelCard, ModelList
88
91
  from sglang.srt.reasoning_parser import ReasoningParser
89
92
  from sglang.srt.server_args import ServerArgs
90
93
  from sglang.srt.utils import (
@@ -107,6 +110,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
107
110
  @dataclasses.dataclass
108
111
  class _GlobalState:
109
112
  tokenizer_manager: TokenizerManager
113
+ template_manager: TemplateManager
110
114
  scheduler_info: Dict
111
115
 
112
116
 
@@ -121,6 +125,24 @@ def set_global_state(global_state: _GlobalState):
121
125
  @asynccontextmanager
122
126
  async def lifespan(fast_api_app: FastAPI):
123
127
  server_args: ServerArgs = fast_api_app.server_args
128
+
129
+ # Initialize OpenAI serving handlers
130
+ fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
131
+ _global_state.tokenizer_manager, _global_state.template_manager
132
+ )
133
+ fast_api_app.state.openai_serving_chat = OpenAIServingChat(
134
+ _global_state.tokenizer_manager, _global_state.template_manager
135
+ )
136
+ fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
137
+ _global_state.tokenizer_manager, _global_state.template_manager
138
+ )
139
+ fast_api_app.state.openai_serving_score = OpenAIServingScore(
140
+ _global_state.tokenizer_manager
141
+ )
142
+ fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
143
+ _global_state.tokenizer_manager
144
+ )
145
+
124
146
  if server_args.warmups is not None:
125
147
  await execute_warmups(
126
148
  server_args.warmups.split(","), _global_state.tokenizer_manager
@@ -146,6 +168,47 @@ app.add_middleware(
146
168
  allow_headers=["*"],
147
169
  )
148
170
 
171
+
172
+ # Custom exception handlers to change validation error status codes
173
+ @app.exception_handler(RequestValidationError)
174
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
175
+ """Override FastAPI's default 422 validation error with 400"""
176
+ exc_str = str(exc)
177
+ errors_str = str(exc.errors())
178
+
179
+ if errors_str and errors_str != exc_str:
180
+ message = f"{exc_str} {errors_str}"
181
+ else:
182
+ message = exc_str
183
+
184
+ err = ErrorResponse(
185
+ message=message,
186
+ type=HTTPStatus.BAD_REQUEST.phrase,
187
+ code=HTTPStatus.BAD_REQUEST.value,
188
+ )
189
+
190
+ return ORJSONResponse(
191
+ status_code=400,
192
+ content=err.model_dump(),
193
+ )
194
+
195
+
196
+ async def validate_json_request(raw_request: Request):
197
+ """Validate that the request content-type is application/json."""
198
+ content_type = raw_request.headers.get("content-type", "").lower()
199
+ media_type = content_type.split(";", maxsplit=1)[0]
200
+ if media_type != "application/json":
201
+ raise RequestValidationError(
202
+ errors=[
203
+ {
204
+ "loc": ["header", "content-type"],
205
+ "msg": "Unsupported Media Type: Only 'application/json' is allowed",
206
+ "type": "value_error",
207
+ }
208
+ ]
209
+ )
210
+
211
+
149
212
  HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
150
213
 
151
214
 
@@ -328,6 +391,16 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
328
391
  return _create_error_response(e)
329
392
 
330
393
 
394
+ @app.api_route(
395
+ "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
396
+ )
397
+ async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
398
+ """Endpoint for reranking documents based on query relevance."""
399
+ return await raw_request.app.state.openai_serving_rerank.handle_request(
400
+ request, raw_request
401
+ )
402
+
403
+
331
404
  @app.api_route("/flush_cache", methods=["GET", "POST"])
332
405
  async def flush_cache():
333
406
  """Flush the radix cache."""
@@ -608,25 +681,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
608
681
  ##### OpenAI-compatible API endpoints #####
609
682
 
610
683
 
611
- @app.post("/v1/completions")
612
- async def openai_v1_completions(raw_request: Request):
613
- return await v1_completions(_global_state.tokenizer_manager, raw_request)
684
+ @app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
685
+ async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
686
+ """OpenAI-compatible text completion endpoint."""
687
+ return await raw_request.app.state.openai_serving_completion.handle_request(
688
+ request, raw_request
689
+ )
614
690
 
615
691
 
616
- @app.post("/v1/chat/completions")
617
- async def openai_v1_chat_completions(raw_request: Request):
618
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
692
+ @app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
693
+ async def openai_v1_chat_completions(
694
+ request: ChatCompletionRequest, raw_request: Request
695
+ ):
696
+ """OpenAI-compatible chat completion endpoint."""
697
+ return await raw_request.app.state.openai_serving_chat.handle_request(
698
+ request, raw_request
699
+ )
619
700
 
620
701
 
621
- @app.post("/v1/embeddings", response_class=ORJSONResponse)
622
- async def openai_v1_embeddings(raw_request: Request):
623
- response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
624
- return response
702
+ @app.post(
703
+ "/v1/embeddings",
704
+ response_class=ORJSONResponse,
705
+ dependencies=[Depends(validate_json_request)],
706
+ )
707
+ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
708
+ """OpenAI-compatible embeddings endpoint."""
709
+ return await raw_request.app.state.openai_serving_embedding.handle_request(
710
+ request, raw_request
711
+ )
625
712
 
626
713
 
627
714
  @app.get("/v1/models", response_class=ORJSONResponse)
628
- def available_models():
629
- """Show available models."""
715
+ async def available_models():
716
+ """Show available models. OpenAI-compatible endpoint."""
630
717
  served_model_names = [_global_state.tokenizer_manager.served_model_name]
631
718
  model_cards = []
632
719
  for served_model_name in served_model_names:
@@ -640,45 +727,29 @@ def available_models():
640
727
  return ModelList(data=model_cards)
641
728
 
642
729
 
643
- @app.post("/v1/files")
644
- async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
645
- return await v1_files_create(
646
- file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
647
- )
648
-
649
-
650
- @app.delete("/v1/files/{file_id}")
651
- async def delete_file(file_id: str):
652
- # https://platform.openai.com/docs/api-reference/files/delete
653
- return await v1_delete_file(file_id)
654
-
655
-
656
- @app.post("/v1/batches")
657
- async def openai_v1_batches(raw_request: Request):
658
- return await v1_batches(_global_state.tokenizer_manager, raw_request)
659
-
660
-
661
- @app.post("/v1/batches/{batch_id}/cancel")
662
- async def cancel_batches(batch_id: str):
663
- # https://platform.openai.com/docs/api-reference/batch/cancel
664
- return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
665
-
666
-
667
- @app.get("/v1/batches/{batch_id}")
668
- async def retrieve_batch(batch_id: str):
669
- return await v1_retrieve_batch(batch_id)
670
-
671
-
672
- @app.get("/v1/files/{file_id}")
673
- async def retrieve_file(file_id: str):
674
- # https://platform.openai.com/docs/api-reference/files/retrieve
675
- return await v1_retrieve_file(file_id)
730
+ @app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
731
+ async def retrieve_model(model: str):
732
+ """Retrieves a model instance, providing basic information about the model."""
733
+ served_model_names = [_global_state.tokenizer_manager.served_model_name]
676
734
 
735
+ if model not in served_model_names:
736
+ return ORJSONResponse(
737
+ status_code=404,
738
+ content={
739
+ "error": {
740
+ "message": f"The model '{model}' does not exist",
741
+ "type": "invalid_request_error",
742
+ "param": "model",
743
+ "code": "model_not_found",
744
+ }
745
+ },
746
+ )
677
747
 
678
- @app.get("/v1/files/{file_id}/content")
679
- async def retrieve_file_content(file_id: str):
680
- # https://platform.openai.com/docs/api-reference/files/retrieve-contents
681
- return await v1_retrieve_file_content(file_id)
748
+ return ModelCard(
749
+ id=model,
750
+ root=model,
751
+ max_model_len=_global_state.tokenizer_manager.model_config.context_len,
752
+ )
682
753
 
683
754
 
684
755
  ## SageMaker API
@@ -689,8 +760,13 @@ async def sagemaker_health() -> Response:
689
760
 
690
761
 
691
762
  @app.post("/invocations")
692
- async def sagemaker_chat_completions(raw_request: Request):
693
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
763
+ async def sagemaker_chat_completions(
764
+ request: ChatCompletionRequest, raw_request: Request
765
+ ):
766
+ """OpenAI-compatible chat completion endpoint."""
767
+ return await raw_request.app.state.openai_serving_chat.handle_request(
768
+ request, raw_request
769
+ )
694
770
 
695
771
 
696
772
  ## Vertex AI API
@@ -721,10 +797,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
721
797
  return ORJSONResponse({"predictions": ret})
722
798
 
723
799
 
724
- @app.post("/v1/score")
725
- async def v1_score_request(raw_request: Request):
800
+ @app.post("/v1/score", dependencies=[Depends(validate_json_request)])
801
+ async def v1_score_request(request: ScoringRequest, raw_request: Request):
726
802
  """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
727
- return await v1_score(_global_state.tokenizer_manager, raw_request)
803
+ return await raw_request.app.state.openai_serving_score.handle_request(
804
+ request, raw_request
805
+ )
728
806
 
729
807
 
730
808
  def _create_error_response(e):
@@ -753,10 +831,13 @@ def launch_server(
753
831
  1. The HTTP server, Engine, and TokenizerManager both run in the main process.
754
832
  2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
755
833
  """
756
- tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
834
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
835
+ server_args=server_args
836
+ )
757
837
  set_global_state(
758
838
  _GlobalState(
759
839
  tokenizer_manager=tokenizer_manager,
840
+ template_manager=template_manager,
760
841
  scheduler_info=scheduler_info,
761
842
  )
762
843
  )
@@ -878,7 +959,7 @@ def _wait_and_warmup(
878
959
  "max_new_tokens": 8,
879
960
  "ignore_eos": True,
880
961
  },
881
- "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
962
+ "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
882
963
  # This is a hack to ensure fake transfer is enabled during prefill warmup
883
964
  # ensure each dp rank has a unique bootstrap_room during prefill warmup
884
965
  "bootstrap_room": [
@@ -64,11 +64,9 @@ class HttpServerEngineAdapter(EngineBase):
64
64
 
65
65
  def _make_request(self, endpoint: str, payload: Optional[dict] = None):
66
66
  """Make a POST request to the specified endpoint with the given payload.
67
-
68
67
  Args:
69
68
  endpoint: The API endpoint to call
70
69
  payload: The JSON payload to send (default: empty dict)
71
-
72
70
  Returns:
73
71
  The JSON response from the server
74
72
  """
@@ -85,7 +83,6 @@ class HttpServerEngineAdapter(EngineBase):
85
83
  ):
86
84
  """
87
85
  Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
88
-
89
86
  Note: The model should be on GPUs rather than CPU for this functionality to work properly.
90
87
  If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
91
88
  """
File without changes
@@ -16,7 +16,13 @@
16
16
  import time
17
17
  from typing import Dict, List, Optional, Union
18
18
 
19
- from pydantic import BaseModel, Field, root_validator
19
+ from pydantic import (
20
+ BaseModel,
21
+ Field,
22
+ field_validator,
23
+ model_serializer,
24
+ model_validator,
25
+ )
20
26
  from typing_extensions import Literal
21
27
 
22
28
 
@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
167
173
  temperature: float = 1.0
168
174
  top_p: float = 1.0
169
175
  user: Optional[str] = None
176
+ return_hidden_states: bool = False
170
177
 
171
178
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
172
179
  top_k: int = -1
@@ -188,13 +195,28 @@ class CompletionRequest(BaseModel):
188
195
  bootstrap_port: Optional[int] = None
189
196
  bootstrap_room: Optional[int] = None
190
197
 
198
+ @field_validator("max_tokens")
199
+ @classmethod
200
+ def validate_max_tokens_positive(cls, v):
201
+ if v is not None and v <= 0:
202
+ raise ValueError("max_tokens must be positive")
203
+ return v
204
+
191
205
 
192
206
  class CompletionResponseChoice(BaseModel):
193
207
  index: int
194
208
  text: str
195
209
  logprobs: Optional[LogProbs] = None
196
- finish_reason: Literal["stop", "length", "content_filter", "abort"]
210
+ finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
197
211
  matched_stop: Union[None, int, str] = None
212
+ hidden_states: Optional[object] = None
213
+
214
+ @model_serializer(mode="wrap")
215
+ def _serialize(self, handler):
216
+ data = handler(self)
217
+ if self.hidden_states is None:
218
+ data.pop("hidden_states", None)
219
+ return data
198
220
 
199
221
 
200
222
  class CompletionResponse(BaseModel):
@@ -212,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
212
234
  logprobs: Optional[LogProbs] = None
213
235
  finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
214
236
  matched_stop: Union[None, int, str] = None
237
+ hidden_states: Optional[object] = None
238
+
239
+ @model_serializer(mode="wrap")
240
+ def _serialize(self, handler):
241
+ data = handler(self)
242
+ if self.hidden_states is None:
243
+ data.pop("hidden_states", None)
244
+ return data
215
245
 
216
246
 
217
247
  class CompletionStreamResponse(BaseModel):
@@ -369,8 +399,10 @@ class ChatCompletionRequest(BaseModel):
369
399
  tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
370
400
  default="auto", examples=["none"]
371
401
  ) # noqa
402
+ return_hidden_states: bool = False
372
403
 
373
- @root_validator(pre=True)
404
+ @model_validator(mode="before")
405
+ @classmethod
374
406
  def set_tool_choice_default(cls, values):
375
407
  if values.get("tool_choice") is None:
376
408
  if values.get("tools") is None:
@@ -417,10 +449,20 @@ class ChatCompletionResponseChoice(BaseModel):
417
449
  index: int
418
450
  message: ChatMessage
419
451
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
420
- finish_reason: Literal[
421
- "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
422
- ]
452
+ finish_reason: Optional[
453
+ Literal[
454
+ "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
455
+ ]
456
+ ] = None
423
457
  matched_stop: Union[None, int, str] = None
458
+ hidden_states: Optional[object] = None
459
+
460
+ @model_serializer(mode="wrap")
461
+ def _serialize(self, handler):
462
+ data = handler(self)
463
+ if self.hidden_states is None:
464
+ data.pop("hidden_states", None)
465
+ return data
424
466
 
425
467
 
426
468
  class ChatCompletionResponse(BaseModel):
@@ -437,6 +479,14 @@ class DeltaMessage(BaseModel):
437
479
  content: Optional[str] = None
438
480
  reasoning_content: Optional[str] = None
439
481
  tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
482
+ hidden_states: Optional[object] = None
483
+
484
+ @model_serializer(mode="wrap")
485
+ def _serialize(self, handler):
486
+ data = handler(self)
487
+ if self.hidden_states is None:
488
+ data.pop("hidden_states", None)
489
+ return data
440
490
 
441
491
 
442
492
  class ChatCompletionResponseStreamChoice(BaseModel):
@@ -463,15 +513,18 @@ class MultimodalEmbeddingInput(BaseModel):
463
513
  image: Optional[str] = None
464
514
 
465
515
 
516
+ EmbeddingInput = Union[
517
+ List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
518
+ ]
519
+
520
+
466
521
  class EmbeddingRequest(BaseModel):
467
522
  # Ordered by official OpenAI API documentation
468
523
  # https://platform.openai.com/docs/api-reference/embeddings/create
469
- input: Union[
470
- List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
471
- ]
524
+ input: EmbeddingInput
472
525
  model: str
473
526
  encoding_format: str = "float"
474
- dimensions: int = None
527
+ dimensions: Optional[int] = None
475
528
  user: Optional[str] = None
476
529
 
477
530
  # The request id.
@@ -513,3 +566,24 @@ class ScoringResponse(BaseModel):
513
566
  model: str
514
567
  usage: Optional[UsageInfo] = None
515
568
  object: str = "scoring"
569
+
570
+
571
+ class V1RerankReqInput(BaseModel):
572
+ query: str
573
+ documents: List[str]
574
+
575
+
576
+ class RerankResponse(BaseModel):
577
+ score: float
578
+ document: str
579
+ index: int
580
+ meta_info: Optional[dict] = None
581
+
582
+
583
+ OpenAIServingRequest = Union[
584
+ ChatCompletionRequest,
585
+ CompletionRequest,
586
+ EmbeddingRequest,
587
+ ScoringRequest,
588
+ V1RerankReqInput,
589
+ ]
@@ -0,0 +1,149 @@
1
+ import json
2
+ import logging
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Optional, Union
6
+
7
+ from fastapi import Request
8
+ from fastapi.responses import ORJSONResponse, StreamingResponse
9
+
10
+ from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
11
+ from sglang.srt.managers.io_struct import GenerateReqInput
12
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ # Base class for specific endpoint handlers
18
+ class OpenAIServingBase(ABC):
19
+ """Abstract base class for OpenAI endpoint handlers"""
20
+
21
+ def __init__(self, tokenizer_manager: TokenizerManager):
22
+ self.tokenizer_manager = tokenizer_manager
23
+
24
+ async def handle_request(
25
+ self, request: OpenAIServingRequest, raw_request: Request
26
+ ) -> Union[Any, StreamingResponse, ErrorResponse]:
27
+ """Handle the specific request type with common pattern"""
28
+ try:
29
+ # Validate request
30
+ error_msg = self._validate_request(request)
31
+ if error_msg:
32
+ return self.create_error_response(error_msg)
33
+
34
+ # Convert to internal format
35
+ adapted_request, processed_request = self._convert_to_internal_request(
36
+ request
37
+ )
38
+
39
+ # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
40
+ if hasattr(request, "stream") and request.stream:
41
+ return await self._handle_streaming_request(
42
+ adapted_request, processed_request, raw_request
43
+ )
44
+ else:
45
+ return await self._handle_non_streaming_request(
46
+ adapted_request, processed_request, raw_request
47
+ )
48
+
49
+ except Exception as e:
50
+ logger.exception(f"Error in request: {e}")
51
+ return self.create_error_response(
52
+ message=f"Internal server error: {str(e)}",
53
+ err_type="InternalServerError",
54
+ status_code=500,
55
+ )
56
+
57
+ @abstractmethod
58
+ def _request_id_prefix(self) -> str:
59
+ """Generate request ID based on request type"""
60
+ pass
61
+
62
+ def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
63
+ """Generate request ID based on request type"""
64
+ return None
65
+
66
+ # TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
67
+ # Temporarily return None in this function until the rid logic is clear.
68
+ if rid := getattr(request, "rid", None):
69
+ return rid
70
+
71
+ return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
72
+
73
+ @abstractmethod
74
+ def _convert_to_internal_request(
75
+ self,
76
+ request: OpenAIServingRequest,
77
+ ) -> tuple[GenerateReqInput, OpenAIServingRequest]:
78
+ """Convert OpenAI request to internal format"""
79
+ pass
80
+
81
+ async def _handle_streaming_request(
82
+ self,
83
+ adapted_request: GenerateReqInput,
84
+ request: OpenAIServingRequest,
85
+ raw_request: Request,
86
+ ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
87
+ """Handle streaming request
88
+
89
+ Override this method in child classes that support streaming requests.
90
+ """
91
+ return self.create_error_response(
92
+ message=f"{self.__class__.__name__} does not support streaming requests",
93
+ err_type="NotImplementedError",
94
+ status_code=501,
95
+ )
96
+
97
+ async def _handle_non_streaming_request(
98
+ self,
99
+ adapted_request: GenerateReqInput,
100
+ request: OpenAIServingRequest,
101
+ raw_request: Request,
102
+ ) -> Union[Any, ErrorResponse, ORJSONResponse]:
103
+ """Handle non-streaming request
104
+
105
+ Override this method in child classes that support non-streaming requests.
106
+ """
107
+ return self.create_error_response(
108
+ message=f"{self.__class__.__name__} does not support non-streaming requests",
109
+ err_type="NotImplementedError",
110
+ status_code=501,
111
+ )
112
+
113
+ def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
114
+ """Validate request"""
115
+ pass
116
+
117
+ def create_error_response(
118
+ self,
119
+ message: str,
120
+ err_type: str = "BadRequestError",
121
+ status_code: int = 400,
122
+ param: Optional[str] = None,
123
+ ) -> ORJSONResponse:
124
+ """Create an error response"""
125
+ # TODO: remove fastapi dependency in openai and move response handling to the entrypoint
126
+ error = ErrorResponse(
127
+ object="error",
128
+ message=message,
129
+ type=err_type,
130
+ param=param,
131
+ code=status_code,
132
+ )
133
+ return ORJSONResponse(content=error.model_dump(), status_code=status_code)
134
+
135
+ def create_streaming_error_response(
136
+ self,
137
+ message: str,
138
+ err_type: str = "BadRequestError",
139
+ status_code: int = 400,
140
+ ) -> str:
141
+ """Create a streaming error response"""
142
+ error = ErrorResponse(
143
+ object="error",
144
+ message=message,
145
+ type=err_type,
146
+ param=None,
147
+ code=status_code,
148
+ )
149
+ return json.dumps({"error": error.model_dump()})