sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import random
|
|
6
6
|
import threading
|
7
7
|
import warnings
|
8
8
|
from collections import deque
|
9
|
+
from contextlib import nullcontext
|
9
10
|
from enum import Enum
|
10
11
|
from typing import TYPE_CHECKING, List, Optional
|
11
12
|
|
@@ -84,28 +85,48 @@ class ReqToMetadataIdxAllocator:
|
|
84
85
|
|
85
86
|
|
86
87
|
class MetadataBuffers:
|
87
|
-
def __init__(
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
self.
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
size: int,
|
91
|
+
hidden_size: int,
|
92
|
+
dtype: torch.dtype,
|
93
|
+
max_top_logprobs_num: int = 128,
|
94
|
+
custom_mem_pool: torch.cuda.MemPool = None,
|
95
|
+
):
|
96
|
+
self.custom_mem_pool = custom_mem_pool
|
97
|
+
device = "cuda" if self.custom_mem_pool else "cpu"
|
98
|
+
|
99
|
+
with (
|
100
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
101
|
+
if self.custom_mem_pool
|
102
|
+
else nullcontext()
|
103
|
+
):
|
104
|
+
# TODO: abort top_logprobs_num > 128 in PD
|
105
|
+
|
106
|
+
# We transfer the metadata of first output token to decode
|
107
|
+
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
108
|
+
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
109
|
+
|
110
|
+
self.output_hidden_states = torch.zeros(
|
111
|
+
(size, hidden_size), dtype=dtype, device=device
|
112
|
+
)
|
113
|
+
self.output_token_logprobs_val = torch.zeros(
|
114
|
+
(size, 16), dtype=torch.float32, device=device
|
115
|
+
)
|
116
|
+
self.output_token_logprobs_idx = torch.zeros(
|
117
|
+
(size, 16), dtype=torch.int32, device=device
|
118
|
+
)
|
119
|
+
self.output_top_logprobs_val = torch.zeros(
|
120
|
+
(size, max_top_logprobs_num), dtype=torch.float32, device=device
|
121
|
+
)
|
122
|
+
self.output_top_logprobs_idx = torch.zeros(
|
123
|
+
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
124
|
+
)
|
105
125
|
|
106
126
|
def get_buf_infos(self):
|
107
127
|
ptrs = [
|
108
128
|
self.output_ids.data_ptr(),
|
129
|
+
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
109
130
|
self.output_token_logprobs_val.data_ptr(),
|
110
131
|
self.output_token_logprobs_idx.data_ptr(),
|
111
132
|
self.output_top_logprobs_val.data_ptr(),
|
@@ -113,6 +134,7 @@ class MetadataBuffers:
|
|
113
134
|
]
|
114
135
|
data_lens = [
|
115
136
|
self.output_ids.nbytes,
|
137
|
+
self.output_hidden_states.nbytes,
|
116
138
|
self.output_token_logprobs_val.nbytes,
|
117
139
|
self.output_token_logprobs_idx.nbytes,
|
118
140
|
self.output_top_logprobs_val.nbytes,
|
@@ -120,6 +142,7 @@ class MetadataBuffers:
|
|
120
142
|
]
|
121
143
|
item_lens = [
|
122
144
|
self.output_ids[0].nbytes,
|
145
|
+
self.output_hidden_states[0].nbytes,
|
123
146
|
self.output_token_logprobs_val[0].nbytes,
|
124
147
|
self.output_token_logprobs_idx[0].nbytes,
|
125
148
|
self.output_top_logprobs_val[0].nbytes,
|
@@ -130,6 +153,7 @@ class MetadataBuffers:
|
|
130
153
|
def get_buf(self, idx: int):
|
131
154
|
return (
|
132
155
|
self.output_ids[idx],
|
156
|
+
self.output_hidden_states[idx],
|
133
157
|
self.output_token_logprobs_val[idx],
|
134
158
|
self.output_token_logprobs_idx[idx],
|
135
159
|
self.output_top_logprobs_val[idx],
|
@@ -139,6 +163,10 @@ class MetadataBuffers:
|
|
139
163
|
def set_buf(self, req: Req):
|
140
164
|
|
141
165
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
166
|
+
if req.hidden_states_tensor is not None:
|
167
|
+
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
168
|
+
req.hidden_states_tensor
|
169
|
+
)
|
142
170
|
if req.return_logprob:
|
143
171
|
if req.output_token_logprobs_val: # not none or empty list
|
144
172
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
@@ -523,17 +523,25 @@ class GroupCoordinator:
|
|
523
523
|
self,
|
524
524
|
input_: torch.Tensor,
|
525
525
|
dim: int = -1,
|
526
|
-
|
526
|
+
output_tensor_list: Optional[List[torch.Tensor]] = None,
|
527
527
|
) -> torch.Tensor:
|
528
528
|
world_size = self.world_size
|
529
529
|
# Bypass the function if we are using only 1 GPU.
|
530
530
|
if world_size == 1:
|
531
|
-
|
531
|
+
if output_tensor_list is not None:
|
532
|
+
logger.warning(
|
533
|
+
"Performing in-place all-gather with a group size of 1. "
|
534
|
+
"This may be unnecessary; consider bypassing it for better efficiency."
|
535
|
+
)
|
536
|
+
output_tensor_list[0].copy_(input_)
|
537
|
+
return None
|
538
|
+
else:
|
539
|
+
return input_
|
532
540
|
|
533
|
-
if
|
541
|
+
if output_tensor_list is not None:
|
534
542
|
# TODO(ch-wan): support other backends
|
535
543
|
return torch.distributed.all_gather(
|
536
|
-
|
544
|
+
output_tensor_list, input_, group=self.device_group
|
537
545
|
)
|
538
546
|
|
539
547
|
assert (
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
|
37
37
|
import torch
|
38
38
|
import uvloop
|
39
39
|
|
40
|
-
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
41
40
|
from sglang.srt.entrypoints.EngineBase import EngineBase
|
42
41
|
from sglang.srt.managers.data_parallel_controller import (
|
43
42
|
run_data_parallel_controller_process,
|
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
|
|
58
57
|
UpdateWeightsFromTensorReqInput,
|
59
58
|
)
|
60
59
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
60
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
61
61
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
62
|
-
from sglang.srt.openai_api.adapter import (
|
63
|
-
guess_chat_template_name_from_model_path,
|
64
|
-
load_chat_template_for_openai_api,
|
65
|
-
)
|
66
62
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
67
63
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
68
64
|
from sglang.srt.utils import (
|
@@ -123,12 +119,13 @@ class Engine(EngineBase):
|
|
123
119
|
logger.info(f"{server_args=}")
|
124
120
|
|
125
121
|
# Launch subprocesses
|
126
|
-
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
122
|
+
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
127
123
|
server_args=server_args,
|
128
124
|
port_args=port_args,
|
129
125
|
)
|
130
126
|
self.server_args = server_args
|
131
127
|
self.tokenizer_manager = tokenizer_manager
|
128
|
+
self.template_manager = template_manager
|
132
129
|
self.scheduler_info = scheduler_info
|
133
130
|
|
134
131
|
context = zmq.Context(2)
|
@@ -175,7 +172,7 @@ class Engine(EngineBase):
|
|
175
172
|
"""
|
176
173
|
if self.server_args.enable_dp_attention:
|
177
174
|
if data_parallel_rank is None:
|
178
|
-
logger.
|
175
|
+
logger.debug("data_parallel_rank not provided, using default dispatch")
|
179
176
|
elif data_parallel_rank < 0:
|
180
177
|
raise ValueError("data_parallel_rank must be non-negative")
|
181
178
|
elif data_parallel_rank >= self.server_args.dp_size:
|
@@ -258,7 +255,7 @@ class Engine(EngineBase):
|
|
258
255
|
|
259
256
|
if self.server_args.enable_dp_attention:
|
260
257
|
if data_parallel_rank is None:
|
261
|
-
logger.
|
258
|
+
logger.debug("data_parallel_rank not provided, using default dispatch")
|
262
259
|
elif data_parallel_rank < 0:
|
263
260
|
raise ValueError("data_parallel_rank must be non-negative")
|
264
261
|
elif data_parallel_rank >= self.server_args.dp_size:
|
@@ -479,17 +476,15 @@ class Engine(EngineBase):
|
|
479
476
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
480
477
|
)
|
481
478
|
|
482
|
-
def release_memory_occupation(self):
|
483
|
-
|
484
|
-
obj = ReleaseMemoryOccupationReqInput()
|
479
|
+
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
480
|
+
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
485
481
|
loop = asyncio.get_event_loop()
|
486
482
|
return loop.run_until_complete(
|
487
483
|
self.tokenizer_manager.release_memory_occupation(obj, None)
|
488
484
|
)
|
489
485
|
|
490
|
-
def resume_memory_occupation(self):
|
491
|
-
|
492
|
-
obj = ResumeMemoryOccupationReqInput()
|
486
|
+
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
|
487
|
+
obj = ResumeMemoryOccupationReqInput(tags=tags)
|
493
488
|
loop = asyncio.get_event_loop()
|
494
489
|
return loop.run_until_complete(
|
495
490
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
@@ -649,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
649
644
|
|
650
645
|
def _launch_subprocesses(
|
651
646
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
652
|
-
) -> Tuple[TokenizerManager, Dict]:
|
647
|
+
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
653
648
|
"""
|
654
649
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
655
650
|
"""
|
@@ -670,11 +665,9 @@ def _launch_subprocesses(
|
|
670
665
|
|
671
666
|
scheduler_procs = []
|
672
667
|
if server_args.dp_size == 1:
|
673
|
-
# Launch tensor parallel scheduler processes
|
674
668
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
675
669
|
enable=server_args.enable_memory_saver
|
676
670
|
)
|
677
|
-
|
678
671
|
scheduler_pipe_readers = []
|
679
672
|
|
680
673
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
@@ -710,6 +703,7 @@ def _launch_subprocesses(
|
|
710
703
|
writer,
|
711
704
|
),
|
712
705
|
)
|
706
|
+
|
713
707
|
with memory_saver_adapter.configure_subprocess():
|
714
708
|
proc.start()
|
715
709
|
scheduler_procs.append(proc)
|
@@ -735,7 +729,7 @@ def _launch_subprocesses(
|
|
735
729
|
|
736
730
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
737
731
|
# When using `Engine` as a Python API, we don't want to block here.
|
738
|
-
return None, None
|
732
|
+
return None, None, None
|
739
733
|
|
740
734
|
launch_dummy_health_check_server(server_args.host, server_args.port)
|
741
735
|
|
@@ -744,7 +738,7 @@ def _launch_subprocesses(
|
|
744
738
|
logger.error(
|
745
739
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
746
740
|
)
|
747
|
-
return None, None
|
741
|
+
return None, None, None
|
748
742
|
|
749
743
|
# Launch detokenizer process
|
750
744
|
detoken_proc = mp.Process(
|
@@ -758,15 +752,15 @@ def _launch_subprocesses(
|
|
758
752
|
|
759
753
|
# Launch tokenizer process
|
760
754
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
761
|
-
if server_args.chat_template:
|
762
|
-
load_chat_template_for_openai_api(
|
763
|
-
tokenizer_manager, server_args.chat_template, server_args.model_path
|
764
|
-
)
|
765
|
-
else:
|
766
|
-
guess_chat_template_name_from_model_path(server_args.model_path)
|
767
755
|
|
768
|
-
|
769
|
-
|
756
|
+
# Initialize templates
|
757
|
+
template_manager = TemplateManager()
|
758
|
+
template_manager.initialize_templates(
|
759
|
+
tokenizer_manager=tokenizer_manager,
|
760
|
+
model_path=server_args.model_path,
|
761
|
+
chat_template=server_args.chat_template,
|
762
|
+
completion_template=server_args.completion_template,
|
763
|
+
)
|
770
764
|
|
771
765
|
# Wait for the model to finish loading
|
772
766
|
scheduler_infos = []
|
@@ -790,4 +784,4 @@ def _launch_subprocesses(
|
|
790
784
|
# Assume all schedulers have the same scheduler_info
|
791
785
|
scheduler_info = scheduler_infos[0]
|
792
786
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
793
|
-
return tokenizer_manager, scheduler_info
|
787
|
+
return tokenizer_manager, template_manager, scheduler_info
|
@@ -38,7 +38,8 @@ import orjson
|
|
38
38
|
import requests
|
39
39
|
import uvicorn
|
40
40
|
import uvloop
|
41
|
-
from fastapi import
|
41
|
+
from fastapi import Depends, FastAPI, Request, UploadFile
|
42
|
+
from fastapi.exceptions import RequestValidationError
|
42
43
|
from fastapi.middleware.cors import CORSMiddleware
|
43
44
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
45
|
|
@@ -47,6 +48,21 @@ from sglang.srt.disaggregation.utils import (
|
|
47
48
|
register_disaggregation_server,
|
48
49
|
)
|
49
50
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
51
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
52
|
+
ChatCompletionRequest,
|
53
|
+
CompletionRequest,
|
54
|
+
EmbeddingRequest,
|
55
|
+
ErrorResponse,
|
56
|
+
ModelCard,
|
57
|
+
ModelList,
|
58
|
+
ScoringRequest,
|
59
|
+
V1RerankReqInput,
|
60
|
+
)
|
61
|
+
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
62
|
+
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
63
|
+
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
64
|
+
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
65
|
+
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
50
66
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
51
67
|
from sglang.srt.managers.io_struct import (
|
52
68
|
AbortReq,
|
@@ -67,26 +83,11 @@ from sglang.srt.managers.io_struct import (
|
|
67
83
|
UpdateWeightFromDiskReqInput,
|
68
84
|
UpdateWeightsFromDistributedReqInput,
|
69
85
|
UpdateWeightsFromTensorReqInput,
|
70
|
-
V1RerankReqInput,
|
71
86
|
VertexGenerateReqInput,
|
72
87
|
)
|
88
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
73
89
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
74
90
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
75
|
-
from sglang.srt.openai_api.adapter import (
|
76
|
-
v1_batches,
|
77
|
-
v1_cancel_batch,
|
78
|
-
v1_chat_completions,
|
79
|
-
v1_completions,
|
80
|
-
v1_delete_file,
|
81
|
-
v1_embeddings,
|
82
|
-
v1_files_create,
|
83
|
-
v1_rerank,
|
84
|
-
v1_retrieve_batch,
|
85
|
-
v1_retrieve_file,
|
86
|
-
v1_retrieve_file_content,
|
87
|
-
v1_score,
|
88
|
-
)
|
89
|
-
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
90
91
|
from sglang.srt.reasoning_parser import ReasoningParser
|
91
92
|
from sglang.srt.server_args import ServerArgs
|
92
93
|
from sglang.srt.utils import (
|
@@ -109,6 +110,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
109
110
|
@dataclasses.dataclass
|
110
111
|
class _GlobalState:
|
111
112
|
tokenizer_manager: TokenizerManager
|
113
|
+
template_manager: TemplateManager
|
112
114
|
scheduler_info: Dict
|
113
115
|
|
114
116
|
|
@@ -123,6 +125,24 @@ def set_global_state(global_state: _GlobalState):
|
|
123
125
|
@asynccontextmanager
|
124
126
|
async def lifespan(fast_api_app: FastAPI):
|
125
127
|
server_args: ServerArgs = fast_api_app.server_args
|
128
|
+
|
129
|
+
# Initialize OpenAI serving handlers
|
130
|
+
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
131
|
+
_global_state.tokenizer_manager, _global_state.template_manager
|
132
|
+
)
|
133
|
+
fast_api_app.state.openai_serving_chat = OpenAIServingChat(
|
134
|
+
_global_state.tokenizer_manager, _global_state.template_manager
|
135
|
+
)
|
136
|
+
fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding(
|
137
|
+
_global_state.tokenizer_manager, _global_state.template_manager
|
138
|
+
)
|
139
|
+
fast_api_app.state.openai_serving_score = OpenAIServingScore(
|
140
|
+
_global_state.tokenizer_manager
|
141
|
+
)
|
142
|
+
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
143
|
+
_global_state.tokenizer_manager
|
144
|
+
)
|
145
|
+
|
126
146
|
if server_args.warmups is not None:
|
127
147
|
await execute_warmups(
|
128
148
|
server_args.warmups.split(","), _global_state.tokenizer_manager
|
@@ -148,6 +168,47 @@ app.add_middleware(
|
|
148
168
|
allow_headers=["*"],
|
149
169
|
)
|
150
170
|
|
171
|
+
|
172
|
+
# Custom exception handlers to change validation error status codes
|
173
|
+
@app.exception_handler(RequestValidationError)
|
174
|
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
175
|
+
"""Override FastAPI's default 422 validation error with 400"""
|
176
|
+
exc_str = str(exc)
|
177
|
+
errors_str = str(exc.errors())
|
178
|
+
|
179
|
+
if errors_str and errors_str != exc_str:
|
180
|
+
message = f"{exc_str} {errors_str}"
|
181
|
+
else:
|
182
|
+
message = exc_str
|
183
|
+
|
184
|
+
err = ErrorResponse(
|
185
|
+
message=message,
|
186
|
+
type=HTTPStatus.BAD_REQUEST.phrase,
|
187
|
+
code=HTTPStatus.BAD_REQUEST.value,
|
188
|
+
)
|
189
|
+
|
190
|
+
return ORJSONResponse(
|
191
|
+
status_code=400,
|
192
|
+
content=err.model_dump(),
|
193
|
+
)
|
194
|
+
|
195
|
+
|
196
|
+
async def validate_json_request(raw_request: Request):
|
197
|
+
"""Validate that the request content-type is application/json."""
|
198
|
+
content_type = raw_request.headers.get("content-type", "").lower()
|
199
|
+
media_type = content_type.split(";", maxsplit=1)[0]
|
200
|
+
if media_type != "application/json":
|
201
|
+
raise RequestValidationError(
|
202
|
+
errors=[
|
203
|
+
{
|
204
|
+
"loc": ["header", "content-type"],
|
205
|
+
"msg": "Unsupported Media Type: Only 'application/json' is allowed",
|
206
|
+
"type": "value_error",
|
207
|
+
}
|
208
|
+
]
|
209
|
+
)
|
210
|
+
|
211
|
+
|
151
212
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
152
213
|
|
153
214
|
|
@@ -330,13 +391,14 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
330
391
|
return _create_error_response(e)
|
331
392
|
|
332
393
|
|
333
|
-
@app.api_route(
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
394
|
+
@app.api_route(
|
395
|
+
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
396
|
+
)
|
397
|
+
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
398
|
+
"""Endpoint for reranking documents based on query relevance."""
|
399
|
+
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
400
|
+
request, raw_request
|
401
|
+
)
|
340
402
|
|
341
403
|
|
342
404
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
@@ -619,25 +681,39 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
|
|
619
681
|
##### OpenAI-compatible API endpoints #####
|
620
682
|
|
621
683
|
|
622
|
-
@app.post("/v1/completions")
|
623
|
-
async def openai_v1_completions(raw_request: Request):
|
624
|
-
|
684
|
+
@app.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
685
|
+
async def openai_v1_completions(request: CompletionRequest, raw_request: Request):
|
686
|
+
"""OpenAI-compatible text completion endpoint."""
|
687
|
+
return await raw_request.app.state.openai_serving_completion.handle_request(
|
688
|
+
request, raw_request
|
689
|
+
)
|
625
690
|
|
626
691
|
|
627
|
-
@app.post("/v1/chat/completions")
|
628
|
-
async def openai_v1_chat_completions(
|
629
|
-
|
692
|
+
@app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
|
693
|
+
async def openai_v1_chat_completions(
|
694
|
+
request: ChatCompletionRequest, raw_request: Request
|
695
|
+
):
|
696
|
+
"""OpenAI-compatible chat completion endpoint."""
|
697
|
+
return await raw_request.app.state.openai_serving_chat.handle_request(
|
698
|
+
request, raw_request
|
699
|
+
)
|
630
700
|
|
631
701
|
|
632
|
-
@app.post(
|
633
|
-
|
634
|
-
|
635
|
-
|
702
|
+
@app.post(
|
703
|
+
"/v1/embeddings",
|
704
|
+
response_class=ORJSONResponse,
|
705
|
+
dependencies=[Depends(validate_json_request)],
|
706
|
+
)
|
707
|
+
async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
708
|
+
"""OpenAI-compatible embeddings endpoint."""
|
709
|
+
return await raw_request.app.state.openai_serving_embedding.handle_request(
|
710
|
+
request, raw_request
|
711
|
+
)
|
636
712
|
|
637
713
|
|
638
714
|
@app.get("/v1/models", response_class=ORJSONResponse)
|
639
|
-
def available_models():
|
640
|
-
"""Show available models."""
|
715
|
+
async def available_models():
|
716
|
+
"""Show available models. OpenAI-compatible endpoint."""
|
641
717
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
642
718
|
model_cards = []
|
643
719
|
for served_model_name in served_model_names:
|
@@ -651,45 +727,29 @@ def available_models():
|
|
651
727
|
return ModelList(data=model_cards)
|
652
728
|
|
653
729
|
|
654
|
-
@app.
|
655
|
-
async def
|
656
|
-
|
657
|
-
|
658
|
-
)
|
659
|
-
|
660
|
-
|
661
|
-
@app.delete("/v1/files/{file_id}")
|
662
|
-
async def delete_file(file_id: str):
|
663
|
-
# https://platform.openai.com/docs/api-reference/files/delete
|
664
|
-
return await v1_delete_file(file_id)
|
665
|
-
|
666
|
-
|
667
|
-
@app.post("/v1/batches")
|
668
|
-
async def openai_v1_batches(raw_request: Request):
|
669
|
-
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
670
|
-
|
671
|
-
|
672
|
-
@app.post("/v1/batches/{batch_id}/cancel")
|
673
|
-
async def cancel_batches(batch_id: str):
|
674
|
-
# https://platform.openai.com/docs/api-reference/batch/cancel
|
675
|
-
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
676
|
-
|
677
|
-
|
678
|
-
@app.get("/v1/batches/{batch_id}")
|
679
|
-
async def retrieve_batch(batch_id: str):
|
680
|
-
return await v1_retrieve_batch(batch_id)
|
681
|
-
|
682
|
-
|
683
|
-
@app.get("/v1/files/{file_id}")
|
684
|
-
async def retrieve_file(file_id: str):
|
685
|
-
# https://platform.openai.com/docs/api-reference/files/retrieve
|
686
|
-
return await v1_retrieve_file(file_id)
|
730
|
+
@app.get("/v1/models/{model:path}", response_class=ORJSONResponse)
|
731
|
+
async def retrieve_model(model: str):
|
732
|
+
"""Retrieves a model instance, providing basic information about the model."""
|
733
|
+
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
687
734
|
|
735
|
+
if model not in served_model_names:
|
736
|
+
return ORJSONResponse(
|
737
|
+
status_code=404,
|
738
|
+
content={
|
739
|
+
"error": {
|
740
|
+
"message": f"The model '{model}' does not exist",
|
741
|
+
"type": "invalid_request_error",
|
742
|
+
"param": "model",
|
743
|
+
"code": "model_not_found",
|
744
|
+
}
|
745
|
+
},
|
746
|
+
)
|
688
747
|
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
748
|
+
return ModelCard(
|
749
|
+
id=model,
|
750
|
+
root=model,
|
751
|
+
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
|
752
|
+
)
|
693
753
|
|
694
754
|
|
695
755
|
## SageMaker API
|
@@ -700,8 +760,13 @@ async def sagemaker_health() -> Response:
|
|
700
760
|
|
701
761
|
|
702
762
|
@app.post("/invocations")
|
703
|
-
async def sagemaker_chat_completions(
|
704
|
-
|
763
|
+
async def sagemaker_chat_completions(
|
764
|
+
request: ChatCompletionRequest, raw_request: Request
|
765
|
+
):
|
766
|
+
"""OpenAI-compatible chat completion endpoint."""
|
767
|
+
return await raw_request.app.state.openai_serving_chat.handle_request(
|
768
|
+
request, raw_request
|
769
|
+
)
|
705
770
|
|
706
771
|
|
707
772
|
## Vertex AI API
|
@@ -732,10 +797,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|
732
797
|
return ORJSONResponse({"predictions": ret})
|
733
798
|
|
734
799
|
|
735
|
-
@app.post("/v1/score")
|
736
|
-
async def v1_score_request(raw_request: Request):
|
800
|
+
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
801
|
+
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
737
802
|
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
738
|
-
return await
|
803
|
+
return await raw_request.app.state.openai_serving_score.handle_request(
|
804
|
+
request, raw_request
|
805
|
+
)
|
739
806
|
|
740
807
|
|
741
808
|
def _create_error_response(e):
|
@@ -764,10 +831,13 @@ def launch_server(
|
|
764
831
|
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
765
832
|
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
766
833
|
"""
|
767
|
-
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
834
|
+
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
835
|
+
server_args=server_args
|
836
|
+
)
|
768
837
|
set_global_state(
|
769
838
|
_GlobalState(
|
770
839
|
tokenizer_manager=tokenizer_manager,
|
840
|
+
template_manager=template_manager,
|
771
841
|
scheduler_info=scheduler_info,
|
772
842
|
)
|
773
843
|
)
|
@@ -64,11 +64,9 @@ class HttpServerEngineAdapter(EngineBase):
|
|
64
64
|
|
65
65
|
def _make_request(self, endpoint: str, payload: Optional[dict] = None):
|
66
66
|
"""Make a POST request to the specified endpoint with the given payload.
|
67
|
-
|
68
67
|
Args:
|
69
68
|
endpoint: The API endpoint to call
|
70
69
|
payload: The JSON payload to send (default: empty dict)
|
71
|
-
|
72
70
|
Returns:
|
73
71
|
The JSON response from the server
|
74
72
|
"""
|
@@ -85,7 +83,6 @@ class HttpServerEngineAdapter(EngineBase):
|
|
85
83
|
):
|
86
84
|
"""
|
87
85
|
Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
|
88
|
-
|
89
86
|
Note: The model should be on GPUs rather than CPU for this functionality to work properly.
|
90
87
|
If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
|
91
88
|
"""
|
File without changes
|