sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
from torch import Tensor
|
8
|
+
from transformers.configuration_utils import PretrainedConfig
|
9
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
10
|
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
11
|
+
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
|
12
|
+
|
13
|
+
import sglang.srt.managers.mm_utils as mm_utils
|
14
|
+
import sglang.srt.model_loader.weight_utils as weight_utils
|
15
|
+
import sglang.srt.utils as utils
|
16
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
17
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
18
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
20
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
21
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
22
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
##### BEGIN COPY configuration.py #####
|
28
|
+
|
29
|
+
|
30
|
+
class VILAConfig(PretrainedConfig):
|
31
|
+
# Class attributes.
|
32
|
+
model_type: str = "vila"
|
33
|
+
sub_configs: Dict[str, PretrainedConfig] = {
|
34
|
+
"text_config": Qwen2Config(),
|
35
|
+
"vision_config": SiglipVisionConfig(),
|
36
|
+
}
|
37
|
+
_auto_class: Optional[str] = "AutoConfig"
|
38
|
+
|
39
|
+
# Configuration for sub-modules.
|
40
|
+
text_config: Qwen2Config = Qwen2Config()
|
41
|
+
vision_config: SiglipVisionConfig = SiglipVisionConfig()
|
42
|
+
|
43
|
+
# Model configuration.
|
44
|
+
hidden_size: int
|
45
|
+
image_token_id: int
|
46
|
+
mm_hidden_size: int
|
47
|
+
mm_projector_type: str
|
48
|
+
mm_vision_select_feature: str
|
49
|
+
mm_vision_select_layer: int
|
50
|
+
video_token_id: int
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
text_config: Optional[Dict[str, Any]] = None,
|
55
|
+
vision_config: Optional[Dict[str, Any]] = None,
|
56
|
+
*,
|
57
|
+
hidden_size: int = 1536,
|
58
|
+
image_token_id: int = 151649,
|
59
|
+
mm_hidden_size: int = 1152,
|
60
|
+
mm_projector_type: str = "mlp_downsample_3x3_fix",
|
61
|
+
mm_vision_select_feature: str = "cls_patch",
|
62
|
+
mm_vision_select_layer: int = -2,
|
63
|
+
video_token_id: int = 151650,
|
64
|
+
**kwargs,
|
65
|
+
):
|
66
|
+
super().__init__(**kwargs)
|
67
|
+
|
68
|
+
self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
|
69
|
+
self.vision_config = (
|
70
|
+
SiglipVisionConfig(**vision_config)
|
71
|
+
if vision_config
|
72
|
+
else SiglipVisionConfig()
|
73
|
+
)
|
74
|
+
|
75
|
+
self.hidden_size = hidden_size
|
76
|
+
self.image_token_id = image_token_id
|
77
|
+
self.mm_hidden_size = mm_hidden_size
|
78
|
+
self.mm_projector_type = mm_projector_type
|
79
|
+
self.mm_vision_select_feature = mm_vision_select_feature
|
80
|
+
self.mm_vision_select_layer = mm_vision_select_layer
|
81
|
+
self.video_token_id = video_token_id
|
82
|
+
|
83
|
+
|
84
|
+
##### END COPY configuration.py #####
|
85
|
+
|
86
|
+
##### BEGIN COPY modeling_vila.py #####
|
87
|
+
|
88
|
+
|
89
|
+
class DownSample3x3BlockFix(nn.Module):
|
90
|
+
def forward(self, x: Tensor) -> Tensor:
|
91
|
+
"""
|
92
|
+
Args:
|
93
|
+
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
|
97
|
+
"""
|
98
|
+
|
99
|
+
batch_size, sequence_length, hidden_size = x.shape
|
100
|
+
|
101
|
+
feat_size = int(sequence_length**0.5)
|
102
|
+
if feat_size**2 != sequence_length:
|
103
|
+
raise ValueError(
|
104
|
+
f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
|
105
|
+
)
|
106
|
+
|
107
|
+
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
|
108
|
+
|
109
|
+
pad_after = (3 - feat_size % 3) % 3
|
110
|
+
if pad_after > 0:
|
111
|
+
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
|
112
|
+
feat_size = feat_size + pad_after
|
113
|
+
|
114
|
+
features = features.reshape(
|
115
|
+
batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
|
116
|
+
)
|
117
|
+
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
|
118
|
+
features = features.reshape(batch_size, -1, 9 * hidden_size)
|
119
|
+
|
120
|
+
return features
|
121
|
+
|
122
|
+
|
123
|
+
class MultimodalProjector(nn.Module):
|
124
|
+
layers: nn.Sequential
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
config: VILAConfig,
|
129
|
+
*args,
|
130
|
+
**kwargs,
|
131
|
+
):
|
132
|
+
super().__init__(*args, **kwargs)
|
133
|
+
|
134
|
+
if config.mm_projector_type == "mlp_downsample_3x3_fix":
|
135
|
+
self.layers = nn.Sequential(
|
136
|
+
DownSample3x3BlockFix(),
|
137
|
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
138
|
+
nn.Linear(
|
139
|
+
config.mm_hidden_size * 9,
|
140
|
+
config.mm_hidden_size * 3,
|
141
|
+
),
|
142
|
+
nn.GELU(),
|
143
|
+
nn.LayerNorm(config.vision_config.hidden_size * 3),
|
144
|
+
nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
|
145
|
+
nn.GELU(),
|
146
|
+
nn.Linear(config.hidden_size, config.hidden_size),
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
raise NotImplementedError(
|
150
|
+
f"Unsupported mm_projector_type: {config.mm_projector_type}"
|
151
|
+
)
|
152
|
+
|
153
|
+
self.layers.type(config.torch_dtype)
|
154
|
+
|
155
|
+
@property
|
156
|
+
def device(self) -> torch.device:
|
157
|
+
return next(self.parameters()).device
|
158
|
+
|
159
|
+
@property
|
160
|
+
def dtype(self) -> torch.dtype:
|
161
|
+
return next(self.parameters()).dtype
|
162
|
+
|
163
|
+
def forward(self, x: Tensor) -> Tensor:
|
164
|
+
"""
|
165
|
+
Args:
|
166
|
+
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
The output tensor of shape (batch_size, image_pad_len, hidden_size).
|
170
|
+
"""
|
171
|
+
|
172
|
+
return self.layers(x.to(device=self.device, dtype=self.dtype))
|
173
|
+
|
174
|
+
|
175
|
+
##### END COPY modeling_vila.py #####
|
176
|
+
|
177
|
+
|
178
|
+
class VILAForConditionalGeneration(nn.Module):
|
179
|
+
config: VILAConfig
|
180
|
+
quant_config: Optional[QuantizationConfig]
|
181
|
+
|
182
|
+
logits_processor: LogitsProcessor
|
183
|
+
pooler: Pooler
|
184
|
+
|
185
|
+
llm: Qwen2ForCausalLM
|
186
|
+
mm_projector: MultimodalProjector
|
187
|
+
vision_tower: SiglipVisionModel
|
188
|
+
|
189
|
+
def __init__(
|
190
|
+
self,
|
191
|
+
config: VILAConfig,
|
192
|
+
quant_config: Optional[QuantizationConfig] = None,
|
193
|
+
prefix: str = "",
|
194
|
+
) -> None:
|
195
|
+
super().__init__()
|
196
|
+
|
197
|
+
self.config = config
|
198
|
+
self.quant_config = quant_config
|
199
|
+
|
200
|
+
self.logits_processor = LogitsProcessor(config)
|
201
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
202
|
+
|
203
|
+
self.llm = Qwen2ForCausalLM(
|
204
|
+
config=config.text_config,
|
205
|
+
quant_config=quant_config,
|
206
|
+
prefix=utils.add_prefix("llm", prefix),
|
207
|
+
)
|
208
|
+
self.mm_projector = MultimodalProjector(config)
|
209
|
+
self.vision_tower = SiglipVisionModel(config.vision_config)
|
210
|
+
|
211
|
+
@property
|
212
|
+
def dtype(self) -> torch.dtype:
|
213
|
+
return self.config.torch_dtype
|
214
|
+
|
215
|
+
def forward(
|
216
|
+
self,
|
217
|
+
input_ids: Tensor,
|
218
|
+
positions: Tensor,
|
219
|
+
forward_batch: ForwardBatch,
|
220
|
+
get_embedding: bool = False,
|
221
|
+
) -> LogitsProcessorOutput:
|
222
|
+
output = mm_utils.general_mm_embed_routine(
|
223
|
+
input_ids=input_ids,
|
224
|
+
forward_batch=forward_batch,
|
225
|
+
language_model=self.llm,
|
226
|
+
image_data_embedding_func=self.get_image_feature,
|
227
|
+
get_embedding=get_embedding,
|
228
|
+
positions=positions,
|
229
|
+
)
|
230
|
+
|
231
|
+
return cast(LogitsProcessorOutput, output)
|
232
|
+
|
233
|
+
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
234
|
+
pixel_values = cast(Tensor, mm_input[0].pixel_values)
|
235
|
+
|
236
|
+
##### BEGIN COPY modeling_vila.py #####
|
237
|
+
|
238
|
+
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
|
239
|
+
pixel_values.to(
|
240
|
+
device=self.vision_tower.device, dtype=self.vision_tower.dtype
|
241
|
+
),
|
242
|
+
output_hidden_states=True,
|
243
|
+
)
|
244
|
+
|
245
|
+
mm_projector_input = self._vision_tower_output_to_mm_projector_input(
|
246
|
+
vision_tower_output
|
247
|
+
)
|
248
|
+
|
249
|
+
image_embedding: Tensor = self.mm_projector.__call__(
|
250
|
+
mm_projector_input.to(
|
251
|
+
device=self.mm_projector.device, dtype=self.mm_projector.dtype
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
##### END COPY modeling_vila.py #####
|
256
|
+
|
257
|
+
return image_embedding
|
258
|
+
|
259
|
+
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None:
|
260
|
+
params_dict = dict(self.named_parameters())
|
261
|
+
|
262
|
+
for name, loaded_weight in weights:
|
263
|
+
if name.startswith("llm."):
|
264
|
+
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
|
265
|
+
else:
|
266
|
+
param = params_dict[name]
|
267
|
+
weight_loader = getattr(
|
268
|
+
param, "weight_loader", weight_utils.default_weight_loader
|
269
|
+
)
|
270
|
+
weight_loader(param, loaded_weight)
|
271
|
+
|
272
|
+
def pad_input_ids(
|
273
|
+
self,
|
274
|
+
input_ids: List[int],
|
275
|
+
image_inputs: MultimodalInputs,
|
276
|
+
) -> List[int]:
|
277
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens(
|
278
|
+
token_ids=[self.config.image_token_id],
|
279
|
+
)
|
280
|
+
|
281
|
+
return pattern.pad_input_tokens(input_ids, image_inputs)
|
282
|
+
|
283
|
+
##### BEGIN COPY modeling_vila.py #####
|
284
|
+
|
285
|
+
def _vision_tower_output_to_mm_projector_input(
|
286
|
+
self,
|
287
|
+
vision_tower_output: BaseModelOutputWithPooling,
|
288
|
+
) -> Tensor:
|
289
|
+
assert vision_tower_output.hidden_states is not None
|
290
|
+
|
291
|
+
selected_layer_hidden_states = vision_tower_output.hidden_states[
|
292
|
+
self.config.mm_vision_select_layer
|
293
|
+
]
|
294
|
+
|
295
|
+
if self.config.mm_vision_select_feature == "cls_patch":
|
296
|
+
return selected_layer_hidden_states
|
297
|
+
else:
|
298
|
+
raise NotImplementedError(
|
299
|
+
f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
|
300
|
+
)
|
301
|
+
|
302
|
+
##### END COPY modeling_vila.py #####
|
303
|
+
|
304
|
+
|
305
|
+
EntryClass = [VILAForConditionalGeneration]
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
|
|
41
41
|
register_conv_template,
|
42
42
|
)
|
43
43
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
44
|
-
from sglang.srt.managers.io_struct import
|
44
|
+
from sglang.srt.managers.io_struct import (
|
45
|
+
EmbeddingReqInput,
|
46
|
+
GenerateReqInput,
|
47
|
+
V1RerankReqInput,
|
48
|
+
)
|
45
49
|
from sglang.srt.openai_api.protocol import (
|
46
50
|
BatchRequest,
|
47
51
|
BatchResponse,
|
@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
|
|
69
73
|
FunctionResponse,
|
70
74
|
LogProbs,
|
71
75
|
MultimodalEmbeddingInput,
|
76
|
+
RerankResponse,
|
72
77
|
ScoringRequest,
|
73
78
|
ScoringResponse,
|
74
79
|
ToolCall,
|
@@ -542,6 +547,7 @@ def v1_generate_request(
|
|
542
547
|
logprob_start_lens = []
|
543
548
|
top_logprobs_nums = []
|
544
549
|
lora_paths = []
|
550
|
+
return_hidden_states = []
|
545
551
|
|
546
552
|
for request in all_requests:
|
547
553
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
@@ -581,6 +587,7 @@ def v1_generate_request(
|
|
581
587
|
"no_stop_trim": request.no_stop_trim,
|
582
588
|
"ignore_eos": request.ignore_eos,
|
583
589
|
"skip_special_tokens": request.skip_special_tokens,
|
590
|
+
"logit_bias": request.logit_bias,
|
584
591
|
}
|
585
592
|
)
|
586
593
|
return_logprobs.append(request.logprobs is not None)
|
@@ -588,6 +595,7 @@ def v1_generate_request(
|
|
588
595
|
top_logprobs_nums.append(
|
589
596
|
request.logprobs if request.logprobs is not None else 0
|
590
597
|
)
|
598
|
+
return_hidden_states.append(request.return_hidden_states)
|
591
599
|
|
592
600
|
if len(all_requests) == 1:
|
593
601
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
@@ -599,6 +607,7 @@ def v1_generate_request(
|
|
599
607
|
logprob_start_lens = logprob_start_lens[0]
|
600
608
|
top_logprobs_nums = top_logprobs_nums[0]
|
601
609
|
lora_paths = lora_paths[0]
|
610
|
+
return_hidden_states = return_hidden_states[0]
|
602
611
|
else:
|
603
612
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
604
613
|
prompt_kwargs = {"text": prompts}
|
@@ -615,6 +624,7 @@ def v1_generate_request(
|
|
615
624
|
stream=all_requests[0].stream,
|
616
625
|
rid=request_ids,
|
617
626
|
lora_path=lora_paths,
|
627
|
+
return_hidden_states=return_hidden_states,
|
618
628
|
bootstrap_host=all_requests[0].bootstrap_host,
|
619
629
|
bootstrap_port=all_requests[0].bootstrap_port,
|
620
630
|
bootstrap_room=all_requests[0].bootstrap_room,
|
@@ -683,6 +693,16 @@ def v1_generate_response(
|
|
683
693
|
else:
|
684
694
|
logprobs = None
|
685
695
|
|
696
|
+
hidden_states = None
|
697
|
+
if isinstance(request, list) and request[idx].return_hidden_states:
|
698
|
+
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
699
|
+
elif (not isinstance(request, list)) and request.return_hidden_states:
|
700
|
+
hidden_states = ret_item["meta_info"].get("hidden_states", None)
|
701
|
+
if hidden_states is not None:
|
702
|
+
hidden_states = (
|
703
|
+
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
|
704
|
+
)
|
705
|
+
|
686
706
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
687
707
|
|
688
708
|
if to_file:
|
@@ -698,6 +718,8 @@ def v1_generate_response(
|
|
698
718
|
else None
|
699
719
|
),
|
700
720
|
}
|
721
|
+
if hidden_states is not None:
|
722
|
+
choice_data["hidden_states"] = hidden_states
|
701
723
|
else:
|
702
724
|
choice_data = CompletionResponseChoice(
|
703
725
|
index=idx,
|
@@ -709,6 +731,7 @@ def v1_generate_response(
|
|
709
731
|
if finish_reason and "matched" in finish_reason
|
710
732
|
else None
|
711
733
|
),
|
734
|
+
hidden_states=hidden_states,
|
712
735
|
)
|
713
736
|
|
714
737
|
choices.append(choice_data)
|
@@ -777,6 +800,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
777
800
|
prompt_tokens = {}
|
778
801
|
completion_tokens = {}
|
779
802
|
cached_tokens = {}
|
803
|
+
hidden_states = {}
|
780
804
|
|
781
805
|
try:
|
782
806
|
async for content in tokenizer_manager.generate_request(
|
@@ -791,6 +815,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
791
815
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
792
816
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
793
817
|
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
|
818
|
+
hidden_states[index] = content["meta_info"].get(
|
819
|
+
"hidden_states", None
|
820
|
+
) or hidden_states.get(index)
|
794
821
|
|
795
822
|
if not stream_buffer: # The first chunk
|
796
823
|
if request.echo:
|
@@ -873,6 +900,27 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
873
900
|
n_prev_tokens[index] = n_prev_token
|
874
901
|
|
875
902
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
903
|
+
if request.return_hidden_states and hidden_states:
|
904
|
+
for index, choice_hidden_states in hidden_states.items():
|
905
|
+
last_token_hidden_states = (
|
906
|
+
choice_hidden_states[-1]
|
907
|
+
if choice_hidden_states and len(choice_hidden_states) > 1
|
908
|
+
else []
|
909
|
+
)
|
910
|
+
hidden_states_chunk = CompletionStreamResponse(
|
911
|
+
id=content["meta_info"]["id"],
|
912
|
+
created=created,
|
913
|
+
choices=[
|
914
|
+
CompletionResponseStreamChoice(
|
915
|
+
text="",
|
916
|
+
index=index,
|
917
|
+
hidden_states=last_token_hidden_states,
|
918
|
+
finish_reason=None,
|
919
|
+
)
|
920
|
+
],
|
921
|
+
model=request.model,
|
922
|
+
)
|
923
|
+
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
876
924
|
if request.stream_options and request.stream_options.include_usage:
|
877
925
|
total_prompt_tokens = sum(
|
878
926
|
tokens
|
@@ -973,6 +1021,7 @@ def v1_chat_generate_request(
|
|
973
1021
|
top_logprobs_nums = []
|
974
1022
|
modalities_list = []
|
975
1023
|
lora_paths = []
|
1024
|
+
return_hidden_states = []
|
976
1025
|
|
977
1026
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
978
1027
|
|
@@ -1176,6 +1225,7 @@ def v1_chat_generate_request(
|
|
1176
1225
|
"no_stop_trim": request.no_stop_trim,
|
1177
1226
|
"ignore_eos": request.ignore_eos,
|
1178
1227
|
"skip_special_tokens": request.skip_special_tokens,
|
1228
|
+
"logit_bias": request.logit_bias,
|
1179
1229
|
}
|
1180
1230
|
|
1181
1231
|
if request.response_format and request.response_format.type == "json_schema":
|
@@ -1215,6 +1265,7 @@ def v1_chat_generate_request(
|
|
1215
1265
|
image_data_list.append(image_data)
|
1216
1266
|
audio_data_list.append(audio_data)
|
1217
1267
|
modalities_list.append(modalities)
|
1268
|
+
return_hidden_states.append(request.return_hidden_states)
|
1218
1269
|
if len(all_requests) == 1:
|
1219
1270
|
if is_multimodal:
|
1220
1271
|
# processor will need text input
|
@@ -1233,6 +1284,7 @@ def v1_chat_generate_request(
|
|
1233
1284
|
modalities_list = modalities_list[0]
|
1234
1285
|
lora_paths = lora_paths[0]
|
1235
1286
|
request_ids = request_ids[0]
|
1287
|
+
return_hidden_states = return_hidden_states[0]
|
1236
1288
|
else:
|
1237
1289
|
if tokenizer_manager.model_config.is_multimodal:
|
1238
1290
|
# processor will need text input
|
@@ -1259,6 +1311,7 @@ def v1_chat_generate_request(
|
|
1259
1311
|
bootstrap_host=all_requests[0].bootstrap_host,
|
1260
1312
|
bootstrap_port=all_requests[0].bootstrap_port,
|
1261
1313
|
bootstrap_room=all_requests[0].bootstrap_room,
|
1314
|
+
return_hidden_states=return_hidden_states,
|
1262
1315
|
)
|
1263
1316
|
|
1264
1317
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
@@ -1319,6 +1372,20 @@ def v1_chat_generate_response(
|
|
1319
1372
|
else:
|
1320
1373
|
choice_logprobs = None
|
1321
1374
|
|
1375
|
+
if isinstance(request, list) and request[idx].return_hidden_states:
|
1376
|
+
include_hidden_states = True
|
1377
|
+
elif not isinstance(request, list) and request.return_hidden_states:
|
1378
|
+
include_hidden_states = True
|
1379
|
+
else:
|
1380
|
+
include_hidden_states = False
|
1381
|
+
if include_hidden_states and ret_item["meta_info"].get("hidden_states", None):
|
1382
|
+
hidden_states = ret_item["meta_info"]["hidden_states"]
|
1383
|
+
hidden_states = (
|
1384
|
+
hidden_states[-1] if hidden_states and len(hidden_states) > 1 else []
|
1385
|
+
)
|
1386
|
+
else:
|
1387
|
+
hidden_states = None
|
1388
|
+
|
1322
1389
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
1323
1390
|
|
1324
1391
|
tool_calls = None
|
@@ -1391,6 +1458,8 @@ def v1_chat_generate_response(
|
|
1391
1458
|
else None
|
1392
1459
|
),
|
1393
1460
|
}
|
1461
|
+
if hidden_states is not None:
|
1462
|
+
choice_data["hidden_states"] = hidden_states
|
1394
1463
|
else:
|
1395
1464
|
choice_data = ChatCompletionResponseChoice(
|
1396
1465
|
index=idx,
|
@@ -1407,6 +1476,7 @@ def v1_chat_generate_response(
|
|
1407
1476
|
if finish_reason and "matched" in finish_reason
|
1408
1477
|
else None
|
1409
1478
|
),
|
1479
|
+
hidden_states=hidden_states,
|
1410
1480
|
)
|
1411
1481
|
|
1412
1482
|
choices.append(choice_data)
|
@@ -1479,19 +1549,23 @@ async def v1_chat_completions(
|
|
1479
1549
|
reasoning_parser_dict = {}
|
1480
1550
|
|
1481
1551
|
async def generate_stream_resp():
|
1482
|
-
|
1552
|
+
tool_index_previous = -1
|
1483
1553
|
is_firsts = {}
|
1484
1554
|
stream_buffers = {}
|
1485
1555
|
n_prev_tokens = {}
|
1486
1556
|
prompt_tokens = {}
|
1487
1557
|
completion_tokens = {}
|
1488
1558
|
cached_tokens = {}
|
1559
|
+
hidden_states = {}
|
1489
1560
|
try:
|
1490
1561
|
async for content in tokenizer_manager.generate_request(
|
1491
1562
|
adapted_request, raw_request
|
1492
1563
|
):
|
1493
1564
|
index = content.get("index", 0)
|
1494
1565
|
text = content["text"]
|
1566
|
+
hidden_states[index] = content["meta_info"].get(
|
1567
|
+
"hidden_states", None
|
1568
|
+
) or hidden_states.get(index)
|
1495
1569
|
|
1496
1570
|
is_first = is_firsts.get(index, True)
|
1497
1571
|
stream_buffer = stream_buffers.get(index, "")
|
@@ -1613,6 +1687,7 @@ async def v1_chat_completions(
|
|
1613
1687
|
if (delta and len(delta) == 0) or not delta:
|
1614
1688
|
stream_buffers[index] = new_stream_buffer
|
1615
1689
|
is_firsts[index] = is_first
|
1690
|
+
n_prev_tokens[index] = n_prev_token
|
1616
1691
|
continue
|
1617
1692
|
|
1618
1693
|
if request.tool_choice != "none" and request.tools:
|
@@ -1645,6 +1720,7 @@ async def v1_chat_completions(
|
|
1645
1720
|
|
1646
1721
|
# 2) if we found calls, we output them as separate chunk(s)
|
1647
1722
|
for call_item in calls:
|
1723
|
+
tool_index_current = call_item.tool_index
|
1648
1724
|
# transform call_item -> FunctionResponse + ToolCall
|
1649
1725
|
if finish_reason_type == "stop":
|
1650
1726
|
latest_delta_len = 0
|
@@ -1671,7 +1747,7 @@ async def v1_chat_completions(
|
|
1671
1747
|
tool_call = ToolCall(
|
1672
1748
|
id=(
|
1673
1749
|
f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
|
1674
|
-
if
|
1750
|
+
if tool_index_previous != tool_index_current
|
1675
1751
|
else None
|
1676
1752
|
),
|
1677
1753
|
index=call_item.tool_index,
|
@@ -1680,7 +1756,7 @@ async def v1_chat_completions(
|
|
1680
1756
|
arguments=call_item.parameters,
|
1681
1757
|
),
|
1682
1758
|
)
|
1683
|
-
|
1759
|
+
tool_index_previous = tool_index_current
|
1684
1760
|
choice_data = ChatCompletionResponseStreamChoice(
|
1685
1761
|
index=index,
|
1686
1762
|
delta=DeltaMessage(tool_calls=[tool_call]),
|
@@ -1701,6 +1777,7 @@ async def v1_chat_completions(
|
|
1701
1777
|
|
1702
1778
|
stream_buffers[index] = new_stream_buffer
|
1703
1779
|
is_firsts[index] = is_first
|
1780
|
+
n_prev_tokens[index] = n_prev_token
|
1704
1781
|
|
1705
1782
|
else:
|
1706
1783
|
# No tool calls => just treat this as normal text
|
@@ -1733,6 +1810,7 @@ async def v1_chat_completions(
|
|
1733
1810
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
1734
1811
|
stream_buffers[index] = new_stream_buffer
|
1735
1812
|
is_firsts[index] = is_first
|
1813
|
+
n_prev_tokens[index] = n_prev_token
|
1736
1814
|
if finish_reason_type == "stop" and request.tool_choice != "none":
|
1737
1815
|
parser = FunctionCallParser(
|
1738
1816
|
tools=request.tools,
|
@@ -1768,6 +1846,28 @@ async def v1_chat_completions(
|
|
1768
1846
|
|
1769
1847
|
else:
|
1770
1848
|
usage = None
|
1849
|
+
if request.return_hidden_states and hidden_states:
|
1850
|
+
for index, choice_hidden_states in hidden_states.items():
|
1851
|
+
last_token_hidden_states = (
|
1852
|
+
choice_hidden_states[-1]
|
1853
|
+
if choice_hidden_states and len(choice_hidden_states) > 1
|
1854
|
+
else []
|
1855
|
+
)
|
1856
|
+
hidden_states_chunk = ChatCompletionStreamResponse(
|
1857
|
+
id=content["meta_info"]["id"],
|
1858
|
+
created=created,
|
1859
|
+
choices=[
|
1860
|
+
ChatCompletionResponseStreamChoice(
|
1861
|
+
index=index,
|
1862
|
+
delta=DeltaMessage(
|
1863
|
+
hidden_states=last_token_hidden_states
|
1864
|
+
),
|
1865
|
+
finish_reason=finish_reason_type,
|
1866
|
+
)
|
1867
|
+
],
|
1868
|
+
model=request.model,
|
1869
|
+
)
|
1870
|
+
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
|
1771
1871
|
final_usage_chunk = ChatCompletionStreamResponse(
|
1772
1872
|
id=content["meta_info"]["id"],
|
1773
1873
|
created=created,
|
@@ -1925,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
|
1925
2025
|
return response
|
1926
2026
|
|
1927
2027
|
|
2028
|
+
def v1_rerank_request(obj: V1RerankReqInput):
|
2029
|
+
if obj.query is None:
|
2030
|
+
raise ValueError("query is required")
|
2031
|
+
if obj.documents is None or len(obj.documents) == 0:
|
2032
|
+
raise ValueError("documents is required")
|
2033
|
+
|
2034
|
+
pairs = []
|
2035
|
+
for doc in obj.documents:
|
2036
|
+
pairs.append([obj.query, doc])
|
2037
|
+
|
2038
|
+
adapted_request = EmbeddingReqInput(
|
2039
|
+
text=pairs,
|
2040
|
+
is_cross_encoder_request=True,
|
2041
|
+
)
|
2042
|
+
|
2043
|
+
return adapted_request
|
2044
|
+
|
2045
|
+
|
2046
|
+
def v1_rerank_response(ret, obj: V1RerankReqInput):
|
2047
|
+
|
2048
|
+
response = []
|
2049
|
+
for idx, ret_item in enumerate(ret):
|
2050
|
+
response.append(
|
2051
|
+
RerankResponse(
|
2052
|
+
score=ret[idx]["embedding"],
|
2053
|
+
document=obj.documents[idx],
|
2054
|
+
index=idx,
|
2055
|
+
meta_info=ret[idx]["meta_info"],
|
2056
|
+
)
|
2057
|
+
)
|
2058
|
+
|
2059
|
+
response.sort(key=lambda x: x.score, reverse=True)
|
2060
|
+
|
2061
|
+
return response
|
2062
|
+
|
2063
|
+
|
2064
|
+
async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
|
2065
|
+
adapted_request = v1_rerank_request(obj)
|
2066
|
+
|
2067
|
+
try:
|
2068
|
+
ret = await tokenizer_manager.generate_request(
|
2069
|
+
adapted_request, raw_request
|
2070
|
+
).__anext__()
|
2071
|
+
|
2072
|
+
except ValueError as e:
|
2073
|
+
return create_error_response(str(e))
|
2074
|
+
|
2075
|
+
if not isinstance(ret, list):
|
2076
|
+
ret = [ret]
|
2077
|
+
|
2078
|
+
response = v1_rerank_response(
|
2079
|
+
ret,
|
2080
|
+
obj,
|
2081
|
+
)
|
2082
|
+
|
2083
|
+
return response
|
2084
|
+
|
2085
|
+
|
1928
2086
|
def to_openai_style_logprobs(
|
1929
2087
|
input_token_logprobs=None,
|
1930
2088
|
output_token_logprobs=None,
|