sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/models/roberta.py
CHANGED
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
8
|
|
9
|
-
from sglang.srt.layers.pooler import Pooler, PoolingType
|
9
|
+
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
10
10
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
11
11
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
12
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
|
|
16
16
|
RobertaConfig = None
|
17
17
|
|
18
18
|
|
19
|
+
# Adapted from transformers
|
20
|
+
class RobertaClassificationHead(nn.Module):
|
21
|
+
"""Head for sentence-level classification tasks."""
|
22
|
+
|
23
|
+
def __init__(self, config: RobertaConfig):
|
24
|
+
super().__init__()
|
25
|
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
26
|
+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
27
|
+
|
28
|
+
def forward(self, features, **kwargs):
|
29
|
+
x = features[0, :] # take <s> token (equiv. to [CLS])
|
30
|
+
x = self.dense(x)
|
31
|
+
x = torch.tanh(x)
|
32
|
+
x = self.out_proj(x)
|
33
|
+
return x
|
34
|
+
|
35
|
+
|
19
36
|
class RobertaEmbedding(nn.Module):
|
20
37
|
|
21
38
|
def __init__(self, config: RobertaConfig):
|
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
|
|
51
68
|
input_ids: torch.Tensor,
|
52
69
|
seq_lens: torch.Tensor,
|
53
70
|
position_ids: torch.Tensor,
|
54
|
-
|
55
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
71
|
+
forward_batch: ForwardBatch,
|
56
72
|
) -> torch.Tensor:
|
57
73
|
input_shape = input_ids.size()
|
58
74
|
inputs_embeds = self.word_embeddings(input_ids)
|
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
|
|
82
98
|
|
83
99
|
# Position embeddings.
|
84
100
|
position_embeddings = self.position_embeddings(position_ids)
|
101
|
+
|
102
|
+
token_type_ids = forward_batch.token_type_ids
|
85
103
|
if token_type_ids is None:
|
86
104
|
token_type_ids = torch.zeros(
|
87
105
|
input_shape, dtype=torch.long, device=inputs_embeds.device
|
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
|
|
93
111
|
return embeddings
|
94
112
|
|
95
113
|
|
96
|
-
class
|
114
|
+
class XLMRobertaBaseModel(nn.Module):
|
97
115
|
def __init__(
|
98
116
|
self,
|
99
117
|
*,
|
100
118
|
config: RobertaConfig,
|
101
119
|
quant_config: Optional[QuantizationConfig] = None,
|
102
120
|
prefix: str = "",
|
121
|
+
add_pooling_layer: bool = False,
|
103
122
|
):
|
104
123
|
super().__init__()
|
105
124
|
|
106
125
|
self.config = config
|
107
126
|
self.embeddings = RobertaEmbedding(config)
|
108
127
|
self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
|
109
|
-
self.pooler =
|
128
|
+
self.pooler = (
|
129
|
+
Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
130
|
+
if add_pooling_layer
|
131
|
+
else None
|
132
|
+
)
|
110
133
|
|
111
134
|
@torch.no_grad()
|
112
135
|
def forward(
|
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
|
|
124
147
|
input_ids=input_ids,
|
125
148
|
position_ids=positions,
|
126
149
|
seq_lens=forward_batch.seq_lens,
|
150
|
+
forward_batch=forward_batch,
|
127
151
|
)
|
128
152
|
|
129
153
|
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
130
|
-
|
131
|
-
return
|
154
|
+
|
155
|
+
return hidden_states
|
132
156
|
|
133
157
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
134
158
|
stacked_params_mapping = [
|
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
|
|
141
165
|
params_dict = dict(self.named_parameters())
|
142
166
|
for name, loaded_weight in weights:
|
143
167
|
name = name.replace("self", "self_attn")
|
144
|
-
if "pooler" in name:
|
168
|
+
if self.pooler is None and "pooler" in name:
|
145
169
|
continue
|
146
170
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
147
171
|
|
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
|
|
175
199
|
return incremental_indices.long() + padding_idx
|
176
200
|
|
177
201
|
|
178
|
-
|
202
|
+
class XLMRobertaModel(nn.Module):
|
203
|
+
def __init__(
|
204
|
+
self,
|
205
|
+
*,
|
206
|
+
config: RobertaConfig,
|
207
|
+
quant_config: Optional[QuantizationConfig] = None,
|
208
|
+
prefix: str = "",
|
209
|
+
):
|
210
|
+
super().__init__()
|
211
|
+
self.roberta = XLMRobertaBaseModel(
|
212
|
+
config=config, quant_config=quant_config, prefix=prefix
|
213
|
+
)
|
214
|
+
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
215
|
+
|
216
|
+
def forward(
|
217
|
+
self,
|
218
|
+
input_ids: torch.Tensor,
|
219
|
+
positions: torch.Tensor,
|
220
|
+
forward_batch: ForwardBatch,
|
221
|
+
input_embeds: torch.Tensor = None,
|
222
|
+
get_embedding: bool = False,
|
223
|
+
) -> torch.Tensor:
|
224
|
+
hidden_states = self.roberta(
|
225
|
+
input_ids, positions, forward_batch, input_embeds, get_embedding
|
226
|
+
)
|
227
|
+
return self.pooler(hidden_states, forward_batch)
|
228
|
+
|
229
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
230
|
+
self.roberta.load_weights(weights)
|
231
|
+
|
232
|
+
|
233
|
+
class XLMRobertaForSequenceClassification(nn.Module):
|
234
|
+
def __init__(
|
235
|
+
self,
|
236
|
+
*,
|
237
|
+
config: RobertaConfig,
|
238
|
+
quant_config: Optional[QuantizationConfig] = None,
|
239
|
+
prefix: str = "",
|
240
|
+
):
|
241
|
+
super().__init__()
|
242
|
+
self.roberta = XLMRobertaBaseModel(
|
243
|
+
config=config, quant_config=quant_config, prefix=prefix
|
244
|
+
)
|
245
|
+
self.classifier = RobertaClassificationHead(config)
|
246
|
+
self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler)
|
247
|
+
|
248
|
+
def forward(
|
249
|
+
self,
|
250
|
+
input_ids: torch.Tensor,
|
251
|
+
positions: torch.Tensor,
|
252
|
+
forward_batch: ForwardBatch,
|
253
|
+
input_embeds: torch.Tensor = None,
|
254
|
+
get_embedding: bool = True,
|
255
|
+
) -> torch.Tensor:
|
256
|
+
assert (
|
257
|
+
get_embedding
|
258
|
+
), "XLMRobertaForSequenceClassification is only used for rerank"
|
259
|
+
|
260
|
+
hidden_states = self.roberta(
|
261
|
+
input_ids, positions, forward_batch, input_embeds, get_embedding
|
262
|
+
)
|
263
|
+
return self.pooler(hidden_states, forward_batch)
|
264
|
+
|
265
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
266
|
+
self_weights = []
|
267
|
+
|
268
|
+
def weight_filter():
|
269
|
+
for name, weight in weights:
|
270
|
+
if name.startswith("roberta."):
|
271
|
+
yield (name[len("roberta.") :], weight)
|
272
|
+
else:
|
273
|
+
self_weights.append((name, weight))
|
274
|
+
|
275
|
+
self.roberta.load_weights(weight_filter())
|
276
|
+
|
277
|
+
params_dict = dict(self.named_parameters())
|
278
|
+
|
279
|
+
for name, loaded_weight in self_weights:
|
280
|
+
if name.startswith("classifier"):
|
281
|
+
param = params_dict[name]
|
282
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
283
|
+
weight_loader(param, loaded_weight)
|
284
|
+
|
285
|
+
|
286
|
+
EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification]
|
@@ -0,0 +1,305 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
from torch import Tensor
|
8
|
+
from transformers.configuration_utils import PretrainedConfig
|
9
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
10
|
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
11
|
+
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
|
12
|
+
|
13
|
+
import sglang.srt.managers.mm_utils as mm_utils
|
14
|
+
import sglang.srt.model_loader.weight_utils as weight_utils
|
15
|
+
import sglang.srt.utils as utils
|
16
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
17
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
18
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
20
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
21
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
22
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
##### BEGIN COPY configuration.py #####
|
28
|
+
|
29
|
+
|
30
|
+
class VILAConfig(PretrainedConfig):
|
31
|
+
# Class attributes.
|
32
|
+
model_type: str = "vila"
|
33
|
+
sub_configs: Dict[str, PretrainedConfig] = {
|
34
|
+
"text_config": Qwen2Config(),
|
35
|
+
"vision_config": SiglipVisionConfig(),
|
36
|
+
}
|
37
|
+
_auto_class: Optional[str] = "AutoConfig"
|
38
|
+
|
39
|
+
# Configuration for sub-modules.
|
40
|
+
text_config: Qwen2Config = Qwen2Config()
|
41
|
+
vision_config: SiglipVisionConfig = SiglipVisionConfig()
|
42
|
+
|
43
|
+
# Model configuration.
|
44
|
+
hidden_size: int
|
45
|
+
image_token_id: int
|
46
|
+
mm_hidden_size: int
|
47
|
+
mm_projector_type: str
|
48
|
+
mm_vision_select_feature: str
|
49
|
+
mm_vision_select_layer: int
|
50
|
+
video_token_id: int
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
text_config: Optional[Dict[str, Any]] = None,
|
55
|
+
vision_config: Optional[Dict[str, Any]] = None,
|
56
|
+
*,
|
57
|
+
hidden_size: int = 1536,
|
58
|
+
image_token_id: int = 151649,
|
59
|
+
mm_hidden_size: int = 1152,
|
60
|
+
mm_projector_type: str = "mlp_downsample_3x3_fix",
|
61
|
+
mm_vision_select_feature: str = "cls_patch",
|
62
|
+
mm_vision_select_layer: int = -2,
|
63
|
+
video_token_id: int = 151650,
|
64
|
+
**kwargs,
|
65
|
+
):
|
66
|
+
super().__init__(**kwargs)
|
67
|
+
|
68
|
+
self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
|
69
|
+
self.vision_config = (
|
70
|
+
SiglipVisionConfig(**vision_config)
|
71
|
+
if vision_config
|
72
|
+
else SiglipVisionConfig()
|
73
|
+
)
|
74
|
+
|
75
|
+
self.hidden_size = hidden_size
|
76
|
+
self.image_token_id = image_token_id
|
77
|
+
self.mm_hidden_size = mm_hidden_size
|
78
|
+
self.mm_projector_type = mm_projector_type
|
79
|
+
self.mm_vision_select_feature = mm_vision_select_feature
|
80
|
+
self.mm_vision_select_layer = mm_vision_select_layer
|
81
|
+
self.video_token_id = video_token_id
|
82
|
+
|
83
|
+
|
84
|
+
##### END COPY configuration.py #####
|
85
|
+
|
86
|
+
##### BEGIN COPY modeling_vila.py #####
|
87
|
+
|
88
|
+
|
89
|
+
class DownSample3x3BlockFix(nn.Module):
|
90
|
+
def forward(self, x: Tensor) -> Tensor:
|
91
|
+
"""
|
92
|
+
Args:
|
93
|
+
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
|
97
|
+
"""
|
98
|
+
|
99
|
+
batch_size, sequence_length, hidden_size = x.shape
|
100
|
+
|
101
|
+
feat_size = int(sequence_length**0.5)
|
102
|
+
if feat_size**2 != sequence_length:
|
103
|
+
raise ValueError(
|
104
|
+
f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
|
105
|
+
)
|
106
|
+
|
107
|
+
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
|
108
|
+
|
109
|
+
pad_after = (3 - feat_size % 3) % 3
|
110
|
+
if pad_after > 0:
|
111
|
+
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
|
112
|
+
feat_size = feat_size + pad_after
|
113
|
+
|
114
|
+
features = features.reshape(
|
115
|
+
batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
|
116
|
+
)
|
117
|
+
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
|
118
|
+
features = features.reshape(batch_size, -1, 9 * hidden_size)
|
119
|
+
|
120
|
+
return features
|
121
|
+
|
122
|
+
|
123
|
+
class MultimodalProjector(nn.Module):
|
124
|
+
layers: nn.Sequential
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
config: VILAConfig,
|
129
|
+
*args,
|
130
|
+
**kwargs,
|
131
|
+
):
|
132
|
+
super().__init__(*args, **kwargs)
|
133
|
+
|
134
|
+
if config.mm_projector_type == "mlp_downsample_3x3_fix":
|
135
|
+
self.layers = nn.Sequential(
|
136
|
+
DownSample3x3BlockFix(),
|
137
|
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
138
|
+
nn.Linear(
|
139
|
+
config.mm_hidden_size * 9,
|
140
|
+
config.mm_hidden_size * 3,
|
141
|
+
),
|
142
|
+
nn.GELU(),
|
143
|
+
nn.LayerNorm(config.vision_config.hidden_size * 3),
|
144
|
+
nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
|
145
|
+
nn.GELU(),
|
146
|
+
nn.Linear(config.hidden_size, config.hidden_size),
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
raise NotImplementedError(
|
150
|
+
f"Unsupported mm_projector_type: {config.mm_projector_type}"
|
151
|
+
)
|
152
|
+
|
153
|
+
self.layers.type(config.torch_dtype)
|
154
|
+
|
155
|
+
@property
|
156
|
+
def device(self) -> torch.device:
|
157
|
+
return next(self.parameters()).device
|
158
|
+
|
159
|
+
@property
|
160
|
+
def dtype(self) -> torch.dtype:
|
161
|
+
return next(self.parameters()).dtype
|
162
|
+
|
163
|
+
def forward(self, x: Tensor) -> Tensor:
|
164
|
+
"""
|
165
|
+
Args:
|
166
|
+
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
The output tensor of shape (batch_size, image_pad_len, hidden_size).
|
170
|
+
"""
|
171
|
+
|
172
|
+
return self.layers(x.to(device=self.device, dtype=self.dtype))
|
173
|
+
|
174
|
+
|
175
|
+
##### END COPY modeling_vila.py #####
|
176
|
+
|
177
|
+
|
178
|
+
class VILAForConditionalGeneration(nn.Module):
|
179
|
+
config: VILAConfig
|
180
|
+
quant_config: Optional[QuantizationConfig]
|
181
|
+
|
182
|
+
logits_processor: LogitsProcessor
|
183
|
+
pooler: Pooler
|
184
|
+
|
185
|
+
llm: Qwen2ForCausalLM
|
186
|
+
mm_projector: MultimodalProjector
|
187
|
+
vision_tower: SiglipVisionModel
|
188
|
+
|
189
|
+
def __init__(
|
190
|
+
self,
|
191
|
+
config: VILAConfig,
|
192
|
+
quant_config: Optional[QuantizationConfig] = None,
|
193
|
+
prefix: str = "",
|
194
|
+
) -> None:
|
195
|
+
super().__init__()
|
196
|
+
|
197
|
+
self.config = config
|
198
|
+
self.quant_config = quant_config
|
199
|
+
|
200
|
+
self.logits_processor = LogitsProcessor(config)
|
201
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
202
|
+
|
203
|
+
self.llm = Qwen2ForCausalLM(
|
204
|
+
config=config.text_config,
|
205
|
+
quant_config=quant_config,
|
206
|
+
prefix=utils.add_prefix("llm", prefix),
|
207
|
+
)
|
208
|
+
self.mm_projector = MultimodalProjector(config)
|
209
|
+
self.vision_tower = SiglipVisionModel(config.vision_config)
|
210
|
+
|
211
|
+
@property
|
212
|
+
def dtype(self) -> torch.dtype:
|
213
|
+
return self.config.torch_dtype
|
214
|
+
|
215
|
+
def forward(
|
216
|
+
self,
|
217
|
+
input_ids: Tensor,
|
218
|
+
positions: Tensor,
|
219
|
+
forward_batch: ForwardBatch,
|
220
|
+
get_embedding: bool = False,
|
221
|
+
) -> LogitsProcessorOutput:
|
222
|
+
output = mm_utils.general_mm_embed_routine(
|
223
|
+
input_ids=input_ids,
|
224
|
+
forward_batch=forward_batch,
|
225
|
+
language_model=self.llm,
|
226
|
+
image_data_embedding_func=self.get_image_feature,
|
227
|
+
get_embedding=get_embedding,
|
228
|
+
positions=positions,
|
229
|
+
)
|
230
|
+
|
231
|
+
return cast(LogitsProcessorOutput, output)
|
232
|
+
|
233
|
+
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
234
|
+
pixel_values = cast(Tensor, mm_input[0].pixel_values)
|
235
|
+
|
236
|
+
##### BEGIN COPY modeling_vila.py #####
|
237
|
+
|
238
|
+
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
|
239
|
+
pixel_values.to(
|
240
|
+
device=self.vision_tower.device, dtype=self.vision_tower.dtype
|
241
|
+
),
|
242
|
+
output_hidden_states=True,
|
243
|
+
)
|
244
|
+
|
245
|
+
mm_projector_input = self._vision_tower_output_to_mm_projector_input(
|
246
|
+
vision_tower_output
|
247
|
+
)
|
248
|
+
|
249
|
+
image_embedding: Tensor = self.mm_projector.__call__(
|
250
|
+
mm_projector_input.to(
|
251
|
+
device=self.mm_projector.device, dtype=self.mm_projector.dtype
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
##### END COPY modeling_vila.py #####
|
256
|
+
|
257
|
+
return image_embedding
|
258
|
+
|
259
|
+
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None:
|
260
|
+
params_dict = dict(self.named_parameters())
|
261
|
+
|
262
|
+
for name, loaded_weight in weights:
|
263
|
+
if name.startswith("llm."):
|
264
|
+
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
|
265
|
+
else:
|
266
|
+
param = params_dict[name]
|
267
|
+
weight_loader = getattr(
|
268
|
+
param, "weight_loader", weight_utils.default_weight_loader
|
269
|
+
)
|
270
|
+
weight_loader(param, loaded_weight)
|
271
|
+
|
272
|
+
def pad_input_ids(
|
273
|
+
self,
|
274
|
+
input_ids: List[int],
|
275
|
+
image_inputs: MultimodalInputs,
|
276
|
+
) -> List[int]:
|
277
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens(
|
278
|
+
token_ids=[self.config.image_token_id],
|
279
|
+
)
|
280
|
+
|
281
|
+
return pattern.pad_input_tokens(input_ids, image_inputs)
|
282
|
+
|
283
|
+
##### BEGIN COPY modeling_vila.py #####
|
284
|
+
|
285
|
+
def _vision_tower_output_to_mm_projector_input(
|
286
|
+
self,
|
287
|
+
vision_tower_output: BaseModelOutputWithPooling,
|
288
|
+
) -> Tensor:
|
289
|
+
assert vision_tower_output.hidden_states is not None
|
290
|
+
|
291
|
+
selected_layer_hidden_states = vision_tower_output.hidden_states[
|
292
|
+
self.config.mm_vision_select_layer
|
293
|
+
]
|
294
|
+
|
295
|
+
if self.config.mm_vision_select_feature == "cls_patch":
|
296
|
+
return selected_layer_hidden_states
|
297
|
+
else:
|
298
|
+
raise NotImplementedError(
|
299
|
+
f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
|
300
|
+
)
|
301
|
+
|
302
|
+
##### END COPY modeling_vila.py #####
|
303
|
+
|
304
|
+
|
305
|
+
EntryClass = [VILAForConditionalGeneration]
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Tuple
|
1
|
+
from typing import Dict, Optional, Tuple, Type
|
2
2
|
|
3
3
|
|
4
4
|
class StreamingParseResult:
|
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
|
36
|
-
|
35
|
+
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
|
36
|
+
|
37
|
+
if not in_reasoning:
|
38
|
+
return StreamingParseResult(normal_text=text)
|
39
|
+
|
40
|
+
# The text is considered to be in a reasoning block.
|
41
|
+
processed_text = text.replace(self.think_start_token, "").strip()
|
42
|
+
|
43
|
+
if self.think_end_token not in processed_text:
|
37
44
|
# Assume reasoning was truncated before `</think>` token
|
38
|
-
return StreamingParseResult(reasoning_text=
|
45
|
+
return StreamingParseResult(reasoning_text=processed_text)
|
39
46
|
|
40
47
|
# Extract reasoning content
|
41
|
-
splits =
|
48
|
+
splits = processed_text.split(self.think_end_token, maxsplit=1)
|
42
49
|
reasoning_text = splits[0]
|
43
|
-
|
50
|
+
normal_text = splits[1].strip()
|
44
51
|
|
45
|
-
return StreamingParseResult(
|
52
|
+
return StreamingParseResult(
|
53
|
+
normal_text=normal_text, reasoning_text=reasoning_text
|
54
|
+
)
|
46
55
|
|
47
56
|
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
48
57
|
"""
|
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
|
|
61
70
|
if not self.stripped_think_start and self.think_start_token in current_text:
|
62
71
|
current_text = current_text.replace(self.think_start_token, "")
|
63
72
|
self.stripped_think_start = True
|
73
|
+
self._in_reasoning = True
|
64
74
|
|
65
75
|
# Handle end of reasoning block
|
66
76
|
if self._in_reasoning and self.think_end_token in current_text:
|
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
|
|
131
141
|
"""
|
132
142
|
|
133
143
|
def __init__(self, stream_reasoning: bool = True):
|
134
|
-
# Qwen3
|
144
|
+
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
|
135
145
|
super().__init__(
|
136
146
|
"<think>",
|
137
147
|
"</think>",
|
138
|
-
force_reasoning=
|
148
|
+
force_reasoning=False,
|
139
149
|
stream_reasoning=stream_reasoning,
|
140
150
|
)
|
141
151
|
|
@@ -151,12 +161,12 @@ class ReasoningParser:
|
|
151
161
|
If True, streams reasoning content as it arrives.
|
152
162
|
"""
|
153
163
|
|
154
|
-
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
164
|
+
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
|
155
165
|
"deepseek-r1": DeepSeekR1Detector,
|
156
166
|
"qwen3": Qwen3Detector,
|
157
167
|
}
|
158
168
|
|
159
|
-
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
169
|
+
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
|
160
170
|
if not model_type:
|
161
171
|
raise ValueError("Model type must be specified")
|
162
172
|
|
@@ -10,6 +10,7 @@ import torch
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
12
|
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
13
|
+
from sglang.srt.utils import merge_bias_tensor
|
13
14
|
|
14
15
|
if TYPE_CHECKING:
|
15
16
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
|
|
63
64
|
# Device
|
64
65
|
device: str = "cuda"
|
65
66
|
|
67
|
+
# Handle logit bias
|
68
|
+
logit_bias: Optional[torch.Tensor] = None
|
69
|
+
|
66
70
|
@classmethod
|
67
71
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
68
72
|
reqs = batch.reqs
|
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
|
|
85
89
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
86
90
|
).to(device, non_blocking=True)
|
87
91
|
|
92
|
+
logit_bias = None
|
93
|
+
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
94
|
+
logit_bias = torch.zeros(len(reqs), vocab_size, device=device)
|
95
|
+
for i, r in enumerate(reqs):
|
96
|
+
if r.sampling_params.logit_bias is not None:
|
97
|
+
for key, value in r.sampling_params.logit_bias.items():
|
98
|
+
logit_bias[i, int(key)] = value
|
99
|
+
|
88
100
|
# Check if any request has custom logit processor
|
89
101
|
has_custom_logit_processor = (
|
90
102
|
batch.enable_custom_logit_processor # check the flag first.
|
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
|
|
150
162
|
custom_params=custom_params,
|
151
163
|
custom_logit_processor=merged_custom_logit_processor,
|
152
164
|
device=device,
|
165
|
+
logit_bias=logit_bias,
|
153
166
|
)
|
154
167
|
return ret
|
155
168
|
|
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
|
|
206
219
|
if self.vocab_mask is not None:
|
207
220
|
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
208
221
|
|
222
|
+
if self.logit_bias is not None:
|
223
|
+
logits.add_(self.logit_bias)
|
224
|
+
|
209
225
|
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
210
226
|
self.penalizer_orchestrator.filter(keep_indices_device)
|
211
227
|
|
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
|
|
221
237
|
value = getattr(self, item, None)
|
222
238
|
setattr(self, item, value[keep_indices_device])
|
223
239
|
|
240
|
+
if self.logit_bias is not None:
|
241
|
+
self.logit_bias = self.logit_bias[keep_indices_device]
|
242
|
+
|
224
243
|
def _filter_batch_custom_logit_processor(
|
225
244
|
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
226
245
|
):
|
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
|
|
321
340
|
self.need_top_p_sampling |= other.need_top_p_sampling
|
322
341
|
self.need_top_k_sampling |= other.need_top_k_sampling
|
323
342
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
343
|
+
|
344
|
+
# Merge logit bias
|
345
|
+
self.logit_bias = merge_bias_tensor(
|
346
|
+
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
347
|
+
)
|
@@ -52,6 +52,7 @@ class SamplingParams:
|
|
52
52
|
no_stop_trim: bool = False,
|
53
53
|
custom_params: Optional[Dict[str, Any]] = None,
|
54
54
|
stream_interval: Optional[int] = None,
|
55
|
+
logit_bias: Optional[Dict[str, float]] = None,
|
55
56
|
) -> None:
|
56
57
|
self.max_new_tokens = max_new_tokens
|
57
58
|
self.stop_strs = stop
|
@@ -78,6 +79,7 @@ class SamplingParams:
|
|
78
79
|
self.no_stop_trim = no_stop_trim
|
79
80
|
self.custom_params = custom_params
|
80
81
|
self.stream_interval = stream_interval
|
82
|
+
self.logit_bias = logit_bias
|
81
83
|
|
82
84
|
# Process some special cases
|
83
85
|
if 0 <= self.temperature < _SAMPLING_EPS:
|