sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -38,15 +38,31 @@ import orjson
|
|
38
38
|
import requests
|
39
39
|
import uvicorn
|
40
40
|
import uvloop
|
41
|
-
from fastapi import
|
41
|
+
from fastapi import Depends, FastAPI, Request, UploadFile
|
42
|
+
from fastapi.exceptions import RequestValidationError
|
42
43
|
from fastapi.middleware.cors import CORSMiddleware
|
43
44
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
45
|
|
45
46
|
from sglang.srt.disaggregation.utils import (
|
46
|
-
|
47
|
+
FAKE_BOOTSTRAP_HOST,
|
47
48
|
register_disaggregation_server,
|
48
49
|
)
|
49
50
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
51
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
52
|
+
ChatCompletionRequest,
|
53
|
+
CompletionRequest,
|
54
|
+
EmbeddingRequest,
|
55
|
+
ErrorResponse,
|
56
|
+
ModelCard,
|
57
|
+
ModelList,
|
58
|
+
ScoringRequest,
|
59
|
+
V1RerankReqInput,
|
60
|
+
)
|
61
|
+
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
62
|
+
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
63
|
+
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
64
|
+
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
65
|
+
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
50
66
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
51
67
|
from sglang.srt.managers.io_struct import (
|
52
68
|
AbortReq,
|
@@ -69,22 +85,9 @@ from sglang.srt.managers.io_struct import (
|
|
69
85
|
UpdateWeightsFromTensorReqInput,
|
70
86
|
VertexGenerateReqInput,
|
71
87
|
)
|
88
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
72
89
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
73
90
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
74
|
-
from sglang.srt.openai_api.adapter import (
|
75
|
-
v1_batches,
|
76
|
-
v1_cancel_batch,
|
77
|
-
v1_chat_completions,
|
78
|
-
v1_completions,
|
79
|
-
v1_delete_file,
|
80
|
-
v1_embeddings,
|
81
|
-
v1_files_create,
|
82
|
-
v1_retrieve_batch,
|
83
|
-
v1_retrieve_file,
|
84
|
-
v1_retrieve_file_content,
|
85
|
-
v1_score,
|
86
|
-
)
|
87
|
-
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
88
91
|
from sglang.srt.reasoning_parser import ReasoningParser
|
89
92
|
from sglang.srt.server_args import ServerArgs
|
90
93
|
from sglang.srt.utils import (
|
@@ -107,6 +110,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
107
110
|
@dataclasses.dataclass
|
108
111
|
class _GlobalState:
|
109
112
|
tokenizer_manager: TokenizerManager
|
113
|
+
template_manager: TemplateManager
|
110
114
|
scheduler_info: Dict
|
111
115
|
|
112
116
|
|
@@ -121,6 +125,24 @@ def set_global_state(global_state: _GlobalState):
|
|
121
125
|
@asynccontextmanager
|
122
126
|
async def lifespan(fast_api_app: FastAPI):
|
123
127
|
server_args: ServerArgs = fast_api_app.server_args
|
128
|
+
|
129
|
+
# Initialize OpenAI serving handlers
|
130
|
+
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
131
|
+
_global_state.tokenizer_manager, _global_state.template_manager
|
132
|
+
)
|
133
|
+
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
|
134
|
+
_global_state.tokenizer_manager, _global_state.template_manager
|
135
|
+
)
|
136
|
+
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
|
137
|
+
_global_state.tokenizer_manager, _global_state.template_manager
|
138
|
+
)
|
139
|
+
fast_api_app.state.openai_serving_score = OpenAIServingScore(
|
140
|
+
_global_state.tokenizer_manager
|
141
|
+
)
|
142
|
+
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
143
|
+
_global_state.tokenizer_manager
|
144
|
+
)
|
145
|
+
|
124
146
|
if server_args.warmups is not None:
|
125
147
|
await execute_warmups(
|
126
148
|
server_args.warmups.split(","), _global_state.tokenizer_manager
|
@@ -146,6 +168,47 @@ app.add_middleware(
|
|
146
168
|
allow_headers=["*"],
|
147
169
|
)
|
148
170
|
|
171
|
+
|
172
|
+
# Custom exception handlers to change validation error status codes
|
173
|
+
@app.exception_handler(RequestValidationError)
|
174
|
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
175
|
+
"""Override FastAPI's default 422 validation error with 400"""
|
176
|
+
exc_str = str(exc)
|
177
|
+
errors_str = str(exc.errors())
|
178
|
+
|
179
|
+
if errors_str and errors_str != exc_str:
|
180
|
+
message = f"{exc_str} {errors_str}"
|
181
|
+
else:
|
182
|
+
message = exc_str
|
183
|
+
|
184
|
+
err = ErrorResponse(
|
185
|
+
message=message,
|
186
|
+
type=HTTPStatus.BAD_REQUEST.phrase,
|
187
|
+
code=HTTPStatus.BAD_REQUEST.value,
|
188
|
+
)
|
189
|
+
|
190
|
+
return ORJSONResponse(
|
191
|
+
status_code=400,
|
192
|
+
content=err.model_dump(),
|
193
|
+
)
|
194
|
+
|
195
|
+
|
196
|
+
async def validate_json_request(raw_request: Request):
|
197
|
+
"""Validate that the request content-type is application/json."""
|
198
|
+
content_type = raw_request.headers.get("content-type", "").lower()
|
199
|
+
media_type = content_type.split(";", maxsplit=1)[0]
|
200
|
+
if media_type != "application/json":
|
201
|
+
raise RequestValidationError(
|
202
|
+
errors=[
|
203
|
+
{
|
204
|
+
"loc": ["header", "content-type"],
|
205
|
+
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
|
206
|
+
"type": "value_error",
|
207
|
+
}
|
208
|
+
]
|
209
|
+
)
|
210
|
+
|
211
|
+
|
149
212
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
150
213
|
|
151
214
|
|
@@ -328,6 +391,16 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
328
391
|
return _create_error_response(e)
|
329
392
|
|
330
393
|
|
394
|
+
@app.api_route(
|
395
|
+
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
396
|
+
)
|
397
|
+
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
398
|
+
"""Endpoint for reranking documents based on query relevance."""
|
399
|
+
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
400
|
+
request, raw_request
|
401
|
+
)
|
402
|
+
|
403
|
+
|
331
404
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
332
405
|
async def flush_cache():
|
333
406
|
"""Flush the radix cache."""
|
@@ -608,25 +681,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
|
|
608
681
|
##### OpenAI-compatible API endpoints #####
|
609
682
|
|
610
683
|
|
611
|
-
@app.post("/v1/completions")
|
612
|
-
async def openai_v1_completions(raw_request: Request):
|
613
|
-
|
684
|
+
@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
685
|
+
async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
|
686
|
+
"""OpenAI-compatible text completion endpoint."""
|
687
|
+
return await raw_request.app.state.openai_serving_completion.handle_request(
|
688
|
+
request, raw_request
|
689
|
+
)
|
614
690
|
|
615
691
|
|
616
|
-
@app.post("/v1/chat/completions")
|
617
|
-
async def openai_v1_chat_completions(
|
618
|
-
|
692
|
+
@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
|
693
|
+
async def openai_v1_chat_completions(
|
694
|
+
request: ChatCompletionRequest, raw_request: Request
|
695
|
+
):
|
696
|
+
"""OpenAI-compatible chat completion endpoint."""
|
697
|
+
return await raw_request.app.state.openai_serving_chat.handle_request(
|
698
|
+
request, raw_request
|
699
|
+
)
|
619
700
|
|
620
701
|
|
621
|
-
@app.post(
|
622
|
-
|
623
|
-
|
624
|
-
|
702
|
+
@app.post(
|
703
|
+
"/v1/embeddings",
|
704
|
+
response_class=ORJSONResponse,
|
705
|
+
dependencies=[Depends(validate_json_request)],
|
706
|
+
)
|
707
|
+
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
708
|
+
"""OpenAI-compatible embeddings endpoint."""
|
709
|
+
return await raw_request.app.state.openai_serving_embedding.handle_request(
|
710
|
+
request, raw_request
|
711
|
+
)
|
625
712
|
|
626
713
|
|
627
714
|
@app.get("/v1/models", response_class=ORJSONResponse)
|
628
|
-
def available_models():
|
629
|
-
"""Show available models."""
|
715
|
+
async def available_models():
|
716
|
+
"""Show available models. OpenAI-compatible endpoint."""
|
630
717
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
631
718
|
model_cards = []
|
632
719
|
for served_model_name in served_model_names:
|
@@ -640,45 +727,29 @@ def available_models():
|
|
640
727
|
return ModelList(data=model_cards)
|
641
728
|
|
642
729
|
|
643
|
-
@app.
|
644
|
-
async def
|
645
|
-
|
646
|
-
|
647
|
-
)
|
648
|
-
|
649
|
-
|
650
|
-
@app.delete("/v1/files/{file_id}")
|
651
|
-
async def delete_file(file_id: str):
|
652
|
-
# https://platform.openai.com/docs/api-reference/files/delete
|
653
|
-
return await v1_delete_file(file_id)
|
654
|
-
|
655
|
-
|
656
|
-
@app.post("/v1/batches")
|
657
|
-
async def openai_v1_batches(raw_request: Request):
|
658
|
-
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
659
|
-
|
660
|
-
|
661
|
-
@app.post("/v1/batches/{batch_id}/cancel")
|
662
|
-
async def cancel_batches(batch_id: str):
|
663
|
-
# https://platform.openai.com/docs/api-reference/batch/cancel
|
664
|
-
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
665
|
-
|
666
|
-
|
667
|
-
@app.get("/v1/batches/{batch_id}")
|
668
|
-
async def retrieve_batch(batch_id: str):
|
669
|
-
return await v1_retrieve_batch(batch_id)
|
670
|
-
|
671
|
-
|
672
|
-
@app.get("/v1/files/{file_id}")
|
673
|
-
async def retrieve_file(file_id: str):
|
674
|
-
# https://platform.openai.com/docs/api-reference/files/retrieve
|
675
|
-
return await v1_retrieve_file(file_id)
|
730
|
+
@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
|
731
|
+
async def retrieve_model(model: str):
|
732
|
+
"""Retrieves a model instance, providing basic information about the model."""
|
733
|
+
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
676
734
|
|
735
|
+
if model not in served_model_names:
|
736
|
+
return ORJSONResponse(
|
737
|
+
status_code=404,
|
738
|
+
content={
|
739
|
+
"error": {
|
740
|
+
"message": f"The model '{model}' does not exist",
|
741
|
+
"type": "invalid_request_error",
|
742
|
+
"param": "model",
|
743
|
+
"code": "model_not_found",
|
744
|
+
}
|
745
|
+
},
|
746
|
+
)
|
677
747
|
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
748
|
+
return ModelCard(
|
749
|
+
id=model,
|
750
|
+
root=model,
|
751
|
+
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
|
752
|
+
)
|
682
753
|
|
683
754
|
|
684
755
|
## SageMaker API
|
@@ -689,8 +760,13 @@ async def sagemaker_health() -> Response:
|
|
689
760
|
|
690
761
|
|
691
762
|
@app.post("/invocations")
|
692
|
-
async def sagemaker_chat_completions(
|
693
|
-
|
763
|
+
async def sagemaker_chat_completions(
|
764
|
+
request: ChatCompletionRequest, raw_request: Request
|
765
|
+
):
|
766
|
+
"""OpenAI-compatible chat completion endpoint."""
|
767
|
+
return await raw_request.app.state.openai_serving_chat.handle_request(
|
768
|
+
request, raw_request
|
769
|
+
)
|
694
770
|
|
695
771
|
|
696
772
|
## Vertex AI API
|
@@ -721,10 +797,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|
721
797
|
return ORJSONResponse({"predictions": ret})
|
722
798
|
|
723
799
|
|
724
|
-
@app.post("/v1/score")
|
725
|
-
async def v1_score_request(raw_request: Request):
|
800
|
+
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
801
|
+
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
726
802
|
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
727
|
-
return await
|
803
|
+
return await raw_request.app.state.openai_serving_score.handle_request(
|
804
|
+
request, raw_request
|
805
|
+
)
|
728
806
|
|
729
807
|
|
730
808
|
def _create_error_response(e):
|
@@ -753,10 +831,13 @@ def launch_server(
|
|
753
831
|
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
754
832
|
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
755
833
|
"""
|
756
|
-
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
834
|
+
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
835
|
+
server_args=server_args
|
836
|
+
)
|
757
837
|
set_global_state(
|
758
838
|
_GlobalState(
|
759
839
|
tokenizer_manager=tokenizer_manager,
|
840
|
+
template_manager=template_manager,
|
760
841
|
scheduler_info=scheduler_info,
|
761
842
|
)
|
762
843
|
)
|
@@ -878,7 +959,7 @@ def _wait_and_warmup(
|
|
878
959
|
"max_new_tokens": 8,
|
879
960
|
"ignore_eos": True,
|
880
961
|
},
|
881
|
-
"bootstrap_host": [
|
962
|
+
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
882
963
|
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
883
964
|
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
884
965
|
"bootstrap_room": [
|
@@ -64,11 +64,9 @@ class HttpServerEngineAdapter(EngineBase):
|
|
64
64
|
|
65
65
|
def _make_request(self, endpoint: str, payload: Optional[dict] = None):
|
66
66
|
"""Make a POST request to the specified endpoint with the given payload.
|
67
|
-
|
68
67
|
Args:
|
69
68
|
endpoint: The API endpoint to call
|
70
69
|
payload: The JSON payload to send (default: empty dict)
|
71
|
-
|
72
70
|
Returns:
|
73
71
|
The JSON response from the server
|
74
72
|
"""
|
@@ -85,7 +83,6 @@ class HttpServerEngineAdapter(EngineBase):
|
|
85
83
|
):
|
86
84
|
"""
|
87
85
|
Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
|
88
|
-
|
89
86
|
Note: The model should be on GPUs rather than CPU for this functionality to work properly.
|
90
87
|
If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
|
91
88
|
"""
|
File without changes
|
@@ -16,7 +16,13 @@
|
|
16
16
|
import time
|
17
17
|
from typing import Dict, List, Optional, Union
|
18
18
|
|
19
|
-
from pydantic import
|
19
|
+
from pydantic import (
|
20
|
+
BaseModel,
|
21
|
+
Field,
|
22
|
+
field_validator,
|
23
|
+
model_serializer,
|
24
|
+
model_validator,
|
25
|
+
)
|
20
26
|
from typing_extensions import Literal
|
21
27
|
|
22
28
|
|
@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
|
|
167
173
|
temperature: float = 1.0
|
168
174
|
top_p: float = 1.0
|
169
175
|
user: Optional[str] = None
|
176
|
+
return_hidden_states: bool = False
|
170
177
|
|
171
178
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
172
179
|
top_k: int = -1
|
@@ -188,13 +195,28 @@ class CompletionRequest(BaseModel):
|
|
188
195
|
bootstrap_port: Optional[int] = None
|
189
196
|
bootstrap_room: Optional[int] = None
|
190
197
|
|
198
|
+
@field_validator("max_tokens")
|
199
|
+
@classmethod
|
200
|
+
def validate_max_tokens_positive(cls, v):
|
201
|
+
if v is not None and v <= 0:
|
202
|
+
raise ValueError("max_tokens must be positive")
|
203
|
+
return v
|
204
|
+
|
191
205
|
|
192
206
|
class CompletionResponseChoice(BaseModel):
|
193
207
|
index: int
|
194
208
|
text: str
|
195
209
|
logprobs: Optional[LogProbs] = None
|
196
|
-
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
210
|
+
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
197
211
|
matched_stop: Union[None, int, str] = None
|
212
|
+
hidden_states: Optional[object] = None
|
213
|
+
|
214
|
+
@model_serializer(mode="wrap")
|
215
|
+
def _serialize(self, handler):
|
216
|
+
data = handler(self)
|
217
|
+
if self.hidden_states is None:
|
218
|
+
data.pop("hidden_states", None)
|
219
|
+
return data
|
198
220
|
|
199
221
|
|
200
222
|
class CompletionResponse(BaseModel):
|
@@ -212,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
|
|
212
234
|
logprobs: Optional[LogProbs] = None
|
213
235
|
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
214
236
|
matched_stop: Union[None, int, str] = None
|
237
|
+
hidden_states: Optional[object] = None
|
238
|
+
|
239
|
+
@model_serializer(mode="wrap")
|
240
|
+
def _serialize(self, handler):
|
241
|
+
data = handler(self)
|
242
|
+
if self.hidden_states is None:
|
243
|
+
data.pop("hidden_states", None)
|
244
|
+
return data
|
215
245
|
|
216
246
|
|
217
247
|
class CompletionStreamResponse(BaseModel):
|
@@ -369,8 +399,10 @@ class ChatCompletionRequest(BaseModel):
|
|
369
399
|
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
370
400
|
default="auto", examples=["none"]
|
371
401
|
) # noqa
|
402
|
+
return_hidden_states: bool = False
|
372
403
|
|
373
|
-
@
|
404
|
+
@model_validator(mode="before")
|
405
|
+
@classmethod
|
374
406
|
def set_tool_choice_default(cls, values):
|
375
407
|
if values.get("tool_choice") is None:
|
376
408
|
if values.get("tools") is None:
|
@@ -417,10 +449,20 @@ class ChatCompletionResponseChoice(BaseModel):
|
|
417
449
|
index: int
|
418
450
|
message: ChatMessage
|
419
451
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
420
|
-
finish_reason:
|
421
|
-
|
422
|
-
|
452
|
+
finish_reason: Optional[
|
453
|
+
Literal[
|
454
|
+
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
455
|
+
]
|
456
|
+
] = None
|
423
457
|
matched_stop: Union[None, int, str] = None
|
458
|
+
hidden_states: Optional[object] = None
|
459
|
+
|
460
|
+
@model_serializer(mode="wrap")
|
461
|
+
def _serialize(self, handler):
|
462
|
+
data = handler(self)
|
463
|
+
if self.hidden_states is None:
|
464
|
+
data.pop("hidden_states", None)
|
465
|
+
return data
|
424
466
|
|
425
467
|
|
426
468
|
class ChatCompletionResponse(BaseModel):
|
@@ -437,6 +479,14 @@ class DeltaMessage(BaseModel):
|
|
437
479
|
content: Optional[str] = None
|
438
480
|
reasoning_content: Optional[str] = None
|
439
481
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
482
|
+
hidden_states: Optional[object] = None
|
483
|
+
|
484
|
+
@model_serializer(mode="wrap")
|
485
|
+
def _serialize(self, handler):
|
486
|
+
data = handler(self)
|
487
|
+
if self.hidden_states is None:
|
488
|
+
data.pop("hidden_states", None)
|
489
|
+
return data
|
440
490
|
|
441
491
|
|
442
492
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
@@ -463,15 +513,18 @@ class MultimodalEmbeddingInput(BaseModel):
|
|
463
513
|
image: Optional[str] = None
|
464
514
|
|
465
515
|
|
516
|
+
EmbeddingInput = Union[
|
517
|
+
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
518
|
+
]
|
519
|
+
|
520
|
+
|
466
521
|
class EmbeddingRequest(BaseModel):
|
467
522
|
# Ordered by official OpenAI API documentation
|
468
523
|
# https://platform.openai.com/docs/api-reference/embeddings/create
|
469
|
-
input:
|
470
|
-
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
471
|
-
]
|
524
|
+
input: EmbeddingInput
|
472
525
|
model: str
|
473
526
|
encoding_format: str = "float"
|
474
|
-
dimensions: int = None
|
527
|
+
dimensions: Optional[int] = None
|
475
528
|
user: Optional[str] = None
|
476
529
|
|
477
530
|
# The request id.
|
@@ -513,3 +566,24 @@ class ScoringResponse(BaseModel):
|
|
513
566
|
model: str
|
514
567
|
usage: Optional[UsageInfo] = None
|
515
568
|
object: str = "scoring"
|
569
|
+
|
570
|
+
|
571
|
+
class V1RerankReqInput(BaseModel):
|
572
|
+
query: str
|
573
|
+
documents: List[str]
|
574
|
+
|
575
|
+
|
576
|
+
class RerankResponse(BaseModel):
|
577
|
+
score: float
|
578
|
+
document: str
|
579
|
+
index: int
|
580
|
+
meta_info: Optional[dict] = None
|
581
|
+
|
582
|
+
|
583
|
+
OpenAIServingRequest = Union[
|
584
|
+
ChatCompletionRequest,
|
585
|
+
CompletionRequest,
|
586
|
+
EmbeddingRequest,
|
587
|
+
ScoringRequest,
|
588
|
+
V1RerankReqInput,
|
589
|
+
]
|
@@ -0,0 +1,149 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import uuid
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import Any, Optional, Union
|
6
|
+
|
7
|
+
from fastapi import Request
|
8
|
+
from fastapi.responses import ORJSONResponse, StreamingResponse
|
9
|
+
|
10
|
+
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
11
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
12
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
# Base class for specific endpoint handlers
|
18
|
+
class OpenAIServingBase(ABC):
|
19
|
+
"""Abstract base class for OpenAI endpoint handlers"""
|
20
|
+
|
21
|
+
def __init__(self, tokenizer_manager: TokenizerManager):
|
22
|
+
self.tokenizer_manager = tokenizer_manager
|
23
|
+
|
24
|
+
async def handle_request(
|
25
|
+
self, request: OpenAIServingRequest, raw_request: Request
|
26
|
+
) -> Union[Any, StreamingResponse, ErrorResponse]:
|
27
|
+
"""Handle the specific request type with common pattern"""
|
28
|
+
try:
|
29
|
+
# Validate request
|
30
|
+
error_msg = self._validate_request(request)
|
31
|
+
if error_msg:
|
32
|
+
return self.create_error_response(error_msg)
|
33
|
+
|
34
|
+
# Convert to internal format
|
35
|
+
adapted_request, processed_request = self._convert_to_internal_request(
|
36
|
+
request
|
37
|
+
)
|
38
|
+
|
39
|
+
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
40
|
+
if hasattr(request, "stream") and request.stream:
|
41
|
+
return await self._handle_streaming_request(
|
42
|
+
adapted_request, processed_request, raw_request
|
43
|
+
)
|
44
|
+
else:
|
45
|
+
return await self._handle_non_streaming_request(
|
46
|
+
adapted_request, processed_request, raw_request
|
47
|
+
)
|
48
|
+
|
49
|
+
except Exception as e:
|
50
|
+
logger.exception(f"Error in request: {e}")
|
51
|
+
return self.create_error_response(
|
52
|
+
message=f"Internal server error: {str(e)}",
|
53
|
+
err_type="InternalServerError",
|
54
|
+
status_code=500,
|
55
|
+
)
|
56
|
+
|
57
|
+
@abstractmethod
|
58
|
+
def _request_id_prefix(self) -> str:
|
59
|
+
"""Generate request ID based on request type"""
|
60
|
+
pass
|
61
|
+
|
62
|
+
def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]:
|
63
|
+
"""Generate request ID based on request type"""
|
64
|
+
return None
|
65
|
+
|
66
|
+
# TODO(chang): the rid is used in io_strcut check and often violates `The rid should be a list` AssertionError
|
67
|
+
# Temporarily return None in this function until the rid logic is clear.
|
68
|
+
if rid := getattr(request, "rid", None):
|
69
|
+
return rid
|
70
|
+
|
71
|
+
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
72
|
+
|
73
|
+
@abstractmethod
|
74
|
+
def _convert_to_internal_request(
|
75
|
+
self,
|
76
|
+
request: OpenAIServingRequest,
|
77
|
+
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
|
78
|
+
"""Convert OpenAI request to internal format"""
|
79
|
+
pass
|
80
|
+
|
81
|
+
async def _handle_streaming_request(
|
82
|
+
self,
|
83
|
+
adapted_request: GenerateReqInput,
|
84
|
+
request: OpenAIServingRequest,
|
85
|
+
raw_request: Request,
|
86
|
+
) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]:
|
87
|
+
"""Handle streaming request
|
88
|
+
|
89
|
+
Override this method in child classes that support streaming requests.
|
90
|
+
"""
|
91
|
+
return self.create_error_response(
|
92
|
+
message=f"{self.__class__.__name__} does not support streaming requests",
|
93
|
+
err_type="NotImplementedError",
|
94
|
+
status_code=501,
|
95
|
+
)
|
96
|
+
|
97
|
+
async def _handle_non_streaming_request(
|
98
|
+
self,
|
99
|
+
adapted_request: GenerateReqInput,
|
100
|
+
request: OpenAIServingRequest,
|
101
|
+
raw_request: Request,
|
102
|
+
) -> Union[Any, ErrorResponse, ORJSONResponse]:
|
103
|
+
"""Handle non-streaming request
|
104
|
+
|
105
|
+
Override this method in child classes that support non-streaming requests.
|
106
|
+
"""
|
107
|
+
return self.create_error_response(
|
108
|
+
message=f"{self.__class__.__name__} does not support non-streaming requests",
|
109
|
+
err_type="NotImplementedError",
|
110
|
+
status_code=501,
|
111
|
+
)
|
112
|
+
|
113
|
+
def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]:
|
114
|
+
"""Validate request"""
|
115
|
+
pass
|
116
|
+
|
117
|
+
def create_error_response(
|
118
|
+
self,
|
119
|
+
message: str,
|
120
|
+
err_type: str = "BadRequestError",
|
121
|
+
status_code: int = 400,
|
122
|
+
param: Optional[str] = None,
|
123
|
+
) -> ORJSONResponse:
|
124
|
+
"""Create an error response"""
|
125
|
+
# TODO: remove fastapi dependency in openai and move response handling to the entrypoint
|
126
|
+
error = ErrorResponse(
|
127
|
+
object="error",
|
128
|
+
message=message,
|
129
|
+
type=err_type,
|
130
|
+
param=param,
|
131
|
+
code=status_code,
|
132
|
+
)
|
133
|
+
return ORJSONResponse(content=error.model_dump(), status_code=status_code)
|
134
|
+
|
135
|
+
def create_streaming_error_response(
|
136
|
+
self,
|
137
|
+
message: str,
|
138
|
+
err_type: str = "BadRequestError",
|
139
|
+
status_code: int = 400,
|
140
|
+
) -> str:
|
141
|
+
"""Create a streaming error response"""
|
142
|
+
error = ErrorResponse(
|
143
|
+
object="error",
|
144
|
+
message=message,
|
145
|
+
type=err_type,
|
146
|
+
param=None,
|
147
|
+
code=status_code,
|
148
|
+
)
|
149
|
+
return json.dumps({"error": error.model_dump()})
|