sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,8 @@ import orjson
38
38
  import requests
39
39
  import uvicorn
40
40
  import uvloop
41
- from fastapi import FastAPI, File, Form, Request, UploadFile
41
+ from fastapi import Depends, FastAPI, Request, UploadFile
42
+ from fastapi.exceptions import RequestValidationError
42
43
  from fastapi.middleware.cors import CORSMiddleware
43
44
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
45
 
@@ -47,6 +48,21 @@ from sglang.srt.disaggregation.utils import (
47
48
  register_disaggregation_server,
48
49
  )
49
50
  from sglang.srt.entrypoints.engine import _launch_subprocesses
51
+ from sglang.srt.entrypoints.openai.protocol import (
52
+ ChatCompletionRequest,
53
+ CompletionRequest,
54
+ EmbeddingRequest,
55
+ ErrorResponse,
56
+ ModelCard,
57
+ ModelList,
58
+ ScoringRequest,
59
+ V1RerankReqInput,
60
+ )
61
+ from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
62
+ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
63
+ from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
64
+ from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
65
+ from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
50
66
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
51
67
  from sglang.srt.managers.io_struct import (
52
68
  AbortReq,
@@ -67,26 +83,11 @@ from sglang.srt.managers.io_struct import (
67
83
  UpdateWeightFromDiskReqInput,
68
84
  UpdateWeightsFromDistributedReqInput,
69
85
  UpdateWeightsFromTensorReqInput,
70
- V1RerankReqInput,
71
86
  VertexGenerateReqInput,
72
87
  )
88
+ from sglang.srt.managers.template_manager import TemplateManager
73
89
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
74
90
  from sglang.srt.metrics.func_timer import enable_func_timer
75
- from sglang.srt.openai_api.adapter import (
76
- v1_batches,
77
- v1_cancel_batch,
78
- v1_chat_completions,
79
- v1_completions,
80
- v1_delete_file,
81
- v1_embeddings,
82
- v1_files_create,
83
- v1_rerank,
84
- v1_retrieve_batch,
85
- v1_retrieve_file,
86
- v1_retrieve_file_content,
87
- v1_score,
88
- )
89
- from sglang.srt.openai_api.protocol import ModelCard, ModelList
90
91
  from sglang.srt.reasoning_parser import ReasoningParser
91
92
  from sglang.srt.server_args import ServerArgs
92
93
  from sglang.srt.utils import (
@@ -109,6 +110,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
109
110
  @dataclasses.dataclass
110
111
  class _GlobalState:
111
112
  tokenizer_manager: TokenizerManager
113
+ template_manager: TemplateManager
112
114
  scheduler_info: Dict
113
115
 
114
116
 
@@ -123,6 +125,24 @@ def set_global_state(global_state: _GlobalState):
123
125
  @asynccontextmanager
124
126
  async def lifespan(fast_api_app: FastAPI):
125
127
  server_args: ServerArgs = fast_api_app.server_args
128
+
129
+ # Initialize OpenAI serving handlers
130
+ fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
131
+ _global_state.tokenizer_manager, _global_state.template_manager
132
+ )
133
+ fast_api_app.state.openai_serving_chat = OpenAIServingChat(
134
+ _global_state.tokenizer_manager, _global_state.template_manager
135
+ )
136
+ fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
137
+ _global_state.tokenizer_manager, _global_state.template_manager
138
+ )
139
+ fast_api_app.state.openai_serving_score = OpenAIServingScore(
140
+ _global_state.tokenizer_manager
141
+ )
142
+ fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
143
+ _global_state.tokenizer_manager
144
+ )
145
+
126
146
  if server_args.warmups is not None:
127
147
  await execute_warmups(
128
148
  server_args.warmups.split(","), _global_state.tokenizer_manager
@@ -148,6 +168,47 @@ app.add_middleware(
148
168
  allow_headers=["*"],
149
169
  )
150
170
 
171
+
172
+ # Custom exception handlers to change validation error status codes
173
+ @app.exception_handler(RequestValidationError)
174
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
175
+ """Override FastAPI's default 422 validation error with 400"""
176
+ exc_str = str(exc)
177
+ errors_str = str(exc.errors())
178
+
179
+ if errors_str and errors_str != exc_str:
180
+ message = f"{exc_str} {errors_str}"
181
+ else:
182
+ message = exc_str
183
+
184
+ err = ErrorResponse(
185
+ message=message,
186
+ type=HTTPStatus.BAD_REQUEST.phrase,
187
+ code=HTTPStatus.BAD_REQUEST.value,
188
+ )
189
+
190
+ return ORJSONResponse(
191
+ status_code=400,
192
+ content=err.model_dump(),
193
+ )
194
+
195
+
196
+ async def validate_json_request(raw_request: Request):
197
+ """Validate that the request content-type is application/json."""
198
+ content_type = raw_request.headers.get("content-type", "").lower()
199
+ media_type = content_type.split(";", maxsplit=1)[0]
200
+ if media_type != "application/json":
201
+ raise RequestValidationError(
202
+ errors=[
203
+ {
204
+ "loc": ["header", "content-type"],
205
+ "msg": "Unsupported Media Type: Only 'application/json' is allowed",
206
+ "type": "value_error",
207
+ }
208
+ ]
209
+ )
210
+
211
+
151
212
  HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
152
213
 
153
214
 
@@ -330,13 +391,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
330
391
  return _create_error_response(e)
331
392
 
332
393
 
333
- @app.api_route("/v1/rerank", methods=["POST", "PUT"])
334
- async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
335
- try:
336
- ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
337
- return ret
338
- except ValueError as e:
339
- return _create_error_response(e)
394
+ @app.api_route(
395
+ "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
396
+ )
397
+ async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
398
+ """Endpoint for reranking documents based on query relevance."""
399
+ return await raw_request.app.state.openai_serving_rerank.handle_request(
400
+ request, raw_request
401
+ )
340
402
 
341
403
 
342
404
  @app.api_route("/flush_cache", methods=["GET", "POST"])
@@ -619,25 +681,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
619
681
  ##### OpenAI-compatible API endpoints #####
620
682
 
621
683
 
622
- @app.post("/v1/completions")
623
- async def openai_v1_completions(raw_request: Request):
624
- return await v1_completions(_global_state.tokenizer_manager, raw_request)
684
+ @app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
685
+ async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
686
+ """OpenAI-compatible text completion endpoint."""
687
+ return await raw_request.app.state.openai_serving_completion.handle_request(
688
+ request, raw_request
689
+ )
625
690
 
626
691
 
627
- @app.post("/v1/chat/completions")
628
- async def openai_v1_chat_completions(raw_request: Request):
629
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
692
+ @app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
693
+ async def openai_v1_chat_completions(
694
+ request: ChatCompletionRequest, raw_request: Request
695
+ ):
696
+ """OpenAI-compatible chat completion endpoint."""
697
+ return await raw_request.app.state.openai_serving_chat.handle_request(
698
+ request, raw_request
699
+ )
630
700
 
631
701
 
632
- @app.post("/v1/embeddings", response_class=ORJSONResponse)
633
- async def openai_v1_embeddings(raw_request: Request):
634
- response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
635
- return response
702
+ @app.post(
703
+ "/v1/embeddings",
704
+ response_class=ORJSONResponse,
705
+ dependencies=[Depends(validate_json_request)],
706
+ )
707
+ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
708
+ """OpenAI-compatible embeddings endpoint."""
709
+ return await raw_request.app.state.openai_serving_embedding.handle_request(
710
+ request, raw_request
711
+ )
636
712
 
637
713
 
638
714
  @app.get("/v1/models", response_class=ORJSONResponse)
639
- def available_models():
640
- """Show available models."""
715
+ async def available_models():
716
+ """Show available models. OpenAI-compatible endpoint."""
641
717
  served_model_names = [_global_state.tokenizer_manager.served_model_name]
642
718
  model_cards = []
643
719
  for served_model_name in served_model_names:
@@ -651,45 +727,29 @@ def available_models():
651
727
  return ModelList(data=model_cards)
652
728
 
653
729
 
654
- @app.post("/v1/files")
655
- async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
656
- return await v1_files_create(
657
- file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
658
- )
659
-
660
-
661
- @app.delete("/v1/files/{file_id}")
662
- async def delete_file(file_id: str):
663
- # https://platform.openai.com/docs/api-reference/files/delete
664
- return await v1_delete_file(file_id)
665
-
666
-
667
- @app.post("/v1/batches")
668
- async def openai_v1_batches(raw_request: Request):
669
- return await v1_batches(_global_state.tokenizer_manager, raw_request)
670
-
671
-
672
- @app.post("/v1/batches/{batch_id}/cancel")
673
- async def cancel_batches(batch_id: str):
674
- # https://platform.openai.com/docs/api-reference/batch/cancel
675
- return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
676
-
677
-
678
- @app.get("/v1/batches/{batch_id}")
679
- async def retrieve_batch(batch_id: str):
680
- return await v1_retrieve_batch(batch_id)
681
-
682
-
683
- @app.get("/v1/files/{file_id}")
684
- async def retrieve_file(file_id: str):
685
- # https://platform.openai.com/docs/api-reference/files/retrieve
686
- return await v1_retrieve_file(file_id)
730
+ @app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
731
+ async def retrieve_model(model: str):
732
+ """Retrieves a model instance, providing basic information about the model."""
733
+ served_model_names = [_global_state.tokenizer_manager.served_model_name]
687
734
 
735
+ if model not in served_model_names:
736
+ return ORJSONResponse(
737
+ status_code=404,
738
+ content={
739
+ "error": {
740
+ "message": f"The model '{model}' does not exist",
741
+ "type": "invalid_request_error",
742
+ "param": "model",
743
+ "code": "model_not_found",
744
+ }
745
+ },
746
+ )
688
747
 
689
- @app.get("/v1/files/{file_id}/content")
690
- async def retrieve_file_content(file_id: str):
691
- # https://platform.openai.com/docs/api-reference/files/retrieve-contents
692
- return await v1_retrieve_file_content(file_id)
748
+ return ModelCard(
749
+ id=model,
750
+ root=model,
751
+ max_model_len=_global_state.tokenizer_manager.model_config.context_len,
752
+ )
693
753
 
694
754
 
695
755
  ## SageMaker API
@@ -700,8 +760,13 @@ async def sagemaker_health() -> Response:
700
760
 
701
761
 
702
762
  @app.post("/invocations")
703
- async def sagemaker_chat_completions(raw_request: Request):
704
- return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
763
+ async def sagemaker_chat_completions(
764
+ request: ChatCompletionRequest, raw_request: Request
765
+ ):
766
+ """OpenAI-compatible chat completion endpoint."""
767
+ return await raw_request.app.state.openai_serving_chat.handle_request(
768
+ request, raw_request
769
+ )
705
770
 
706
771
 
707
772
  ## Vertex AI API
@@ -732,10 +797,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
732
797
  return ORJSONResponse({"predictions": ret})
733
798
 
734
799
 
735
- @app.post("/v1/score")
736
- async def v1_score_request(raw_request: Request):
800
+ @app.post("/v1/score", dependencies=[Depends(validate_json_request)])
801
+ async def v1_score_request(request: ScoringRequest, raw_request: Request):
737
802
  """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
738
- return await v1_score(_global_state.tokenizer_manager, raw_request)
803
+ return await raw_request.app.state.openai_serving_score.handle_request(
804
+ request, raw_request
805
+ )
739
806
 
740
807
 
741
808
  def _create_error_response(e):
@@ -764,10 +831,13 @@ def launch_server(
764
831
  1. The HTTP server, Engine, and TokenizerManager both run in the main process.
765
832
  2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
766
833
  """
767
- tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
834
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
835
+ server_args=server_args
836
+ )
768
837
  set_global_state(
769
838
  _GlobalState(
770
839
  tokenizer_manager=tokenizer_manager,
840
+ template_manager=template_manager,
771
841
  scheduler_info=scheduler_info,
772
842
  )
773
843
  )
@@ -64,11 +64,9 @@ class HttpServerEngineAdapter(EngineBase):
64
64
 
65
65
  def _make_request(self, endpoint: str, payload: Optional[dict] = None):
66
66
  """Make a POST request to the specified endpoint with the given payload.
67
-
68
67
  Args:
69
68
  endpoint: The API endpoint to call
70
69
  payload: The JSON payload to send (default: empty dict)
71
-
72
70
  Returns:
73
71
  The JSON response from the server
74
72
  """
@@ -85,7 +83,6 @@ class HttpServerEngineAdapter(EngineBase):
85
83
  ):
86
84
  """
87
85
  Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
88
-
89
86
  Note: The model should be on GPUs rather than CPU for this functionality to work properly.
90
87
  If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
91
88
  """
File without changes
@@ -14,9 +14,16 @@
14
14
  """Pydantic models for OpenAI API protocol"""
15
15
 
16
16
  import time
17
- from typing import Dict, List, Optional, Union
18
-
19
- from pydantic import BaseModel, Field, model_serializer, root_validator
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Optional, Union
19
+
20
+ from pydantic import (
21
+ BaseModel,
22
+ Field,
23
+ field_validator,
24
+ model_serializer,
25
+ model_validator,
26
+ )
20
27
  from typing_extensions import Literal
21
28
 
22
29
 
@@ -167,6 +174,7 @@ class CompletionRequest(BaseModel):
167
174
  temperature: float = 1.0
168
175
  top_p: float = 1.0
169
176
  user: Optional[str] = None
177
+ return_hidden_states: bool = False
170
178
 
171
179
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
172
180
  top_k: int = -1
@@ -182,25 +190,37 @@ class CompletionRequest(BaseModel):
182
190
  skip_special_tokens: bool = True
183
191
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
184
192
  session_params: Optional[Dict] = None
185
- return_hidden_states: Optional[bool] = False
186
193
 
187
194
  # For PD disaggregation
188
195
  bootstrap_host: Optional[str] = None
189
196
  bootstrap_port: Optional[int] = None
190
197
  bootstrap_room: Optional[int] = None
191
198
 
199
+ # For request id
200
+ rid: Optional[Union[List[str], str]] = None
201
+
202
+ @field_validator("max_tokens")
203
+ @classmethod
204
+ def validate_max_tokens_positive(cls, v):
205
+ if v is not None and v <= 0:
206
+ raise ValueError("max_tokens must be positive")
207
+ return v
208
+
192
209
 
193
210
  class CompletionResponseChoice(BaseModel):
194
211
  index: int
195
212
  text: str
196
213
  logprobs: Optional[LogProbs] = None
197
- finish_reason: Literal["stop", "length", "content_filter", "abort"]
214
+ finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
198
215
  matched_stop: Union[None, int, str] = None
199
216
  hidden_states: Optional[object] = None
200
217
 
201
- @model_serializer
202
- def _serialize(self):
203
- return exclude_if_none(self, ["hidden_states"])
218
+ @model_serializer(mode="wrap")
219
+ def _serialize(self, handler):
220
+ data = handler(self)
221
+ if self.hidden_states is None:
222
+ data.pop("hidden_states", None)
223
+ return data
204
224
 
205
225
 
206
226
  class CompletionResponse(BaseModel):
@@ -220,9 +240,12 @@ class CompletionResponseStreamChoice(BaseModel):
220
240
  matched_stop: Union[None, int, str] = None
221
241
  hidden_states: Optional[object] = None
222
242
 
223
- @model_serializer
224
- def _serialize(self):
225
- return exclude_if_none(self, ["hidden_states"])
243
+ @model_serializer(mode="wrap")
244
+ def _serialize(self, handler):
245
+ data = handler(self)
246
+ if self.hidden_states is None:
247
+ data.pop("hidden_states", None)
248
+ return data
226
249
 
227
250
 
228
251
  class CompletionStreamResponse(BaseModel):
@@ -290,6 +313,18 @@ class ChatCompletionMessageGenericParam(BaseModel):
290
313
  reasoning_content: Optional[str] = None
291
314
  tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
292
315
 
316
+ @field_validator("role", mode="before")
317
+ @classmethod
318
+ def _normalize_role(cls, v):
319
+ if isinstance(v, str):
320
+ v_lower = v.lower()
321
+ if v_lower not in {"system", "assistant", "tool"}:
322
+ raise ValueError(
323
+ "'role' must be one of 'system', 'assistant', or 'tool' (case-insensitive)."
324
+ )
325
+ return v_lower
326
+ raise ValueError("'role' must be a string")
327
+
293
328
 
294
329
  class ChatCompletionMessageUserParam(BaseModel):
295
330
  role: Literal["user"]
@@ -380,8 +415,10 @@ class ChatCompletionRequest(BaseModel):
380
415
  tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
381
416
  default="auto", examples=["none"]
382
417
  ) # noqa
418
+ return_hidden_states: bool = False
383
419
 
384
- @root_validator(pre=True)
420
+ @model_validator(mode="before")
421
+ @classmethod
385
422
  def set_tool_choice_default(cls, values):
386
423
  if values.get("tool_choice") is None:
387
424
  if values.get("tools") is None:
@@ -408,17 +445,14 @@ class ChatCompletionRequest(BaseModel):
408
445
  stream_reasoning: bool = True
409
446
  chat_template_kwargs: Optional[Dict] = None
410
447
 
411
- # The request id.
412
- rid: Optional[str] = None
448
+ # For request id
449
+ rid: Optional[Union[List[str], str]] = None
413
450
 
414
451
  # For PD disaggregation
415
452
  bootstrap_host: Optional[str] = None
416
453
  bootstrap_port: Optional[int] = None
417
454
  bootstrap_room: Optional[int] = None
418
455
 
419
- # Hidden States
420
- return_hidden_states: Optional[bool] = False
421
-
422
456
 
423
457
  class ChatMessage(BaseModel):
424
458
  role: Optional[str] = None
@@ -431,15 +465,20 @@ class ChatCompletionResponseChoice(BaseModel):
431
465
  index: int
432
466
  message: ChatMessage
433
467
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
434
- finish_reason: Literal[
435
- "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
436
- ]
468
+ finish_reason: Optional[
469
+ Literal[
470
+ "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
471
+ ]
472
+ ] = None
437
473
  matched_stop: Union[None, int, str] = None
438
474
  hidden_states: Optional[object] = None
439
475
 
440
- @model_serializer
441
- def _serialize(self):
442
- return exclude_if_none(self, ["hidden_states"])
476
+ @model_serializer(mode="wrap")
477
+ def _serialize(self, handler):
478
+ data = handler(self)
479
+ if self.hidden_states is None:
480
+ data.pop("hidden_states", None)
481
+ return data
443
482
 
444
483
 
445
484
  class ChatCompletionResponse(BaseModel):
@@ -458,9 +497,12 @@ class DeltaMessage(BaseModel):
458
497
  tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
459
498
  hidden_states: Optional[object] = None
460
499
 
461
- @model_serializer
462
- def _serialize(self):
463
- return exclude_if_none(self, ["hidden_states"])
500
+ @model_serializer(mode="wrap")
501
+ def _serialize(self, handler):
502
+ data = handler(self)
503
+ if self.hidden_states is None:
504
+ data.pop("hidden_states", None)
505
+ return data
464
506
 
465
507
 
466
508
  class ChatCompletionResponseStreamChoice(BaseModel):
@@ -487,19 +529,22 @@ class MultimodalEmbeddingInput(BaseModel):
487
529
  image: Optional[str] = None
488
530
 
489
531
 
532
+ EmbeddingInput = Union[
533
+ List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
534
+ ]
535
+
536
+
490
537
  class EmbeddingRequest(BaseModel):
491
538
  # Ordered by official OpenAI API documentation
492
539
  # https://platform.openai.com/docs/api-reference/embeddings/create
493
- input: Union[
494
- List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
495
- ]
540
+ input: EmbeddingInput
496
541
  model: str
497
542
  encoding_format: str = "float"
498
- dimensions: int = None
543
+ dimensions: Optional[int] = None
499
544
  user: Optional[str] = None
500
545
 
501
546
  # The request id.
502
- rid: Optional[str] = None
547
+ rid: Optional[Union[List[str], str]] = None
503
548
 
504
549
 
505
550
  class EmbeddingObject(BaseModel):
@@ -539,6 +584,11 @@ class ScoringResponse(BaseModel):
539
584
  object: str = "scoring"
540
585
 
541
586
 
587
+ class V1RerankReqInput(BaseModel):
588
+ query: str
589
+ documents: List[str]
590
+
591
+
542
592
  class RerankResponse(BaseModel):
543
593
  score: float
544
594
  document: str
@@ -546,6 +596,37 @@ class RerankResponse(BaseModel):
546
596
  meta_info: Optional[dict] = None
547
597
 
548
598
 
549
- def exclude_if_none(obj, field_names: List[str]):
550
- omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
551
- return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
599
+ OpenAIServingRequest = Union[
600
+ ChatCompletionRequest,
601
+ CompletionRequest,
602
+ EmbeddingRequest,
603
+ ScoringRequest,
604
+ V1RerankReqInput,
605
+ ]
606
+
607
+
608
+ @dataclass
609
+ class MessageProcessingResult:
610
+ """Result of processing chat messages and applying templates.
611
+
612
+ This dataclass encapsulates all the outputs from message processing including
613
+ prompt generation, multimodal data extraction, and constraint preparation.
614
+ Used internally by OpenAIServingChat to pass processed data between methods.
615
+
616
+ Args:
617
+ prompt: The final text prompt after applying chat template
618
+ prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
619
+ image_data: Extracted image data from messages, if any
620
+ audio_data: Extracted audio data from messages, if any
621
+ modalities: List of modality types present in the messages
622
+ stop: Combined stop strings from template and request
623
+ tool_call_constraint: Optional constraint for structured tool calls
624
+ """
625
+
626
+ prompt: str
627
+ prompt_ids: Union[str, List[int]]
628
+ image_data: Optional[Any]
629
+ audio_data: Optional[Any]
630
+ modalities: List[str]
631
+ stop: List[str]
632
+ tool_call_constraint: Optional[Any] = None