sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
6
6
  import torch
7
7
  from torch import nn
8
8
 
9
- from sglang.srt.layers.pooler import Pooler, PoolingType
9
+ from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
10
10
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
11
11
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
12
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
16
16
  RobertaConfig = None
17
17
 
18
18
 
19
+ # Adapted from transformers
20
+ class RobertaClassificationHead(nn.Module):
21
+ """Head for sentence-level classification tasks."""
22
+
23
+ def __init__(self, config: RobertaConfig):
24
+ super().__init__()
25
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
26
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
27
+
28
+ def forward(self, features, **kwargs):
29
+ x = features[0, :] # take <s> token (equiv. to [CLS])
30
+ x = self.dense(x)
31
+ x = torch.tanh(x)
32
+ x = self.out_proj(x)
33
+ return x
34
+
35
+
19
36
  class RobertaEmbedding(nn.Module):
20
37
 
21
38
  def __init__(self, config: RobertaConfig):
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
51
68
  input_ids: torch.Tensor,
52
69
  seq_lens: torch.Tensor,
53
70
  position_ids: torch.Tensor,
54
- inputs_embeds=None,
55
- token_type_ids: Optional[torch.Tensor] = None,
71
+ forward_batch: ForwardBatch,
56
72
  ) -> torch.Tensor:
57
73
  input_shape = input_ids.size()
58
74
  inputs_embeds = self.word_embeddings(input_ids)
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
82
98
 
83
99
  # Position embeddings.
84
100
  position_embeddings = self.position_embeddings(position_ids)
101
+
102
+ token_type_ids = forward_batch.token_type_ids
85
103
  if token_type_ids is None:
86
104
  token_type_ids = torch.zeros(
87
105
  input_shape, dtype=torch.long, device=inputs_embeds.device
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
93
111
  return embeddings
94
112
 
95
113
 
96
- class XLMRobertaModel(nn.Module):
114
+ class XLMRobertaBaseModel(nn.Module):
97
115
  def __init__(
98
116
  self,
99
117
  *,
100
118
  config: RobertaConfig,
101
119
  quant_config: Optional[QuantizationConfig] = None,
102
120
  prefix: str = "",
121
+ add_pooling_layer: bool = False,
103
122
  ):
104
123
  super().__init__()
105
124
 
106
125
  self.config = config
107
126
  self.embeddings = RobertaEmbedding(config)
108
127
  self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
109
- self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
128
+ self.pooler = (
129
+ Pooler(pooling_type=PoolingType.CLS, normalize=True)
130
+ if add_pooling_layer
131
+ else None
132
+ )
110
133
 
111
134
  @torch.no_grad()
112
135
  def forward(
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
124
147
  input_ids=input_ids,
125
148
  position_ids=positions,
126
149
  seq_lens=forward_batch.seq_lens,
150
+ forward_batch=forward_batch,
127
151
  )
128
152
 
129
153
  hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
130
- pooler_out = self.pooler(hidden_states, forward_batch)
131
- return pooler_out
154
+
155
+ return hidden_states
132
156
 
133
157
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
134
158
  stacked_params_mapping = [
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
141
165
  params_dict = dict(self.named_parameters())
142
166
  for name, loaded_weight in weights:
143
167
  name = name.replace("self", "self_attn")
144
- if "pooler" in name:
168
+ if self.pooler is None and "pooler" in name:
145
169
  continue
146
170
  for param_name, weight_name, shard_id in stacked_params_mapping:
147
171
 
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
175
199
  return incremental_indices.long() + padding_idx
176
200
 
177
201
 
178
- EntryClass = [XLMRobertaModel]
202
+ class XLMRobertaModel(nn.Module):
203
+ def __init__(
204
+ self,
205
+ *,
206
+ config: RobertaConfig,
207
+ quant_config: Optional[QuantizationConfig] = None,
208
+ prefix: str = "",
209
+ ):
210
+ super().__init__()
211
+ self.roberta = XLMRobertaBaseModel(
212
+ config=config, quant_config=quant_config, prefix=prefix
213
+ )
214
+ self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
215
+
216
+ def forward(
217
+ self,
218
+ input_ids: torch.Tensor,
219
+ positions: torch.Tensor,
220
+ forward_batch: ForwardBatch,
221
+ input_embeds: torch.Tensor = None,
222
+ get_embedding: bool = False,
223
+ ) -> torch.Tensor:
224
+ hidden_states = self.roberta(
225
+ input_ids, positions, forward_batch, input_embeds, get_embedding
226
+ )
227
+ return self.pooler(hidden_states, forward_batch)
228
+
229
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
230
+ self.roberta.load_weights(weights)
231
+
232
+
233
+ class XLMRobertaForSequenceClassification(nn.Module):
234
+ def __init__(
235
+ self,
236
+ *,
237
+ config: RobertaConfig,
238
+ quant_config: Optional[QuantizationConfig] = None,
239
+ prefix: str = "",
240
+ ):
241
+ super().__init__()
242
+ self.roberta = XLMRobertaBaseModel(
243
+ config=config, quant_config=quant_config, prefix=prefix
244
+ )
245
+ self.classifier = RobertaClassificationHead(config)
246
+ self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler)
247
+
248
+ def forward(
249
+ self,
250
+ input_ids: torch.Tensor,
251
+ positions: torch.Tensor,
252
+ forward_batch: ForwardBatch,
253
+ input_embeds: torch.Tensor = None,
254
+ get_embedding: bool = True,
255
+ ) -> torch.Tensor:
256
+ assert (
257
+ get_embedding
258
+ ), "XLMRobertaForSequenceClassification is only used for rerank"
259
+
260
+ hidden_states = self.roberta(
261
+ input_ids, positions, forward_batch, input_embeds, get_embedding
262
+ )
263
+ return self.pooler(hidden_states, forward_batch)
264
+
265
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
266
+ self_weights = []
267
+
268
+ def weight_filter():
269
+ for name, weight in weights:
270
+ if name.startswith("roberta."):
271
+ yield (name[len("roberta.") :], weight)
272
+ else:
273
+ self_weights.append((name, weight))
274
+
275
+ self.roberta.load_weights(weight_filter())
276
+
277
+ params_dict = dict(self.named_parameters())
278
+
279
+ for name, loaded_weight in self_weights:
280
+ if name.startswith("classifier"):
281
+ param = params_dict[name]
282
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
283
+ weight_loader(param, loaded_weight)
284
+
285
+
286
+ EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification]
@@ -0,0 +1,305 @@
1
+ import logging
2
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
10
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
11
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
12
+
13
+ import sglang.srt.managers.mm_utils as mm_utils
14
+ import sglang.srt.model_loader.weight_utils as weight_utils
15
+ import sglang.srt.utils as utils
16
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
17
+ from sglang.srt.layers.pooler import Pooler, PoolingType
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
20
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
21
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ ##### BEGIN COPY configuration.py #####
28
+
29
+
30
+ class VILAConfig(PretrainedConfig):
31
+ # Class attributes.
32
+ model_type: str = "vila"
33
+ sub_configs: Dict[str, PretrainedConfig] = {
34
+ "text_config": Qwen2Config(),
35
+ "vision_config": SiglipVisionConfig(),
36
+ }
37
+ _auto_class: Optional[str] = "AutoConfig"
38
+
39
+ # Configuration for sub-modules.
40
+ text_config: Qwen2Config = Qwen2Config()
41
+ vision_config: SiglipVisionConfig = SiglipVisionConfig()
42
+
43
+ # Model configuration.
44
+ hidden_size: int
45
+ image_token_id: int
46
+ mm_hidden_size: int
47
+ mm_projector_type: str
48
+ mm_vision_select_feature: str
49
+ mm_vision_select_layer: int
50
+ video_token_id: int
51
+
52
+ def __init__(
53
+ self,
54
+ text_config: Optional[Dict[str, Any]] = None,
55
+ vision_config: Optional[Dict[str, Any]] = None,
56
+ *,
57
+ hidden_size: int = 1536,
58
+ image_token_id: int = 151649,
59
+ mm_hidden_size: int = 1152,
60
+ mm_projector_type: str = "mlp_downsample_3x3_fix",
61
+ mm_vision_select_feature: str = "cls_patch",
62
+ mm_vision_select_layer: int = -2,
63
+ video_token_id: int = 151650,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(**kwargs)
67
+
68
+ self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
69
+ self.vision_config = (
70
+ SiglipVisionConfig(**vision_config)
71
+ if vision_config
72
+ else SiglipVisionConfig()
73
+ )
74
+
75
+ self.hidden_size = hidden_size
76
+ self.image_token_id = image_token_id
77
+ self.mm_hidden_size = mm_hidden_size
78
+ self.mm_projector_type = mm_projector_type
79
+ self.mm_vision_select_feature = mm_vision_select_feature
80
+ self.mm_vision_select_layer = mm_vision_select_layer
81
+ self.video_token_id = video_token_id
82
+
83
+
84
+ ##### END COPY configuration.py #####
85
+
86
+ ##### BEGIN COPY modeling_vila.py #####
87
+
88
+
89
+ class DownSample3x3BlockFix(nn.Module):
90
+ def forward(self, x: Tensor) -> Tensor:
91
+ """
92
+ Args:
93
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
94
+
95
+ Returns:
96
+ The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
97
+ """
98
+
99
+ batch_size, sequence_length, hidden_size = x.shape
100
+
101
+ feat_size = int(sequence_length**0.5)
102
+ if feat_size**2 != sequence_length:
103
+ raise ValueError(
104
+ f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
105
+ )
106
+
107
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
108
+
109
+ pad_after = (3 - feat_size % 3) % 3
110
+ if pad_after > 0:
111
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
112
+ feat_size = feat_size + pad_after
113
+
114
+ features = features.reshape(
115
+ batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
116
+ )
117
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
118
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
119
+
120
+ return features
121
+
122
+
123
+ class MultimodalProjector(nn.Module):
124
+ layers: nn.Sequential
125
+
126
+ def __init__(
127
+ self,
128
+ config: VILAConfig,
129
+ *args,
130
+ **kwargs,
131
+ ):
132
+ super().__init__(*args, **kwargs)
133
+
134
+ if config.mm_projector_type == "mlp_downsample_3x3_fix":
135
+ self.layers = nn.Sequential(
136
+ DownSample3x3BlockFix(),
137
+ nn.LayerNorm(config.mm_hidden_size * 9),
138
+ nn.Linear(
139
+ config.mm_hidden_size * 9,
140
+ config.mm_hidden_size * 3,
141
+ ),
142
+ nn.GELU(),
143
+ nn.LayerNorm(config.vision_config.hidden_size * 3),
144
+ nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
145
+ nn.GELU(),
146
+ nn.Linear(config.hidden_size, config.hidden_size),
147
+ )
148
+ else:
149
+ raise NotImplementedError(
150
+ f"Unsupported mm_projector_type: {config.mm_projector_type}"
151
+ )
152
+
153
+ self.layers.type(config.torch_dtype)
154
+
155
+ @property
156
+ def device(self) -> torch.device:
157
+ return next(self.parameters()).device
158
+
159
+ @property
160
+ def dtype(self) -> torch.dtype:
161
+ return next(self.parameters()).dtype
162
+
163
+ def forward(self, x: Tensor) -> Tensor:
164
+ """
165
+ Args:
166
+ x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
167
+
168
+ Returns:
169
+ The output tensor of shape (batch_size, image_pad_len, hidden_size).
170
+ """
171
+
172
+ return self.layers(x.to(device=self.device, dtype=self.dtype))
173
+
174
+
175
+ ##### END COPY modeling_vila.py #####
176
+
177
+
178
+ class VILAForConditionalGeneration(nn.Module):
179
+ config: VILAConfig
180
+ quant_config: Optional[QuantizationConfig]
181
+
182
+ logits_processor: LogitsProcessor
183
+ pooler: Pooler
184
+
185
+ llm: Qwen2ForCausalLM
186
+ mm_projector: MultimodalProjector
187
+ vision_tower: SiglipVisionModel
188
+
189
+ def __init__(
190
+ self,
191
+ config: VILAConfig,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ) -> None:
195
+ super().__init__()
196
+
197
+ self.config = config
198
+ self.quant_config = quant_config
199
+
200
+ self.logits_processor = LogitsProcessor(config)
201
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
202
+
203
+ self.llm = Qwen2ForCausalLM(
204
+ config=config.text_config,
205
+ quant_config=quant_config,
206
+ prefix=utils.add_prefix("llm", prefix),
207
+ )
208
+ self.mm_projector = MultimodalProjector(config)
209
+ self.vision_tower = SiglipVisionModel(config.vision_config)
210
+
211
+ @property
212
+ def dtype(self) -> torch.dtype:
213
+ return self.config.torch_dtype
214
+
215
+ def forward(
216
+ self,
217
+ input_ids: Tensor,
218
+ positions: Tensor,
219
+ forward_batch: ForwardBatch,
220
+ get_embedding: bool = False,
221
+ ) -> LogitsProcessorOutput:
222
+ output = mm_utils.general_mm_embed_routine(
223
+ input_ids=input_ids,
224
+ forward_batch=forward_batch,
225
+ language_model=self.llm,
226
+ image_data_embedding_func=self.get_image_feature,
227
+ get_embedding=get_embedding,
228
+ positions=positions,
229
+ )
230
+
231
+ return cast(LogitsProcessorOutput, output)
232
+
233
+ def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
234
+ pixel_values = cast(Tensor, mm_input[0].pixel_values)
235
+
236
+ ##### BEGIN COPY modeling_vila.py #####
237
+
238
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
239
+ pixel_values.to(
240
+ device=self.vision_tower.device, dtype=self.vision_tower.dtype
241
+ ),
242
+ output_hidden_states=True,
243
+ )
244
+
245
+ mm_projector_input = self._vision_tower_output_to_mm_projector_input(
246
+ vision_tower_output
247
+ )
248
+
249
+ image_embedding: Tensor = self.mm_projector.__call__(
250
+ mm_projector_input.to(
251
+ device=self.mm_projector.device, dtype=self.mm_projector.dtype
252
+ )
253
+ )
254
+
255
+ ##### END COPY modeling_vila.py #####
256
+
257
+ return image_embedding
258
+
259
+ def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None:
260
+ params_dict = dict(self.named_parameters())
261
+
262
+ for name, loaded_weight in weights:
263
+ if name.startswith("llm."):
264
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
265
+ else:
266
+ param = params_dict[name]
267
+ weight_loader = getattr(
268
+ param, "weight_loader", weight_utils.default_weight_loader
269
+ )
270
+ weight_loader(param, loaded_weight)
271
+
272
+ def pad_input_ids(
273
+ self,
274
+ input_ids: List[int],
275
+ image_inputs: MultimodalInputs,
276
+ ) -> List[int]:
277
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens(
278
+ token_ids=[self.config.image_token_id],
279
+ )
280
+
281
+ return pattern.pad_input_tokens(input_ids, image_inputs)
282
+
283
+ ##### BEGIN COPY modeling_vila.py #####
284
+
285
+ def _vision_tower_output_to_mm_projector_input(
286
+ self,
287
+ vision_tower_output: BaseModelOutputWithPooling,
288
+ ) -> Tensor:
289
+ assert vision_tower_output.hidden_states is not None
290
+
291
+ selected_layer_hidden_states = vision_tower_output.hidden_states[
292
+ self.config.mm_vision_select_layer
293
+ ]
294
+
295
+ if self.config.mm_vision_select_feature == "cls_patch":
296
+ return selected_layer_hidden_states
297
+ else:
298
+ raise NotImplementedError(
299
+ f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
300
+ )
301
+
302
+ ##### END COPY modeling_vila.py #####
303
+
304
+
305
+ EntryClass = [VILAForConditionalGeneration]
@@ -1,4 +1,4 @@
1
- from typing import Dict, Tuple
1
+ from typing import Dict, Optional, Tuple, Type
2
2
 
3
3
 
4
4
  class StreamingParseResult:
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- text = text.replace(self.think_start_token, "").strip()
36
- if self.think_end_token not in text:
35
+ in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
36
+
37
+ if not in_reasoning:
38
+ return StreamingParseResult(normal_text=text)
39
+
40
+ # The text is considered to be in a reasoning block.
41
+ processed_text = text.replace(self.think_start_token, "").strip()
42
+
43
+ if self.think_end_token not in processed_text:
37
44
  # Assume reasoning was truncated before `</think>` token
38
- return StreamingParseResult(reasoning_text=text)
45
+ return StreamingParseResult(reasoning_text=processed_text)
39
46
 
40
47
  # Extract reasoning content
41
- splits = text.split(self.think_end_token, maxsplit=1)
48
+ splits = processed_text.split(self.think_end_token, maxsplit=1)
42
49
  reasoning_text = splits[0]
43
- text = splits[1].strip()
50
+ normal_text = splits[1].strip()
44
51
 
45
- return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
52
+ return StreamingParseResult(
53
+ normal_text=normal_text, reasoning_text=reasoning_text
54
+ )
46
55
 
47
56
  def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
48
57
  """
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
61
70
  if not self.stripped_think_start and self.think_start_token in current_text:
62
71
  current_text = current_text.replace(self.think_start_token, "")
63
72
  self.stripped_think_start = True
73
+ self._in_reasoning = True
64
74
 
65
75
  # Handle end of reasoning block
66
76
  if self._in_reasoning and self.think_end_token in current_text:
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
131
141
  """
132
142
 
133
143
  def __init__(self, stream_reasoning: bool = True):
134
- # Qwen3 is assumed to be reasoning until `</think>` token
144
+ # Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
135
145
  super().__init__(
136
146
  "<think>",
137
147
  "</think>",
138
- force_reasoning=True,
148
+ force_reasoning=False,
139
149
  stream_reasoning=stream_reasoning,
140
150
  )
141
151
 
@@ -151,12 +161,12 @@ class ReasoningParser:
151
161
  If True, streams reasoning content as it arrives.
152
162
  """
153
163
 
154
- DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
164
+ DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
155
165
  "deepseek-r1": DeepSeekR1Detector,
156
166
  "qwen3": Qwen3Detector,
157
167
  }
158
168
 
159
- def __init__(self, model_type: str = None, stream_reasoning: bool = True):
169
+ def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
160
170
  if not model_type:
161
171
  raise ValueError("Model type must be specified")
162
172
 
@@ -10,6 +10,7 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
  from sglang.srt.sampling.sampling_params import TOP_K_ALL
13
+ from sglang.srt.utils import merge_bias_tensor
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
63
64
  # Device
64
65
  device: str = "cuda"
65
66
 
67
+ # Handle logit bias
68
+ logit_bias: Optional[torch.Tensor] = None
69
+
66
70
  @classmethod
67
71
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
68
72
  reqs = batch.reqs
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
85
89
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
86
90
  ).to(device, non_blocking=True)
87
91
 
92
+ logit_bias = None
93
+ if any(r.sampling_params.logit_bias is not None for r in reqs):
94
+ logit_bias = torch.zeros(len(reqs), vocab_size, device=device)
95
+ for i, r in enumerate(reqs):
96
+ if r.sampling_params.logit_bias is not None:
97
+ for key, value in r.sampling_params.logit_bias.items():
98
+ logit_bias[i, int(key)] = value
99
+
88
100
  # Check if any request has custom logit processor
89
101
  has_custom_logit_processor = (
90
102
  batch.enable_custom_logit_processor # check the flag first.
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
150
162
  custom_params=custom_params,
151
163
  custom_logit_processor=merged_custom_logit_processor,
152
164
  device=device,
165
+ logit_bias=logit_bias,
153
166
  )
154
167
  return ret
155
168
 
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
206
219
  if self.vocab_mask is not None:
207
220
  self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
208
221
 
222
+ if self.logit_bias is not None:
223
+ logits.add_(self.logit_bias)
224
+
209
225
  def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
210
226
  self.penalizer_orchestrator.filter(keep_indices_device)
211
227
 
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
221
237
  value = getattr(self, item, None)
222
238
  setattr(self, item, value[keep_indices_device])
223
239
 
240
+ if self.logit_bias is not None:
241
+ self.logit_bias = self.logit_bias[keep_indices_device]
242
+
224
243
  def _filter_batch_custom_logit_processor(
225
244
  self, keep_indices: List[int], keep_indices_device: torch.Tensor
226
245
  ):
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
321
340
  self.need_top_p_sampling |= other.need_top_p_sampling
322
341
  self.need_top_k_sampling |= other.need_top_k_sampling
323
342
  self.need_min_p_sampling |= other.need_min_p_sampling
343
+
344
+ # Merge logit bias
345
+ self.logit_bias = merge_bias_tensor(
346
+ self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
347
+ )
@@ -52,6 +52,7 @@ class SamplingParams:
52
52
  no_stop_trim: bool = False,
53
53
  custom_params: Optional[Dict[str, Any]] = None,
54
54
  stream_interval: Optional[int] = None,
55
+ logit_bias: Optional[Dict[str, float]] = None,
55
56
  ) -> None:
56
57
  self.max_new_tokens = max_new_tokens
57
58
  self.stop_strs = stop
@@ -78,6 +79,7 @@ class SamplingParams:
78
79
  self.no_stop_trim = no_stop_trim
79
80
  self.custom_params = custom_params
80
81
  self.stream_interval = stream_interval
82
+ self.logit_bias = logit_bias
81
83
 
82
84
  # Process some special cases
83
85
  if 0 <= self.temperature < _SAMPLING_EPS: