sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
1
+ import logging
2
+ import time
3
+ from typing import Any, AsyncGenerator, Dict, List, Union
4
+
5
+ from fastapi import Request
6
+ from fastapi.responses import ORJSONResponse, StreamingResponse
7
+
8
+ from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
9
+ from sglang.srt.entrypoints.openai.protocol import (
10
+ CompletionRequest,
11
+ CompletionResponse,
12
+ CompletionResponseChoice,
13
+ CompletionResponseStreamChoice,
14
+ CompletionStreamResponse,
15
+ ErrorResponse,
16
+ )
17
+ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
18
+ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
19
+ from sglang.srt.entrypoints.openai.utils import (
20
+ process_hidden_states_from_ret,
21
+ to_openai_style_logprobs,
22
+ )
23
+ from sglang.srt.managers.io_struct import GenerateReqInput
24
+ from sglang.srt.managers.template_manager import TemplateManager
25
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class OpenAIServingCompletion(OpenAIServingBase):
31
+ """Handler for /v1/completion requests"""
32
+
33
+ def __init__(
34
+ self,
35
+ tokenizer_manager: TokenizerManager,
36
+ template_manager: TemplateManager,
37
+ ):
38
+ super().__init__(tokenizer_manager)
39
+ self.template_manager = template_manager
40
+
41
+ def _request_id_prefix(self) -> str:
42
+ return "cmpl-"
43
+
44
+ def _convert_to_internal_request(
45
+ self,
46
+ request: CompletionRequest,
47
+ ) -> tuple[GenerateReqInput, CompletionRequest]:
48
+ """Convert OpenAI completion request to internal format"""
49
+ # NOTE: with openai API, the prompt's logprobs are always not computed
50
+ if request.echo and request.logprobs:
51
+ logger.warning(
52
+ "Echo is not compatible with logprobs. "
53
+ "To compute logprobs of input prompt, please use the native /generate API."
54
+ )
55
+ # Process prompt
56
+ prompt = request.prompt
57
+ if self.template_manager.completion_template_name is not None:
58
+ prompt = generate_completion_prompt_from_request(request)
59
+
60
+ # Set logprob start length based on echo and logprobs
61
+ if request.echo and request.logprobs:
62
+ logprob_start_len = 0
63
+ else:
64
+ logprob_start_len = -1
65
+
66
+ # Build sampling parameters
67
+ sampling_params = self._build_sampling_params(request)
68
+
69
+ # Determine prompt format
70
+ if isinstance(prompt, str) or (
71
+ isinstance(prompt, list) and isinstance(prompt[0], str)
72
+ ):
73
+ prompt_kwargs = {"text": prompt}
74
+ else:
75
+ prompt_kwargs = {"input_ids": prompt}
76
+
77
+ adapted_request = GenerateReqInput(
78
+ **prompt_kwargs,
79
+ sampling_params=sampling_params,
80
+ return_logprob=request.logprobs is not None,
81
+ top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
82
+ logprob_start_len=logprob_start_len,
83
+ return_text_in_logprobs=True,
84
+ stream=request.stream,
85
+ lora_path=request.lora_path,
86
+ bootstrap_host=request.bootstrap_host,
87
+ bootstrap_port=request.bootstrap_port,
88
+ bootstrap_room=request.bootstrap_room,
89
+ return_hidden_states=request.return_hidden_states,
90
+ )
91
+
92
+ return adapted_request, request
93
+
94
+ def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
95
+ """Build sampling parameters for the request"""
96
+ # Start with common parameters
97
+ sampling_params = {
98
+ "temperature": request.temperature,
99
+ "max_new_tokens": request.max_tokens,
100
+ "min_new_tokens": request.min_tokens,
101
+ "stop": request.stop,
102
+ "stop_token_ids": request.stop_token_ids,
103
+ "top_p": request.top_p,
104
+ "top_k": request.top_k,
105
+ "min_p": request.min_p,
106
+ "presence_penalty": request.presence_penalty,
107
+ "frequency_penalty": request.frequency_penalty,
108
+ "repetition_penalty": request.repetition_penalty,
109
+ "regex": request.regex,
110
+ "json_schema": request.json_schema,
111
+ "ebnf": request.ebnf,
112
+ "n": request.n,
113
+ "no_stop_trim": request.no_stop_trim,
114
+ "ignore_eos": request.ignore_eos,
115
+ "skip_special_tokens": request.skip_special_tokens,
116
+ "logit_bias": request.logit_bias,
117
+ }
118
+
119
+ return sampling_params
120
+
121
+ async def _handle_streaming_request(
122
+ self,
123
+ adapted_request: GenerateReqInput,
124
+ request: CompletionRequest,
125
+ raw_request: Request,
126
+ ) -> StreamingResponse:
127
+ """Handle streaming completion request"""
128
+ return StreamingResponse(
129
+ self._generate_completion_stream(adapted_request, request, raw_request),
130
+ media_type="text/event-stream",
131
+ background=self.tokenizer_manager.create_abort_task(adapted_request),
132
+ )
133
+
134
+ async def _generate_completion_stream(
135
+ self,
136
+ adapted_request: GenerateReqInput,
137
+ request: CompletionRequest,
138
+ raw_request: Request,
139
+ ) -> AsyncGenerator[str, None]:
140
+ """Generate streaming completion response"""
141
+ created = int(time.time())
142
+
143
+ # State tracking for streaming
144
+ stream_buffers = {}
145
+ n_prev_tokens = {}
146
+
147
+ # Usage tracking
148
+ prompt_tokens = {}
149
+ completion_tokens = {}
150
+ cached_tokens = {}
151
+ hidden_states = {}
152
+
153
+ try:
154
+ async for content in self.tokenizer_manager.generate_request(
155
+ adapted_request, raw_request
156
+ ):
157
+ index = content.get("index", 0)
158
+
159
+ text = content["text"]
160
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
161
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
162
+ cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
163
+ hidden_states[index] = content["meta_info"].get("hidden_states", None)
164
+
165
+ stream_buffer = stream_buffers.get(index, "")
166
+ # Handle echo for first chunk
167
+ if not stream_buffer: # The first chunk
168
+ if request.echo:
169
+ echo_text = self._get_echo_text(request, index)
170
+ text = echo_text + text
171
+
172
+ # Handle logprobs
173
+ logprobs = None
174
+ if request.logprobs is not None:
175
+ # The first chunk and echo is enabled.
176
+ if not stream_buffer and request.echo:
177
+ input_token_logprobs = content["meta_info"][
178
+ "input_token_logprobs"
179
+ ]
180
+ input_top_logprobs = content["meta_info"]["input_top_logprobs"]
181
+ else:
182
+ input_token_logprobs = None
183
+ input_top_logprobs = None
184
+
185
+ n_prev_token = n_prev_tokens.get(index, 0)
186
+ logprobs = to_openai_style_logprobs(
187
+ input_token_logprobs=input_token_logprobs,
188
+ input_top_logprobs=input_top_logprobs,
189
+ output_token_logprobs=content["meta_info"][
190
+ "output_token_logprobs"
191
+ ][n_prev_token:],
192
+ output_top_logprobs=content["meta_info"]["output_top_logprobs"][
193
+ n_prev_token:
194
+ ],
195
+ )
196
+ n_prev_tokens[index] = len(
197
+ content["meta_info"]["output_token_logprobs"]
198
+ )
199
+
200
+ # Generate delta
201
+ delta = text[len(stream_buffer) :]
202
+ stream_buffers[index] = stream_buffer + delta
203
+ finish_reason = content["meta_info"]["finish_reason"]
204
+
205
+ choice_data = CompletionResponseStreamChoice(
206
+ index=index,
207
+ text=delta,
208
+ logprobs=logprobs,
209
+ finish_reason=finish_reason["type"] if finish_reason else None,
210
+ matched_stop=(
211
+ finish_reason["matched"]
212
+ if finish_reason and "matched" in finish_reason
213
+ else None
214
+ ),
215
+ )
216
+ chunk = CompletionStreamResponse(
217
+ id=content["meta_info"]["id"],
218
+ created=created,
219
+ object="text_completion",
220
+ choices=[choice_data],
221
+ model=request.model,
222
+ )
223
+
224
+ yield f"data: {chunk.model_dump_json()}\n\n"
225
+
226
+ if request.return_hidden_states and hidden_states:
227
+ for index, choice_hidden_states in hidden_states.items():
228
+ if choice_hidden_states:
229
+ last_token_hidden_states = (
230
+ choice_hidden_states[-1]
231
+ if len(choice_hidden_states) > 1
232
+ else []
233
+ )
234
+ hidden_states_chunk = CompletionStreamResponse(
235
+ id=content["meta_info"]["id"],
236
+ created=created,
237
+ object="text_completion",
238
+ choices=[
239
+ CompletionResponseStreamChoice(
240
+ index=index,
241
+ text="",
242
+ hidden_states=last_token_hidden_states,
243
+ finish_reason=None,
244
+ )
245
+ ],
246
+ model=request.model,
247
+ )
248
+ yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
249
+
250
+ # Handle final usage chunk
251
+ if request.stream_options and request.stream_options.include_usage:
252
+ usage = UsageProcessor.calculate_streaming_usage(
253
+ prompt_tokens,
254
+ completion_tokens,
255
+ cached_tokens,
256
+ n_choices=request.n,
257
+ enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
258
+ )
259
+ final_usage_chunk = CompletionStreamResponse(
260
+ id=content["meta_info"]["id"],
261
+ created=created,
262
+ choices=[],
263
+ model=request.model,
264
+ usage=usage,
265
+ )
266
+ final_usage_data = final_usage_chunk.model_dump_json(exclude_none=True)
267
+ yield f"data: {final_usage_data}\n\n"
268
+
269
+ except Exception as e:
270
+ error = self.create_streaming_error_response(str(e))
271
+ yield f"data: {error}\n\n"
272
+
273
+ yield "data: [DONE]\n\n"
274
+
275
+ async def _handle_non_streaming_request(
276
+ self,
277
+ adapted_request: GenerateReqInput,
278
+ request: CompletionRequest,
279
+ raw_request: Request,
280
+ ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]:
281
+ """Handle non-streaming completion request"""
282
+ try:
283
+ generator = self.tokenizer_manager.generate_request(
284
+ adapted_request, raw_request
285
+ )
286
+ ret = await generator.__anext__()
287
+ except ValueError as e:
288
+ return self.create_error_response(str(e))
289
+
290
+ if not isinstance(ret, list):
291
+ ret = [ret]
292
+
293
+ response = self._build_completion_response(
294
+ request,
295
+ ret,
296
+ int(time.time()),
297
+ )
298
+
299
+ return response
300
+
301
+ def _build_completion_response(
302
+ self,
303
+ request: CompletionRequest,
304
+ ret: List[Dict[str, Any]],
305
+ created: int,
306
+ ) -> CompletionResponse:
307
+ """Build completion response from generation results"""
308
+ choices = []
309
+ echo = False
310
+
311
+ # Prepare echo prompts if needed
312
+ echo_prompts = []
313
+ if request.echo:
314
+ echo_prompts = self._prepare_echo_prompts(request)
315
+ echo = True
316
+
317
+ for idx, ret_item in enumerate(ret):
318
+ text = ret_item["text"]
319
+
320
+ # Handle echo
321
+ if echo:
322
+ prompt_index = idx // request.n
323
+ text = echo_prompts[prompt_index] + text
324
+
325
+ # Handle logprobs
326
+ logprobs = None
327
+ if request.logprobs is not None:
328
+ if echo:
329
+ input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
330
+ input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
331
+ else:
332
+ input_token_logprobs = None
333
+ input_top_logprobs = None
334
+
335
+ logprobs = to_openai_style_logprobs(
336
+ input_token_logprobs=input_token_logprobs,
337
+ input_top_logprobs=input_top_logprobs,
338
+ output_token_logprobs=ret_item["meta_info"][
339
+ "output_token_logprobs"
340
+ ],
341
+ output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
342
+ )
343
+
344
+ # Handle hidden states
345
+ hidden_states = process_hidden_states_from_ret(ret_item, request)
346
+
347
+ finish_reason = ret_item["meta_info"]["finish_reason"]
348
+
349
+ choice_data = CompletionResponseChoice(
350
+ index=idx,
351
+ text=text,
352
+ logprobs=logprobs,
353
+ finish_reason=finish_reason["type"] if finish_reason else None,
354
+ matched_stop=(
355
+ finish_reason["matched"]
356
+ if finish_reason and "matched" in finish_reason
357
+ else None
358
+ ),
359
+ hidden_states=hidden_states,
360
+ )
361
+ choices.append(choice_data)
362
+
363
+ # Calculate usage
364
+ cache_report = self.tokenizer_manager.server_args.enable_cache_report
365
+ usage = UsageProcessor.calculate_response_usage(
366
+ ret, n_choices=request.n, enable_cache_report=cache_report
367
+ )
368
+
369
+ return CompletionResponse(
370
+ id=ret[0]["meta_info"]["id"],
371
+ model=request.model,
372
+ created=created,
373
+ choices=choices,
374
+ usage=usage,
375
+ )
376
+
377
+ def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
378
+ """Get echo text for streaming response"""
379
+ if isinstance(request.prompt, str):
380
+ # for the case of single str prompts
381
+ return request.prompt
382
+ elif isinstance(request.prompt, list):
383
+ if isinstance(request.prompt[0], str):
384
+ # for the case of multiple str prompts
385
+ return request.prompt[index // request.n]
386
+ elif isinstance(request.prompt[0], int):
387
+ # for the case of single token ids prompt
388
+ return self.tokenizer_manager.tokenizer.decode(
389
+ request.prompt, skip_special_tokens=True
390
+ )
391
+ elif isinstance(request.prompt[0], list) and isinstance(
392
+ request.prompt[0][0], int
393
+ ):
394
+ # for the case of multiple token ids prompts
395
+ return self.tokenizer_manager.tokenizer.decode(
396
+ request.prompt[index // request.n],
397
+ skip_special_tokens=True,
398
+ )
399
+ return ""
400
+
401
+ def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:
402
+ """Prepare echo prompts for non-streaming response"""
403
+ # TODO: handle the case prompt is token ids
404
+ if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
405
+ # for the case of multiple str prompts
406
+ return request.prompt
407
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
408
+ # for the case of multiple token ids prompts
409
+ return [
410
+ self.tokenizer_manager.tokenizer.decode(
411
+ prompt, skip_special_tokens=True
412
+ )
413
+ for prompt in request.prompt
414
+ ]
415
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
416
+ # for the case of single token ids prompt
417
+ return [
418
+ self.tokenizer_manager.tokenizer.decode(
419
+ request.prompt, skip_special_tokens=True
420
+ )
421
+ ]
422
+ else:
423
+ # for the case of single str prompt
424
+ return [request.prompt]
@@ -0,0 +1,169 @@
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ from fastapi import Request
4
+ from fastapi.responses import ORJSONResponse
5
+
6
+ from sglang.srt.conversation import generate_embedding_convs
7
+ from sglang.srt.entrypoints.openai.protocol import (
8
+ EmbeddingObject,
9
+ EmbeddingRequest,
10
+ EmbeddingResponse,
11
+ ErrorResponse,
12
+ MultimodalEmbeddingInput,
13
+ UsageInfo,
14
+ )
15
+ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
16
+ from sglang.srt.managers.io_struct import EmbeddingReqInput
17
+ from sglang.srt.managers.template_manager import TemplateManager
18
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
19
+
20
+
21
+ class OpenAIServingEmbedding(OpenAIServingBase):
22
+ """Handler for v1/embeddings requests"""
23
+
24
+ def __init__(
25
+ self,
26
+ tokenizer_manager: TokenizerManager,
27
+ template_manager: TemplateManager,
28
+ ):
29
+ super().__init__(tokenizer_manager)
30
+ self.template_manager = template_manager
31
+
32
+ def _request_id_prefix(self) -> str:
33
+ return "embd-"
34
+
35
+ def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:
36
+ """Validate that the input is not empty or whitespace only."""
37
+ if not (input := request.input):
38
+ return "Input cannot be empty"
39
+
40
+ # Handle single string
41
+ if isinstance(input, str):
42
+ if not input.strip():
43
+ return "Input cannot be empty or whitespace only"
44
+ return None
45
+
46
+ # Handle list inputs
47
+ if isinstance(input, list):
48
+ if len(input) == 0:
49
+ return "Input cannot be empty"
50
+
51
+ # Check first element to determine type
52
+ first_item = input[0]
53
+
54
+ if isinstance(first_item, str):
55
+ # List of strings
56
+ for i, item in enumerate(input):
57
+ if not isinstance(item, str):
58
+ return f"All items in input list must be strings"
59
+ if not item.strip():
60
+ return f"Input at index {i} cannot be empty or whitespace only"
61
+ elif isinstance(first_item, int):
62
+ # List of integers (token IDs)
63
+ for i, item in enumerate(input):
64
+ if not isinstance(item, int):
65
+ return f"All items in input list must be integers"
66
+ if item < 0:
67
+ return f"Token ID at index {i} must be non-negative"
68
+ return None
69
+
70
+ def _convert_to_internal_request(
71
+ self,
72
+ request: EmbeddingRequest,
73
+ ) -> tuple[EmbeddingReqInput, EmbeddingRequest]:
74
+ """Convert OpenAI embedding request to internal format"""
75
+ prompt = request.input
76
+
77
+ if isinstance(prompt, str):
78
+ # Single string input
79
+ prompt_kwargs = {"text": prompt}
80
+ elif isinstance(prompt, list):
81
+ if len(prompt) > 0 and isinstance(prompt[0], str):
82
+ prompt_kwargs = {"text": prompt}
83
+ elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
84
+ # Handle multimodal embedding inputs
85
+ texts = []
86
+ images = []
87
+ for item in prompt:
88
+ # Use padding for text if None - this could be improved
89
+ texts.append(item.text if item.text is not None else "padding")
90
+ images.append(item.image if item.image is not None else None)
91
+
92
+ generate_prompts = []
93
+ # Check if we have a chat template for multimodal embeddings
94
+ if self.template_manager.chat_template_name is not None:
95
+ convs = generate_embedding_convs(
96
+ texts, images, self.template_manager.chat_template_name
97
+ )
98
+ for conv in convs:
99
+ generate_prompts.append(conv.get_prompt())
100
+ else:
101
+ generate_prompts = texts
102
+
103
+ if len(generate_prompts) == 1:
104
+ prompt_kwargs = {
105
+ "text": generate_prompts[0],
106
+ "image_data": images[0],
107
+ }
108
+ else:
109
+ prompt_kwargs = {
110
+ "text": generate_prompts,
111
+ "image_data": images,
112
+ }
113
+ else:
114
+ # List of integers (token IDs) or empty list
115
+ prompt_kwargs = {"input_ids": prompt}
116
+ else:
117
+ # Other types (should not happen but handle gracefully)
118
+ prompt_kwargs = {"input_ids": prompt}
119
+
120
+ adapted_request = EmbeddingReqInput(
121
+ **prompt_kwargs,
122
+ )
123
+
124
+ return adapted_request, request
125
+
126
+ async def _handle_non_streaming_request(
127
+ self,
128
+ adapted_request: EmbeddingReqInput,
129
+ request: EmbeddingRequest,
130
+ raw_request: Request,
131
+ ) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]:
132
+ """Handle the embedding request"""
133
+ try:
134
+ ret = await self.tokenizer_manager.generate_request(
135
+ adapted_request, raw_request
136
+ ).__anext__()
137
+ except ValueError as e:
138
+ return self.create_error_response(str(e))
139
+
140
+ if not isinstance(ret, list):
141
+ ret = [ret]
142
+
143
+ response = self._build_embedding_response(ret)
144
+ return response
145
+
146
+ def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
147
+ """Build the embedding response"""
148
+ embedding_objects = []
149
+ prompt_tokens = 0
150
+
151
+ for idx, ret_item in enumerate(ret):
152
+ embedding_objects.append(
153
+ EmbeddingObject(
154
+ embedding=ret_item["embedding"],
155
+ index=idx,
156
+ )
157
+ )
158
+ # Handle missing prompt_tokens gracefully
159
+ meta_info = ret_item.get("meta_info", {})
160
+ prompt_tokens += meta_info.get("prompt_tokens", 0)
161
+
162
+ return EmbeddingResponse(
163
+ data=embedding_objects,
164
+ model=self.tokenizer_manager.model_path,
165
+ usage=UsageInfo(
166
+ prompt_tokens=prompt_tokens,
167
+ total_tokens=prompt_tokens,
168
+ ),
169
+ )
@@ -0,0 +1,102 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from fastapi import Request
5
+ from fastapi.responses import ORJSONResponse
6
+
7
+ from sglang.srt.entrypoints.openai.protocol import (
8
+ ErrorResponse,
9
+ RerankResponse,
10
+ V1RerankReqInput,
11
+ )
12
+ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
13
+ from sglang.srt.managers.io_struct import EmbeddingReqInput
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OpenAIServingRerank(OpenAIServingBase):
19
+ """Handler for /v1/rerank requests"""
20
+
21
+ # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
22
+ # to another module in the future.
23
+
24
+ def _request_id_prefix(self) -> str:
25
+ return "rerank-"
26
+
27
+ def _validate_request(self, request: V1RerankReqInput) -> Optional[str]:
28
+ """Validate rerank request format and content"""
29
+ if not request.query:
30
+ return "Query cannot be empty"
31
+
32
+ if isinstance(request.query, str):
33
+ if not request.query.strip():
34
+ return "Query cannot be empty or whitespace only"
35
+
36
+ if not request.documents:
37
+ return "Documents cannot be empty"
38
+
39
+ for doc in request.documents:
40
+ if not doc:
41
+ return "Each document must be a non-empty string"
42
+ if isinstance(doc, str) and not doc.strip():
43
+ return "Each document cannot be empty or whitespace only"
44
+
45
+ return None
46
+
47
+ def _convert_to_internal_request(
48
+ self, request: V1RerankReqInput
49
+ ) -> tuple[EmbeddingReqInput, V1RerankReqInput]:
50
+ """Convert OpenAI rerank request to internal embedding format"""
51
+ # Create pairs of [query, document] for each document
52
+ pairs = []
53
+ for doc in request.documents:
54
+ pairs.append([request.query, doc])
55
+
56
+ adapted_request = EmbeddingReqInput(
57
+ text=pairs,
58
+ is_cross_encoder_request=True,
59
+ )
60
+
61
+ return adapted_request, request
62
+
63
+ async def _handle_non_streaming_request(
64
+ self,
65
+ adapted_request: EmbeddingReqInput,
66
+ request: V1RerankReqInput,
67
+ raw_request: Request,
68
+ ) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]:
69
+ """Handle the rerank request"""
70
+ try:
71
+ ret = await self.tokenizer_manager.generate_request(
72
+ adapted_request, raw_request
73
+ ).__anext__()
74
+
75
+ except ValueError as e:
76
+ return self.create_error_response(str(e))
77
+
78
+ if not isinstance(ret, list):
79
+ ret = [ret]
80
+
81
+ responses = self._build_rerank_response(ret, request)
82
+ return responses
83
+
84
+ def _build_rerank_response(
85
+ self, ret: List[Dict[str, Any]], request: V1RerankReqInput
86
+ ) -> List[RerankResponse]:
87
+ """Build the rerank response from generation results"""
88
+ responses = []
89
+ for idx, ret_item in enumerate(ret):
90
+ responses.append(
91
+ RerankResponse(
92
+ score=ret_item["embedding"],
93
+ document=request.documents[idx],
94
+ index=idx,
95
+ meta_info=ret_item["meta_info"],
96
+ )
97
+ )
98
+
99
+ # Sort by score in descending order (highest relevance first)
100
+ responses.sort(key=lambda x: x.score, reverse=True)
101
+
102
+ return responses