sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from typing import Any, AsyncGenerator, Dict, List, Union
|
4
|
+
|
5
|
+
from fastapi import Request
|
6
|
+
from fastapi.responses import ORJSONResponse, StreamingResponse
|
7
|
+
|
8
|
+
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
|
9
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
10
|
+
CompletionRequest,
|
11
|
+
CompletionResponse,
|
12
|
+
CompletionResponseChoice,
|
13
|
+
CompletionResponseStreamChoice,
|
14
|
+
CompletionStreamResponse,
|
15
|
+
ErrorResponse,
|
16
|
+
)
|
17
|
+
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
18
|
+
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
19
|
+
from sglang.srt.entrypoints.openai.utils import (
|
20
|
+
process_hidden_states_from_ret,
|
21
|
+
to_openai_style_logprobs,
|
22
|
+
)
|
23
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
24
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
25
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class OpenAIServingCompletion(OpenAIServingBase):
|
31
|
+
"""Handler for /v1/completion requests"""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
tokenizer_manager: TokenizerManager,
|
36
|
+
template_manager: TemplateManager,
|
37
|
+
):
|
38
|
+
super().__init__(tokenizer_manager)
|
39
|
+
self.template_manager = template_manager
|
40
|
+
|
41
|
+
def _request_id_prefix(self) -> str:
|
42
|
+
return "cmpl-"
|
43
|
+
|
44
|
+
def _convert_to_internal_request(
|
45
|
+
self,
|
46
|
+
request: CompletionRequest,
|
47
|
+
) -> tuple[GenerateReqInput, CompletionRequest]:
|
48
|
+
"""Convert OpenAI completion request to internal format"""
|
49
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
50
|
+
if request.echo and request.logprobs:
|
51
|
+
logger.warning(
|
52
|
+
"Echo is not compatible with logprobs. "
|
53
|
+
"To compute logprobs of input prompt, please use the native /generate API."
|
54
|
+
)
|
55
|
+
# Process prompt
|
56
|
+
prompt = request.prompt
|
57
|
+
if self.template_manager.completion_template_name is not None:
|
58
|
+
prompt = generate_completion_prompt_from_request(request)
|
59
|
+
|
60
|
+
# Set logprob start length based on echo and logprobs
|
61
|
+
if request.echo and request.logprobs:
|
62
|
+
logprob_start_len = 0
|
63
|
+
else:
|
64
|
+
logprob_start_len = -1
|
65
|
+
|
66
|
+
# Build sampling parameters
|
67
|
+
sampling_params = self._build_sampling_params(request)
|
68
|
+
|
69
|
+
# Determine prompt format
|
70
|
+
if isinstance(prompt, str) or (
|
71
|
+
isinstance(prompt, list) and isinstance(prompt[0], str)
|
72
|
+
):
|
73
|
+
prompt_kwargs = {"text": prompt}
|
74
|
+
else:
|
75
|
+
prompt_kwargs = {"input_ids": prompt}
|
76
|
+
|
77
|
+
adapted_request = GenerateReqInput(
|
78
|
+
**prompt_kwargs,
|
79
|
+
sampling_params=sampling_params,
|
80
|
+
return_logprob=request.logprobs is not None,
|
81
|
+
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
82
|
+
logprob_start_len=logprob_start_len,
|
83
|
+
return_text_in_logprobs=True,
|
84
|
+
stream=request.stream,
|
85
|
+
lora_path=request.lora_path,
|
86
|
+
bootstrap_host=request.bootstrap_host,
|
87
|
+
bootstrap_port=request.bootstrap_port,
|
88
|
+
bootstrap_room=request.bootstrap_room,
|
89
|
+
return_hidden_states=request.return_hidden_states,
|
90
|
+
)
|
91
|
+
|
92
|
+
return adapted_request, request
|
93
|
+
|
94
|
+
def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
|
95
|
+
"""Build sampling parameters for the request"""
|
96
|
+
# Start with common parameters
|
97
|
+
sampling_params = {
|
98
|
+
"temperature": request.temperature,
|
99
|
+
"max_new_tokens": request.max_tokens,
|
100
|
+
"min_new_tokens": request.min_tokens,
|
101
|
+
"stop": request.stop,
|
102
|
+
"stop_token_ids": request.stop_token_ids,
|
103
|
+
"top_p": request.top_p,
|
104
|
+
"top_k": request.top_k,
|
105
|
+
"min_p": request.min_p,
|
106
|
+
"presence_penalty": request.presence_penalty,
|
107
|
+
"frequency_penalty": request.frequency_penalty,
|
108
|
+
"repetition_penalty": request.repetition_penalty,
|
109
|
+
"regex": request.regex,
|
110
|
+
"json_schema": request.json_schema,
|
111
|
+
"ebnf": request.ebnf,
|
112
|
+
"n": request.n,
|
113
|
+
"no_stop_trim": request.no_stop_trim,
|
114
|
+
"ignore_eos": request.ignore_eos,
|
115
|
+
"skip_special_tokens": request.skip_special_tokens,
|
116
|
+
"logit_bias": request.logit_bias,
|
117
|
+
}
|
118
|
+
|
119
|
+
return sampling_params
|
120
|
+
|
121
|
+
async def _handle_streaming_request(
|
122
|
+
self,
|
123
|
+
adapted_request: GenerateReqInput,
|
124
|
+
request: CompletionRequest,
|
125
|
+
raw_request: Request,
|
126
|
+
) -> StreamingResponse:
|
127
|
+
"""Handle streaming completion request"""
|
128
|
+
return StreamingResponse(
|
129
|
+
self._generate_completion_stream(adapted_request, request, raw_request),
|
130
|
+
media_type="text/event-stream",
|
131
|
+
background=self.tokenizer_manager.create_abort_task(adapted_request),
|
132
|
+
)
|
133
|
+
|
134
|
+
async def _generate_completion_stream(
|
135
|
+
self,
|
136
|
+
adapted_request: GenerateReqInput,
|
137
|
+
request: CompletionRequest,
|
138
|
+
raw_request: Request,
|
139
|
+
) -> AsyncGenerator[str, None]:
|
140
|
+
"""Generate streaming completion response"""
|
141
|
+
created = int(time.time())
|
142
|
+
|
143
|
+
# State tracking for streaming
|
144
|
+
stream_buffers = {}
|
145
|
+
n_prev_tokens = {}
|
146
|
+
|
147
|
+
# Usage tracking
|
148
|
+
prompt_tokens = {}
|
149
|
+
completion_tokens = {}
|
150
|
+
cached_tokens = {}
|
151
|
+
hidden_states = {}
|
152
|
+
|
153
|
+
try:
|
154
|
+
async for content in self.tokenizer_manager.generate_request(
|
155
|
+
adapted_request, raw_request
|
156
|
+
):
|
157
|
+
index = content.get("index", 0)
|
158
|
+
|
159
|
+
text = content["text"]
|
160
|
+
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
161
|
+
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
162
|
+
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
163
|
+
hidden_states[index] = content["meta_info"].get("hidden_states", None)
|
164
|
+
|
165
|
+
stream_buffer = stream_buffers.get(index, "")
|
166
|
+
# Handle echo for first chunk
|
167
|
+
if not stream_buffer: # The first chunk
|
168
|
+
if request.echo:
|
169
|
+
echo_text = self._get_echo_text(request, index)
|
170
|
+
text = echo_text + text
|
171
|
+
|
172
|
+
# Handle logprobs
|
173
|
+
logprobs = None
|
174
|
+
if request.logprobs is not None:
|
175
|
+
# The first chunk and echo is enabled.
|
176
|
+
if not stream_buffer and request.echo:
|
177
|
+
input_token_logprobs = content["meta_info"][
|
178
|
+
"input_token_logprobs"
|
179
|
+
]
|
180
|
+
input_top_logprobs = content["meta_info"]["input_top_logprobs"]
|
181
|
+
else:
|
182
|
+
input_token_logprobs = None
|
183
|
+
input_top_logprobs = None
|
184
|
+
|
185
|
+
n_prev_token = n_prev_tokens.get(index, 0)
|
186
|
+
logprobs = to_openai_style_logprobs(
|
187
|
+
input_token_logprobs=input_token_logprobs,
|
188
|
+
input_top_logprobs=input_top_logprobs,
|
189
|
+
output_token_logprobs=content["meta_info"][
|
190
|
+
"output_token_logprobs"
|
191
|
+
][n_prev_token:],
|
192
|
+
output_top_logprobs=content["meta_info"]["output_top_logprobs"][
|
193
|
+
n_prev_token:
|
194
|
+
],
|
195
|
+
)
|
196
|
+
n_prev_tokens[index] = len(
|
197
|
+
content["meta_info"]["output_token_logprobs"]
|
198
|
+
)
|
199
|
+
|
200
|
+
# Generate delta
|
201
|
+
delta = text[len(stream_buffer) :]
|
202
|
+
stream_buffers[index] = stream_buffer + delta
|
203
|
+
finish_reason = content["meta_info"]["finish_reason"]
|
204
|
+
|
205
|
+
choice_data = CompletionResponseStreamChoice(
|
206
|
+
index=index,
|
207
|
+
text=delta,
|
208
|
+
logprobs=logprobs,
|
209
|
+
finish_reason=finish_reason["type"] if finish_reason else None,
|
210
|
+
matched_stop=(
|
211
|
+
finish_reason["matched"]
|
212
|
+
if finish_reason and "matched" in finish_reason
|
213
|
+
else None
|
214
|
+
),
|
215
|
+
)
|
216
|
+
chunk = CompletionStreamResponse(
|
217
|
+
id=content["meta_info"]["id"],
|
218
|
+
created=created,
|
219
|
+
object="text_completion",
|
220
|
+
choices=[choice_data],
|
221
|
+
model=request.model,
|
222
|
+
)
|
223
|
+
|
224
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
225
|
+
|
226
|
+
if request.return_hidden_states and hidden_states:
|
227
|
+
for index, choice_hidden_states in hidden_states.items():
|
228
|
+
if choice_hidden_states:
|
229
|
+
last_token_hidden_states = (
|
230
|
+
choice_hidden_states[-1]
|
231
|
+
if len(choice_hidden_states) > 1
|
232
|
+
else []
|
233
|
+
)
|
234
|
+
hidden_states_chunk = CompletionStreamResponse(
|
235
|
+
id=content["meta_info"]["id"],
|
236
|
+
created=created,
|
237
|
+
object="text_completion",
|
238
|
+
choices=[
|
239
|
+
CompletionResponseStreamChoice(
|
240
|
+
index=index,
|
241
|
+
text="",
|
242
|
+
hidden_states=last_token_hidden_states,
|
243
|
+
finish_reason=None,
|
244
|
+
)
|
245
|
+
],
|
246
|
+
model=request.model,
|
247
|
+
)
|
248
|
+
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
249
|
+
|
250
|
+
# Handle final usage chunk
|
251
|
+
if request.stream_options and request.stream_options.include_usage:
|
252
|
+
usage = UsageProcessor.calculate_streaming_usage(
|
253
|
+
prompt_tokens,
|
254
|
+
completion_tokens,
|
255
|
+
cached_tokens,
|
256
|
+
n_choices=request.n,
|
257
|
+
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
258
|
+
)
|
259
|
+
final_usage_chunk = CompletionStreamResponse(
|
260
|
+
id=content["meta_info"]["id"],
|
261
|
+
created=created,
|
262
|
+
choices=[],
|
263
|
+
model=request.model,
|
264
|
+
usage=usage,
|
265
|
+
)
|
266
|
+
final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
|
267
|
+
yield f"data: {final_usage_data}\n\n"
|
268
|
+
|
269
|
+
except Exception as e:
|
270
|
+
error = self.create_streaming_error_response(str(e))
|
271
|
+
yield f"data: {error}\n\n"
|
272
|
+
|
273
|
+
yield "data: [DONE]\n\n"
|
274
|
+
|
275
|
+
async def _handle_non_streaming_request(
|
276
|
+
self,
|
277
|
+
adapted_request: GenerateReqInput,
|
278
|
+
request: CompletionRequest,
|
279
|
+
raw_request: Request,
|
280
|
+
) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
|
281
|
+
"""Handle non-streaming completion request"""
|
282
|
+
try:
|
283
|
+
generator = self.tokenizer_manager.generate_request(
|
284
|
+
adapted_request, raw_request
|
285
|
+
)
|
286
|
+
ret = await generator.__anext__()
|
287
|
+
except ValueError as e:
|
288
|
+
return self.create_error_response(str(e))
|
289
|
+
|
290
|
+
if not isinstance(ret, list):
|
291
|
+
ret = [ret]
|
292
|
+
|
293
|
+
response = self._build_completion_response(
|
294
|
+
request,
|
295
|
+
ret,
|
296
|
+
int(time.time()),
|
297
|
+
)
|
298
|
+
|
299
|
+
return response
|
300
|
+
|
301
|
+
def _build_completion_response(
|
302
|
+
self,
|
303
|
+
request: CompletionRequest,
|
304
|
+
ret: List[Dict[str, Any]],
|
305
|
+
created: int,
|
306
|
+
) -> CompletionResponse:
|
307
|
+
"""Build completion response from generation results"""
|
308
|
+
choices = []
|
309
|
+
echo = False
|
310
|
+
|
311
|
+
# Prepare echo prompts if needed
|
312
|
+
echo_prompts = []
|
313
|
+
if request.echo:
|
314
|
+
echo_prompts = self._prepare_echo_prompts(request)
|
315
|
+
echo = True
|
316
|
+
|
317
|
+
for idx, ret_item in enumerate(ret):
|
318
|
+
text = ret_item["text"]
|
319
|
+
|
320
|
+
# Handle echo
|
321
|
+
if echo:
|
322
|
+
prompt_index = idx // request.n
|
323
|
+
text = echo_prompts[prompt_index] + text
|
324
|
+
|
325
|
+
# Handle logprobs
|
326
|
+
logprobs = None
|
327
|
+
if request.logprobs is not None:
|
328
|
+
if echo:
|
329
|
+
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
330
|
+
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
331
|
+
else:
|
332
|
+
input_token_logprobs = None
|
333
|
+
input_top_logprobs = None
|
334
|
+
|
335
|
+
logprobs = to_openai_style_logprobs(
|
336
|
+
input_token_logprobs=input_token_logprobs,
|
337
|
+
input_top_logprobs=input_top_logprobs,
|
338
|
+
output_token_logprobs=ret_item["meta_info"][
|
339
|
+
"output_token_logprobs"
|
340
|
+
],
|
341
|
+
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
342
|
+
)
|
343
|
+
|
344
|
+
# Handle hidden states
|
345
|
+
hidden_states = process_hidden_states_from_ret(ret_item, request)
|
346
|
+
|
347
|
+
finish_reason = ret_item["meta_info"]["finish_reason"]
|
348
|
+
|
349
|
+
choice_data = CompletionResponseChoice(
|
350
|
+
index=idx,
|
351
|
+
text=text,
|
352
|
+
logprobs=logprobs,
|
353
|
+
finish_reason=finish_reason["type"] if finish_reason else None,
|
354
|
+
matched_stop=(
|
355
|
+
finish_reason["matched"]
|
356
|
+
if finish_reason and "matched" in finish_reason
|
357
|
+
else None
|
358
|
+
),
|
359
|
+
hidden_states=hidden_states,
|
360
|
+
)
|
361
|
+
choices.append(choice_data)
|
362
|
+
|
363
|
+
# Calculate usage
|
364
|
+
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
365
|
+
usage = UsageProcessor.calculate_response_usage(
|
366
|
+
ret, n_choices=request.n, enable_cache_report=cache_report
|
367
|
+
)
|
368
|
+
|
369
|
+
return CompletionResponse(
|
370
|
+
id=ret[0]["meta_info"]["id"],
|
371
|
+
model=request.model,
|
372
|
+
created=created,
|
373
|
+
choices=choices,
|
374
|
+
usage=usage,
|
375
|
+
)
|
376
|
+
|
377
|
+
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
|
378
|
+
"""Get echo text for streaming response"""
|
379
|
+
if isinstance(request.prompt, str):
|
380
|
+
# for the case of single str prompts
|
381
|
+
return request.prompt
|
382
|
+
elif isinstance(request.prompt, list):
|
383
|
+
if isinstance(request.prompt[0], str):
|
384
|
+
# for the case of multiple str prompts
|
385
|
+
return request.prompt[index // request.n]
|
386
|
+
elif isinstance(request.prompt[0], int):
|
387
|
+
# for the case of single token ids prompt
|
388
|
+
return self.tokenizer_manager.tokenizer.decode(
|
389
|
+
request.prompt, skip_special_tokens=True
|
390
|
+
)
|
391
|
+
elif isinstance(request.prompt[0], list) and isinstance(
|
392
|
+
request.prompt[0][0], int
|
393
|
+
):
|
394
|
+
# for the case of multiple token ids prompts
|
395
|
+
return self.tokenizer_manager.tokenizer.decode(
|
396
|
+
request.prompt[index // request.n],
|
397
|
+
skip_special_tokens=True,
|
398
|
+
)
|
399
|
+
return ""
|
400
|
+
|
401
|
+
def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:
|
402
|
+
"""Prepare echo prompts for non-streaming response"""
|
403
|
+
# TODO: handle the case prompt is token ids
|
404
|
+
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
405
|
+
# for the case of multiple str prompts
|
406
|
+
return request.prompt
|
407
|
+
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
408
|
+
# for the case of multiple token ids prompts
|
409
|
+
return [
|
410
|
+
self.tokenizer_manager.tokenizer.decode(
|
411
|
+
prompt, skip_special_tokens=True
|
412
|
+
)
|
413
|
+
for prompt in request.prompt
|
414
|
+
]
|
415
|
+
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
416
|
+
# for the case of single token ids prompt
|
417
|
+
return [
|
418
|
+
self.tokenizer_manager.tokenizer.decode(
|
419
|
+
request.prompt, skip_special_tokens=True
|
420
|
+
)
|
421
|
+
]
|
422
|
+
else:
|
423
|
+
# for the case of single str prompt
|
424
|
+
return [request.prompt]
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
2
|
+
|
3
|
+
from fastapi import Request
|
4
|
+
from fastapi.responses import ORJSONResponse
|
5
|
+
|
6
|
+
from sglang.srt.conversation import generate_embedding_convs
|
7
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
8
|
+
EmbeddingObject,
|
9
|
+
EmbeddingRequest,
|
10
|
+
EmbeddingResponse,
|
11
|
+
ErrorResponse,
|
12
|
+
MultimodalEmbeddingInput,
|
13
|
+
UsageInfo,
|
14
|
+
)
|
15
|
+
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
16
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
17
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
18
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
19
|
+
|
20
|
+
|
21
|
+
class OpenAIServingEmbedding(OpenAIServingBase):
|
22
|
+
"""Handler for v1/embeddings requests"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
tokenizer_manager: TokenizerManager,
|
27
|
+
template_manager: TemplateManager,
|
28
|
+
):
|
29
|
+
super().__init__(tokenizer_manager)
|
30
|
+
self.template_manager = template_manager
|
31
|
+
|
32
|
+
def _request_id_prefix(self) -> str:
|
33
|
+
return "embd-"
|
34
|
+
|
35
|
+
def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:
|
36
|
+
"""Validate that the input is not empty or whitespace only."""
|
37
|
+
if not (input := request.input):
|
38
|
+
return "Input cannot be empty"
|
39
|
+
|
40
|
+
# Handle single string
|
41
|
+
if isinstance(input, str):
|
42
|
+
if not input.strip():
|
43
|
+
return "Input cannot be empty or whitespace only"
|
44
|
+
return None
|
45
|
+
|
46
|
+
# Handle list inputs
|
47
|
+
if isinstance(input, list):
|
48
|
+
if len(input) == 0:
|
49
|
+
return "Input cannot be empty"
|
50
|
+
|
51
|
+
# Check first element to determine type
|
52
|
+
first_item = input[0]
|
53
|
+
|
54
|
+
if isinstance(first_item, str):
|
55
|
+
# List of strings
|
56
|
+
for i, item in enumerate(input):
|
57
|
+
if not isinstance(item, str):
|
58
|
+
return f"All items in input list must be strings"
|
59
|
+
if not item.strip():
|
60
|
+
return f"Input at index {i} cannot be empty or whitespace only"
|
61
|
+
elif isinstance(first_item, int):
|
62
|
+
# List of integers (token IDs)
|
63
|
+
for i, item in enumerate(input):
|
64
|
+
if not isinstance(item, int):
|
65
|
+
return f"All items in input list must be integers"
|
66
|
+
if item < 0:
|
67
|
+
return f"Token ID at index {i} must be non-negative"
|
68
|
+
return None
|
69
|
+
|
70
|
+
def _convert_to_internal_request(
|
71
|
+
self,
|
72
|
+
request: EmbeddingRequest,
|
73
|
+
) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
|
74
|
+
"""Convert OpenAI embedding request to internal format"""
|
75
|
+
prompt = request.input
|
76
|
+
|
77
|
+
if isinstance(prompt, str):
|
78
|
+
# Single string input
|
79
|
+
prompt_kwargs = {"text": prompt}
|
80
|
+
elif isinstance(prompt, list):
|
81
|
+
if len(prompt) > 0 and isinstance(prompt[0], str):
|
82
|
+
prompt_kwargs = {"text": prompt}
|
83
|
+
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
84
|
+
# Handle multimodal embedding inputs
|
85
|
+
texts = []
|
86
|
+
images = []
|
87
|
+
for item in prompt:
|
88
|
+
# Use padding for text if None - this could be improved
|
89
|
+
texts.append(item.text if item.text is not None else "padding")
|
90
|
+
images.append(item.image if item.image is not None else None)
|
91
|
+
|
92
|
+
generate_prompts = []
|
93
|
+
# Check if we have a chat template for multimodal embeddings
|
94
|
+
if self.template_manager.chat_template_name is not None:
|
95
|
+
convs = generate_embedding_convs(
|
96
|
+
texts, images, self.template_manager.chat_template_name
|
97
|
+
)
|
98
|
+
for conv in convs:
|
99
|
+
generate_prompts.append(conv.get_prompt())
|
100
|
+
else:
|
101
|
+
generate_prompts = texts
|
102
|
+
|
103
|
+
if len(generate_prompts) == 1:
|
104
|
+
prompt_kwargs = {
|
105
|
+
"text": generate_prompts[0],
|
106
|
+
"image_data": images[0],
|
107
|
+
}
|
108
|
+
else:
|
109
|
+
prompt_kwargs = {
|
110
|
+
"text": generate_prompts,
|
111
|
+
"image_data": images,
|
112
|
+
}
|
113
|
+
else:
|
114
|
+
# List of integers (token IDs) or empty list
|
115
|
+
prompt_kwargs = {"input_ids": prompt}
|
116
|
+
else:
|
117
|
+
# Other types (should not happen but handle gracefully)
|
118
|
+
prompt_kwargs = {"input_ids": prompt}
|
119
|
+
|
120
|
+
adapted_request = EmbeddingReqInput(
|
121
|
+
**prompt_kwargs,
|
122
|
+
)
|
123
|
+
|
124
|
+
return adapted_request, request
|
125
|
+
|
126
|
+
async def _handle_non_streaming_request(
|
127
|
+
self,
|
128
|
+
adapted_request: EmbeddingReqInput,
|
129
|
+
request: EmbeddingRequest,
|
130
|
+
raw_request: Request,
|
131
|
+
) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
|
132
|
+
"""Handle the embedding request"""
|
133
|
+
try:
|
134
|
+
ret = await self.tokenizer_manager.generate_request(
|
135
|
+
adapted_request, raw_request
|
136
|
+
).__anext__()
|
137
|
+
except ValueError as e:
|
138
|
+
return self.create_error_response(str(e))
|
139
|
+
|
140
|
+
if not isinstance(ret, list):
|
141
|
+
ret = [ret]
|
142
|
+
|
143
|
+
response = self._build_embedding_response(ret)
|
144
|
+
return response
|
145
|
+
|
146
|
+
def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
|
147
|
+
"""Build the embedding response"""
|
148
|
+
embedding_objects = []
|
149
|
+
prompt_tokens = 0
|
150
|
+
|
151
|
+
for idx, ret_item in enumerate(ret):
|
152
|
+
embedding_objects.append(
|
153
|
+
EmbeddingObject(
|
154
|
+
embedding=ret_item["embedding"],
|
155
|
+
index=idx,
|
156
|
+
)
|
157
|
+
)
|
158
|
+
# Handle missing prompt_tokens gracefully
|
159
|
+
meta_info = ret_item.get("meta_info", {})
|
160
|
+
prompt_tokens += meta_info.get("prompt_tokens", 0)
|
161
|
+
|
162
|
+
return EmbeddingResponse(
|
163
|
+
data=embedding_objects,
|
164
|
+
model=self.tokenizer_manager.model_path,
|
165
|
+
usage=UsageInfo(
|
166
|
+
prompt_tokens=prompt_tokens,
|
167
|
+
total_tokens=prompt_tokens,
|
168
|
+
),
|
169
|
+
)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
from fastapi import Request
|
5
|
+
from fastapi.responses import ORJSONResponse
|
6
|
+
|
7
|
+
from sglang.srt.entrypoints.openai.protocol import (
|
8
|
+
ErrorResponse,
|
9
|
+
RerankResponse,
|
10
|
+
V1RerankReqInput,
|
11
|
+
)
|
12
|
+
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
13
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class OpenAIServingRerank(OpenAIServingBase):
|
19
|
+
"""Handler for /v1/rerank requests"""
|
20
|
+
|
21
|
+
# NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
|
22
|
+
# to another module in the future.
|
23
|
+
|
24
|
+
def _request_id_prefix(self) -> str:
|
25
|
+
return "rerank-"
|
26
|
+
|
27
|
+
def _validate_request(self, request: V1RerankReqInput) -> Optional[str]:
|
28
|
+
"""Validate rerank request format and content"""
|
29
|
+
if not request.query:
|
30
|
+
return "Query cannot be empty"
|
31
|
+
|
32
|
+
if isinstance(request.query, str):
|
33
|
+
if not request.query.strip():
|
34
|
+
return "Query cannot be empty or whitespace only"
|
35
|
+
|
36
|
+
if not request.documents:
|
37
|
+
return "Documents cannot be empty"
|
38
|
+
|
39
|
+
for doc in request.documents:
|
40
|
+
if not doc:
|
41
|
+
return "Each document must be a non-empty string"
|
42
|
+
if isinstance(doc, str) and not doc.strip():
|
43
|
+
return "Each document cannot be empty or whitespace only"
|
44
|
+
|
45
|
+
return None
|
46
|
+
|
47
|
+
def _convert_to_internal_request(
|
48
|
+
self, request: V1RerankReqInput
|
49
|
+
) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
|
50
|
+
"""Convert OpenAI rerank request to internal embedding format"""
|
51
|
+
# Create pairs of [query, document] for each document
|
52
|
+
pairs = []
|
53
|
+
for doc in request.documents:
|
54
|
+
pairs.append([request.query, doc])
|
55
|
+
|
56
|
+
adapted_request = EmbeddingReqInput(
|
57
|
+
text=pairs,
|
58
|
+
is_cross_encoder_request=True,
|
59
|
+
)
|
60
|
+
|
61
|
+
return adapted_request, request
|
62
|
+
|
63
|
+
async def _handle_non_streaming_request(
|
64
|
+
self,
|
65
|
+
adapted_request: EmbeddingReqInput,
|
66
|
+
request: V1RerankReqInput,
|
67
|
+
raw_request: Request,
|
68
|
+
) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
|
69
|
+
"""Handle the rerank request"""
|
70
|
+
try:
|
71
|
+
ret = await self.tokenizer_manager.generate_request(
|
72
|
+
adapted_request, raw_request
|
73
|
+
).__anext__()
|
74
|
+
|
75
|
+
except ValueError as e:
|
76
|
+
return self.create_error_response(str(e))
|
77
|
+
|
78
|
+
if not isinstance(ret, list):
|
79
|
+
ret = [ret]
|
80
|
+
|
81
|
+
responses = self._build_rerank_response(ret, request)
|
82
|
+
return responses
|
83
|
+
|
84
|
+
def _build_rerank_response(
|
85
|
+
self, ret: List[Dict[str, Any]], request: V1RerankReqInput
|
86
|
+
) -> List[RerankResponse]:
|
87
|
+
"""Build the rerank response from generation results"""
|
88
|
+
responses = []
|
89
|
+
for idx, ret_item in enumerate(ret):
|
90
|
+
responses.append(
|
91
|
+
RerankResponse(
|
92
|
+
score=ret_item["embedding"],
|
93
|
+
document=request.documents[idx],
|
94
|
+
index=idx,
|
95
|
+
meta_info=ret_item["meta_info"],
|
96
|
+
)
|
97
|
+
)
|
98
|
+
|
99
|
+
# Sort by score in descending order (highest relevance first)
|
100
|
+
responses.sort(key=lambda x: x.score, reverse=True)
|
101
|
+
|
102
|
+
return responses
|