sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
1
+ import logging
2
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
10
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
11
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
12
+
13
+ import sglang.srt.managers.mm_utils as mm_utils
14
+ import sglang.srt.model_loader.weight_utils as weight_utils
15
+ import sglang.srt.utils as utils
16
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
17
+ from sglang.srt.layers.pooler import Pooler, PoolingType
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
20
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
21
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ ##### BEGIN COPY configuration.py #####
28
+
29
+
30
+ class VILAConfig(PretrainedConfig):
31
+ # Class attributes.
32
+ model_type: str = "vila"
33
+ sub_configs: Dict[str, PretrainedConfig] = {
34
+ "text_config": Qwen2Config(),
35
+ "vision_config": SiglipVisionConfig(),
36
+ }
37
+ _auto_class: Optional[str] = "AutoConfig"
38
+
39
+ # Configuration for sub-modules.
40
+ text_config: Qwen2Config = Qwen2Config()
41
+ vision_config: SiglipVisionConfig = SiglipVisionConfig()
42
+
43
+ # Model configuration.
44
+ hidden_size: int
45
+ image_token_id: int
46
+ mm_hidden_size: int
47
+ mm_projector_type: str
48
+ mm_vision_select_feature: str
49
+ mm_vision_select_layer: int
50
+ video_token_id: int
51
+
52
+ def __init__(
53
+ self,
54
+ text_config: Optional[Dict[str, Any]] = None,
55
+ vision_config: Optional[Dict[str, Any]] = None,
56
+ *,
57
+ hidden_size: int = 1536,
58
+ image_token_id: int = 151649,
59
+ mm_hidden_size: int = 1152,
60
+ mm_projector_type: str = "mlp_downsample_3x3_fix",
61
+ mm_vision_select_feature: str = "cls_patch",
62
+ mm_vision_select_layer: int = -2,
63
+ video_token_id: int = 151650,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(**kwargs)
67
+
68
+ self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
69
+ self.vision_config = (
70
+ SiglipVisionConfig(**vision_config)
71
+ if vision_config
72
+ else SiglipVisionConfig()
73
+ )
74
+
75
+ self.hidden_size = hidden_size
76
+ self.image_token_id = image_token_id
77
+ self.mm_hidden_size = mm_hidden_size
78
+ self.mm_projector_type = mm_projector_type
79
+ self.mm_vision_select_feature = mm_vision_select_feature
80
+ self.mm_vision_select_layer = mm_vision_select_layer
81
+ self.video_token_id = video_token_id
82
+
83
+
84
+ ##### END COPY configuration.py #####
85
+
86
+ ##### BEGIN COPY modeling_vila.py #####
87
+
88
+
89
+ class DownSample3x3BlockFix(nn.Module):
90
+ def forward(self, x: Tensor) -> Tensor:
91
+ """
92
+ Args:
93
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
94
+
95
+ Returns:
96
+ The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
97
+ """
98
+
99
+ batch_size, sequence_length, hidden_size = x.shape
100
+
101
+ feat_size = int(sequence_length**0.5)
102
+ if feat_size**2 != sequence_length:
103
+ raise ValueError(
104
+ f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
105
+ )
106
+
107
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
108
+
109
+ pad_after = (3 - feat_size % 3) % 3
110
+ if pad_after > 0:
111
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
112
+ feat_size = feat_size + pad_after
113
+
114
+ features = features.reshape(
115
+ batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
116
+ )
117
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
118
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
119
+
120
+ return features
121
+
122
+
123
+ class MultimodalProjector(nn.Module):
124
+ layers: nn.Sequential
125
+
126
+ def __init__(
127
+ self,
128
+ config: VILAConfig,
129
+ *args,
130
+ **kwargs,
131
+ ):
132
+ super().__init__(*args, **kwargs)
133
+
134
+ if config.mm_projector_type == "mlp_downsample_3x3_fix":
135
+ self.layers = nn.Sequential(
136
+ DownSample3x3BlockFix(),
137
+ nn.LayerNorm(config.mm_hidden_size * 9),
138
+ nn.Linear(
139
+ config.mm_hidden_size * 9,
140
+ config.mm_hidden_size * 3,
141
+ ),
142
+ nn.GELU(),
143
+ nn.LayerNorm(config.vision_config.hidden_size * 3),
144
+ nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
145
+ nn.GELU(),
146
+ nn.Linear(config.hidden_size, config.hidden_size),
147
+ )
148
+ else:
149
+ raise NotImplementedError(
150
+ f"Unsupported mm_projector_type: {config.mm_projector_type}"
151
+ )
152
+
153
+ self.layers.type(config.torch_dtype)
154
+
155
+ @property
156
+ def device(self) -> torch.device:
157
+ return next(self.parameters()).device
158
+
159
+ @property
160
+ def dtype(self) -> torch.dtype:
161
+ return next(self.parameters()).dtype
162
+
163
+ def forward(self, x: Tensor) -> Tensor:
164
+ """
165
+ Args:
166
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
167
+
168
+ Returns:
169
+ The output tensor of shape (batch_size, image_pad_len, hidden_size).
170
+ """
171
+
172
+ return self.layers(x.to(device=self.device, dtype=self.dtype))
173
+
174
+
175
+ ##### END COPY modeling_vila.py #####
176
+
177
+
178
+ class VILAForConditionalGeneration(nn.Module):
179
+ config: VILAConfig
180
+ quant_config: Optional[QuantizationConfig]
181
+
182
+ logits_processor: LogitsProcessor
183
+ pooler: Pooler
184
+
185
+ llm: Qwen2ForCausalLM
186
+ mm_projector: MultimodalProjector
187
+ vision_tower: SiglipVisionModel
188
+
189
+ def __init__(
190
+ self,
191
+ config: VILAConfig,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ) -> None:
195
+ super().__init__()
196
+
197
+ self.config = config
198
+ self.quant_config = quant_config
199
+
200
+ self.logits_processor = LogitsProcessor(config)
201
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
202
+
203
+ self.llm = Qwen2ForCausalLM(
204
+ config=config.text_config,
205
+ quant_config=quant_config,
206
+ prefix=utils.add_prefix("llm", prefix),
207
+ )
208
+ self.mm_projector = MultimodalProjector(config)
209
+ self.vision_tower = SiglipVisionModel(config.vision_config)
210
+
211
+ @property
212
+ def dtype(self) -> torch.dtype:
213
+ return self.config.torch_dtype
214
+
215
+ def forward(
216
+ self,
217
+ input_ids: Tensor,
218
+ positions: Tensor,
219
+ forward_batch: ForwardBatch,
220
+ get_embedding: bool = False,
221
+ ) -> LogitsProcessorOutput:
222
+ output = mm_utils.general_mm_embed_routine(
223
+ input_ids=input_ids,
224
+ forward_batch=forward_batch,
225
+ language_model=self.llm,
226
+ image_data_embedding_func=self.get_image_feature,
227
+ get_embedding=get_embedding,
228
+ positions=positions,
229
+ )
230
+
231
+ return cast(LogitsProcessorOutput, output)
232
+
233
+ def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
234
+ pixel_values = cast(Tensor, mm_input[0].pixel_values)
235
+
236
+ ##### BEGIN COPY modeling_vila.py #####
237
+
238
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
239
+ pixel_values.to(
240
+ device=self.vision_tower.device, dtype=self.vision_tower.dtype
241
+ ),
242
+ output_hidden_states=True,
243
+ )
244
+
245
+ mm_projector_input = self._vision_tower_output_to_mm_projector_input(
246
+ vision_tower_output
247
+ )
248
+
249
+ image_embedding: Tensor = self.mm_projector.__call__(
250
+ mm_projector_input.to(
251
+ device=self.mm_projector.device, dtype=self.mm_projector.dtype
252
+ )
253
+ )
254
+
255
+ ##### END COPY modeling_vila.py #####
256
+
257
+ return image_embedding
258
+
259
+ def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None:
260
+ params_dict = dict(self.named_parameters())
261
+
262
+ for name, loaded_weight in weights:
263
+ if name.startswith("llm."):
264
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
265
+ else:
266
+ param = params_dict[name]
267
+ weight_loader = getattr(
268
+ param, "weight_loader", weight_utils.default_weight_loader
269
+ )
270
+ weight_loader(param, loaded_weight)
271
+
272
+ def pad_input_ids(
273
+ self,
274
+ input_ids: List[int],
275
+ image_inputs: MultimodalInputs,
276
+ ) -> List[int]:
277
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens(
278
+ token_ids=[self.config.image_token_id],
279
+ )
280
+
281
+ return pattern.pad_input_tokens(input_ids, image_inputs)
282
+
283
+ ##### BEGIN COPY modeling_vila.py #####
284
+
285
+ def _vision_tower_output_to_mm_projector_input(
286
+ self,
287
+ vision_tower_output: BaseModelOutputWithPooling,
288
+ ) -> Tensor:
289
+ assert vision_tower_output.hidden_states is not None
290
+
291
+ selected_layer_hidden_states = vision_tower_output.hidden_states[
292
+ self.config.mm_vision_select_layer
293
+ ]
294
+
295
+ if self.config.mm_vision_select_feature == "cls_patch":
296
+ return selected_layer_hidden_states
297
+ else:
298
+ raise NotImplementedError(
299
+ f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
300
+ )
301
+
302
+ ##### END COPY modeling_vila.py #####
303
+
304
+
305
+ EntryClass = [VILAForConditionalGeneration]
@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
41
41
  register_conv_template,
42
42
  )
43
43
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
44
- from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
44
+ from sglang.srt.managers.io_struct import (
45
+ EmbeddingReqInput,
46
+ GenerateReqInput,
47
+ V1RerankReqInput,
48
+ )
45
49
  from sglang.srt.openai_api.protocol import (
46
50
  BatchRequest,
47
51
  BatchResponse,
@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
69
73
  FunctionResponse,
70
74
  LogProbs,
71
75
  MultimodalEmbeddingInput,
76
+ RerankResponse,
72
77
  ScoringRequest,
73
78
  ScoringResponse,
74
79
  ToolCall,
@@ -542,6 +547,7 @@ def v1_generate_request(
542
547
  logprob_start_lens = []
543
548
  top_logprobs_nums = []
544
549
  lora_paths = []
550
+ return_hidden_states = []
545
551
 
546
552
  for request in all_requests:
547
553
  # NOTE: with openai API, the prompt's logprobs are always not computed
@@ -581,6 +587,7 @@ def v1_generate_request(
581
587
  "no_stop_trim": request.no_stop_trim,
582
588
  "ignore_eos": request.ignore_eos,
583
589
  "skip_special_tokens": request.skip_special_tokens,
590
+ "logit_bias": request.logit_bias,
584
591
  }
585
592
  )
586
593
  return_logprobs.append(request.logprobs is not None)
@@ -588,6 +595,7 @@ def v1_generate_request(
588
595
  top_logprobs_nums.append(
589
596
  request.logprobs if request.logprobs is not None else 0
590
597
  )
598
+ return_hidden_states.append(request.return_hidden_states)
591
599
 
592
600
  if len(all_requests) == 1:
593
601
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
@@ -599,6 +607,7 @@ def v1_generate_request(
599
607
  logprob_start_lens = logprob_start_lens[0]
600
608
  top_logprobs_nums = top_logprobs_nums[0]
601
609
  lora_paths = lora_paths[0]
610
+ return_hidden_states = return_hidden_states[0]
602
611
  else:
603
612
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
604
613
  prompt_kwargs = {"text": prompts}
@@ -615,6 +624,7 @@ def v1_generate_request(
615
624
  stream=all_requests[0].stream,
616
625
  rid=request_ids,
617
626
  lora_path=lora_paths,
627
+ return_hidden_states=return_hidden_states,
618
628
  bootstrap_host=all_requests[0].bootstrap_host,
619
629
  bootstrap_port=all_requests[0].bootstrap_port,
620
630
  bootstrap_room=all_requests[0].bootstrap_room,
@@ -683,6 +693,16 @@ def v1_generate_response(
683
693
  else:
684
694
  logprobs = None
685
695
 
696
+ hidden_states = None
697
+ if isinstance(request, list) and request[idx].return_hidden_states:
698
+ hidden_states = ret_item["meta_info"].get("hidden_states", None)
699
+ elif (not isinstance(request, list)) and request.return_hidden_states:
700
+ hidden_states = ret_item["meta_info"].get("hidden_states", None)
701
+ if hidden_states is not None:
702
+ hidden_states = (
703
+ hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
704
+ )
705
+
686
706
  finish_reason = ret_item["meta_info"]["finish_reason"]
687
707
 
688
708
  if to_file:
@@ -698,6 +718,8 @@ def v1_generate_response(
698
718
  else None
699
719
  ),
700
720
  }
721
+ if hidden_states is not None:
722
+ choice_data["hidden_states"] = hidden_states
701
723
  else:
702
724
  choice_data = CompletionResponseChoice(
703
725
  index=idx,
@@ -709,6 +731,7 @@ def v1_generate_response(
709
731
  if finish_reason and "matched" in finish_reason
710
732
  else None
711
733
  ),
734
+ hidden_states=hidden_states,
712
735
  )
713
736
 
714
737
  choices.append(choice_data)
@@ -777,6 +800,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
777
800
  prompt_tokens = {}
778
801
  completion_tokens = {}
779
802
  cached_tokens = {}
803
+ hidden_states = {}
780
804
 
781
805
  try:
782
806
  async for content in tokenizer_manager.generate_request(
@@ -791,6 +815,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
791
815
  prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
792
816
  completion_tokens[index] = content["meta_info"]["completion_tokens"]
793
817
  cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
818
+ hidden_states[index] = content["meta_info"].get(
819
+ "hidden_states", None
820
+ ) or hidden_states.get(index)
794
821
 
795
822
  if not stream_buffer: # The first chunk
796
823
  if request.echo:
@@ -873,6 +900,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
873
900
  n_prev_tokens[index] = n_prev_token
874
901
 
875
902
  yield f"data: {chunk.model_dump_json()}\n\n"
903
+ if request.return_hidden_states and hidden_states:
904
+ for index, choice_hidden_states in hidden_states.items():
905
+ last_token_hidden_states = (
906
+ choice_hidden_states[-1]
907
+ if choice_hidden_states and len(choice_hidden_states) > 1
908
+ else []
909
+ )
910
+ hidden_states_chunk = CompletionStreamResponse(
911
+ id=content["meta_info"]["id"],
912
+ created=created,
913
+ choices=[
914
+ CompletionResponseStreamChoice(
915
+ text="",
916
+ index=index,
917
+ hidden_states=last_token_hidden_states,
918
+ finish_reason=None,
919
+ )
920
+ ],
921
+ model=request.model,
922
+ )
923
+ yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
876
924
  if request.stream_options and request.stream_options.include_usage:
877
925
  total_prompt_tokens = sum(
878
926
  tokens
@@ -973,6 +1021,7 @@ def v1_chat_generate_request(
973
1021
  top_logprobs_nums = []
974
1022
  modalities_list = []
975
1023
  lora_paths = []
1024
+ return_hidden_states = []
976
1025
 
977
1026
  # NOTE: with openai API, the prompt's logprobs are always not computed
978
1027
 
@@ -1176,6 +1225,7 @@ def v1_chat_generate_request(
1176
1225
  "no_stop_trim": request.no_stop_trim,
1177
1226
  "ignore_eos": request.ignore_eos,
1178
1227
  "skip_special_tokens": request.skip_special_tokens,
1228
+ "logit_bias": request.logit_bias,
1179
1229
  }
1180
1230
 
1181
1231
  if request.response_format and request.response_format.type == "json_schema":
@@ -1215,6 +1265,7 @@ def v1_chat_generate_request(
1215
1265
  image_data_list.append(image_data)
1216
1266
  audio_data_list.append(audio_data)
1217
1267
  modalities_list.append(modalities)
1268
+ return_hidden_states.append(request.return_hidden_states)
1218
1269
  if len(all_requests) == 1:
1219
1270
  if is_multimodal:
1220
1271
  # processor will need text input
@@ -1233,6 +1284,7 @@ def v1_chat_generate_request(
1233
1284
  modalities_list = modalities_list[0]
1234
1285
  lora_paths = lora_paths[0]
1235
1286
  request_ids = request_ids[0]
1287
+ return_hidden_states = return_hidden_states[0]
1236
1288
  else:
1237
1289
  if tokenizer_manager.model_config.is_multimodal:
1238
1290
  # processor will need text input
@@ -1259,6 +1311,7 @@ def v1_chat_generate_request(
1259
1311
  bootstrap_host=all_requests[0].bootstrap_host,
1260
1312
  bootstrap_port=all_requests[0].bootstrap_port,
1261
1313
  bootstrap_room=all_requests[0].bootstrap_room,
1314
+ return_hidden_states=return_hidden_states,
1262
1315
  )
1263
1316
 
1264
1317
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
@@ -1319,6 +1372,20 @@ def v1_chat_generate_response(
1319
1372
  else:
1320
1373
  choice_logprobs = None
1321
1374
 
1375
+ if isinstance(request, list) and request[idx].return_hidden_states:
1376
+ include_hidden_states = True
1377
+ elif not isinstance(request, list) and request.return_hidden_states:
1378
+ include_hidden_states = True
1379
+ else:
1380
+ include_hidden_states = False
1381
+ if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
1382
+ hidden_states = ret_item["meta_info"]["hidden_states"]
1383
+ hidden_states = (
1384
+ hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
1385
+ )
1386
+ else:
1387
+ hidden_states = None
1388
+
1322
1389
  finish_reason = ret_item["meta_info"]["finish_reason"]
1323
1390
 
1324
1391
  tool_calls = None
@@ -1391,6 +1458,8 @@ def v1_chat_generate_response(
1391
1458
  else None
1392
1459
  ),
1393
1460
  }
1461
+ if hidden_states is not None:
1462
+ choice_data["hidden_states"] = hidden_states
1394
1463
  else:
1395
1464
  choice_data = ChatCompletionResponseChoice(
1396
1465
  index=idx,
@@ -1407,6 +1476,7 @@ def v1_chat_generate_response(
1407
1476
  if finish_reason and "matched" in finish_reason
1408
1477
  else None
1409
1478
  ),
1479
+ hidden_states=hidden_states,
1410
1480
  )
1411
1481
 
1412
1482
  choices.append(choice_data)
@@ -1479,19 +1549,23 @@ async def v1_chat_completions(
1479
1549
  reasoning_parser_dict = {}
1480
1550
 
1481
1551
  async def generate_stream_resp():
1482
- tool_call_first = True
1552
+ tool_index_previous = -1
1483
1553
  is_firsts = {}
1484
1554
  stream_buffers = {}
1485
1555
  n_prev_tokens = {}
1486
1556
  prompt_tokens = {}
1487
1557
  completion_tokens = {}
1488
1558
  cached_tokens = {}
1559
+ hidden_states = {}
1489
1560
  try:
1490
1561
  async for content in tokenizer_manager.generate_request(
1491
1562
  adapted_request, raw_request
1492
1563
  ):
1493
1564
  index = content.get("index", 0)
1494
1565
  text = content["text"]
1566
+ hidden_states[index] = content["meta_info"].get(
1567
+ "hidden_states", None
1568
+ ) or hidden_states.get(index)
1495
1569
 
1496
1570
  is_first = is_firsts.get(index, True)
1497
1571
  stream_buffer = stream_buffers.get(index, "")
@@ -1613,6 +1687,7 @@ async def v1_chat_completions(
1613
1687
  if (delta and len(delta) == 0) or not delta:
1614
1688
  stream_buffers[index] = new_stream_buffer
1615
1689
  is_firsts[index] = is_first
1690
+ n_prev_tokens[index] = n_prev_token
1616
1691
  continue
1617
1692
 
1618
1693
  if request.tool_choice != "none" and request.tools:
@@ -1645,6 +1720,7 @@ async def v1_chat_completions(
1645
1720
 
1646
1721
  # 2) if we found calls, we output them as separate chunk(s)
1647
1722
  for call_item in calls:
1723
+ tool_index_current = call_item.tool_index
1648
1724
  # transform call_item -> FunctionResponse + ToolCall
1649
1725
  if finish_reason_type == "stop":
1650
1726
  latest_delta_len = 0
@@ -1671,7 +1747,7 @@ async def v1_chat_completions(
1671
1747
  tool_call = ToolCall(
1672
1748
  id=(
1673
1749
  f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
1674
- if tool_call_first
1750
+ if tool_index_previous != tool_index_current
1675
1751
  else None
1676
1752
  ),
1677
1753
  index=call_item.tool_index,
@@ -1680,7 +1756,7 @@ async def v1_chat_completions(
1680
1756
  arguments=call_item.parameters,
1681
1757
  ),
1682
1758
  )
1683
- tool_call_first = False
1759
+ tool_index_previous = tool_index_current
1684
1760
  choice_data = ChatCompletionResponseStreamChoice(
1685
1761
  index=index,
1686
1762
  delta=DeltaMessage(tool_calls=[tool_call]),
@@ -1701,6 +1777,7 @@ async def v1_chat_completions(
1701
1777
 
1702
1778
  stream_buffers[index] = new_stream_buffer
1703
1779
  is_firsts[index] = is_first
1780
+ n_prev_tokens[index] = n_prev_token
1704
1781
 
1705
1782
  else:
1706
1783
  # No tool calls => just treat this as normal text
@@ -1733,6 +1810,7 @@ async def v1_chat_completions(
1733
1810
  yield f"data: {chunk.model_dump_json()}\n\n"
1734
1811
  stream_buffers[index] = new_stream_buffer
1735
1812
  is_firsts[index] = is_first
1813
+ n_prev_tokens[index] = n_prev_token
1736
1814
  if finish_reason_type == "stop" and request.tool_choice != "none":
1737
1815
  parser = FunctionCallParser(
1738
1816
  tools=request.tools,
@@ -1768,6 +1846,28 @@ async def v1_chat_completions(
1768
1846
 
1769
1847
  else:
1770
1848
  usage = None
1849
+ if request.return_hidden_states and hidden_states:
1850
+ for index, choice_hidden_states in hidden_states.items():
1851
+ last_token_hidden_states = (
1852
+ choice_hidden_states[-1]
1853
+ if choice_hidden_states and len(choice_hidden_states) > 1
1854
+ else []
1855
+ )
1856
+ hidden_states_chunk = ChatCompletionStreamResponse(
1857
+ id=content["meta_info"]["id"],
1858
+ created=created,
1859
+ choices=[
1860
+ ChatCompletionResponseStreamChoice(
1861
+ index=index,
1862
+ delta=DeltaMessage(
1863
+ hidden_states=last_token_hidden_states
1864
+ ),
1865
+ finish_reason=finish_reason_type,
1866
+ )
1867
+ ],
1868
+ model=request.model,
1869
+ )
1870
+ yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
1771
1871
  final_usage_chunk = ChatCompletionStreamResponse(
1772
1872
  id=content["meta_info"]["id"],
1773
1873
  created=created,
@@ -1925,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
1925
2025
  return response
1926
2026
 
1927
2027
 
2028
+ def v1_rerank_request(obj: V1RerankReqInput):
2029
+ if obj.query is None:
2030
+ raise ValueError("query is required")
2031
+ if obj.documents is None or len(obj.documents) == 0:
2032
+ raise ValueError("documents is required")
2033
+
2034
+ pairs = []
2035
+ for doc in obj.documents:
2036
+ pairs.append([obj.query, doc])
2037
+
2038
+ adapted_request = EmbeddingReqInput(
2039
+ text=pairs,
2040
+ is_cross_encoder_request=True,
2041
+ )
2042
+
2043
+ return adapted_request
2044
+
2045
+
2046
+ def v1_rerank_response(ret, obj: V1RerankReqInput):
2047
+
2048
+ response = []
2049
+ for idx, ret_item in enumerate(ret):
2050
+ response.append(
2051
+ RerankResponse(
2052
+ score=ret[idx]["embedding"],
2053
+ document=obj.documents[idx],
2054
+ index=idx,
2055
+ meta_info=ret[idx]["meta_info"],
2056
+ )
2057
+ )
2058
+
2059
+ response.sort(key=lambda x: x.score, reverse=True)
2060
+
2061
+ return response
2062
+
2063
+
2064
+ async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
2065
+ adapted_request = v1_rerank_request(obj)
2066
+
2067
+ try:
2068
+ ret = await tokenizer_manager.generate_request(
2069
+ adapted_request, raw_request
2070
+ ).__anext__()
2071
+
2072
+ except ValueError as e:
2073
+ return create_error_response(str(e))
2074
+
2075
+ if not isinstance(ret, list):
2076
+ ret = [ret]
2077
+
2078
+ response = v1_rerank_response(
2079
+ ret,
2080
+ obj,
2081
+ )
2082
+
2083
+ return response
2084
+
2085
+
1928
2086
  def to_openai_style_logprobs(
1929
2087
  input_token_logprobs=None,
1930
2088
  output_token_logprobs=None,