sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,568 @@
1
+ # Adapted from
2
+ # https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/clip/modeling_clip.py
3
+
4
+ from functools import partial
5
+ from typing import Iterable, List, Optional, Tuple, Type, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
10
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
11
+
12
+ from sglang.srt.layers.activation import QuickGELU
13
+ from sglang.srt.layers.attention.vision import VisionAttention
14
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
15
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
16
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
18
+ from sglang.srt.model_executor.model_runner import ForwardBatch
19
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
20
+ from sglang.srt.utils import add_prefix, flatten_nested_list
21
+
22
+
23
+ class CLIPVisionEmbeddings(nn.Module):
24
+
25
+ def __init__(self, config: CLIPVisionConfig):
26
+ super().__init__()
27
+ self.config = config
28
+ self.embed_dim = config.hidden_size
29
+ self.image_size = config.image_size
30
+ self.patch_size = config.patch_size
31
+ assert self.image_size % self.patch_size == 0
32
+
33
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
34
+
35
+ self.patch_embedding = nn.Conv2d(
36
+ in_channels=config.num_channels,
37
+ out_channels=self.embed_dim,
38
+ kernel_size=self.patch_size,
39
+ stride=self.patch_size,
40
+ bias=False,
41
+ )
42
+
43
+ self.num_patches = (self.image_size // self.patch_size) ** 2
44
+ self.num_positions = self.num_patches + 1
45
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
46
+ self.register_buffer(
47
+ "position_ids",
48
+ torch.arange(self.num_positions).expand((1, -1)),
49
+ persistent=False,
50
+ )
51
+
52
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
53
+ batch_size = pixel_values.shape[0]
54
+ target_dtype = self.patch_embedding.weight.dtype
55
+ patch_embeds = self.patch_embedding(
56
+ pixel_values.to(dtype=target_dtype)
57
+ ) # shape = [*, width, grid, grid]
58
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
59
+
60
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
61
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
62
+ embeddings = embeddings + self.position_embedding(self.position_ids)
63
+
64
+ return embeddings
65
+
66
+
67
+ class CLIPTextEmbeddings(nn.Module):
68
+ def __init__(self, config: CLIPTextConfig):
69
+ super().__init__()
70
+ embed_dim = config.hidden_size
71
+
72
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
73
+ self.position_embedding = nn.Embedding(
74
+ config.max_position_embeddings, embed_dim
75
+ )
76
+
77
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
78
+ self.register_buffer(
79
+ "position_ids",
80
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
81
+ persistent=False,
82
+ )
83
+
84
+ def forward(
85
+ self,
86
+ input_ids: Optional[torch.LongTensor] = None,
87
+ position_ids: Optional[torch.LongTensor] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ ) -> torch.Tensor:
90
+ seq_length = (
91
+ input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
92
+ )
93
+
94
+ if position_ids is None:
95
+ position_ids = self.position_ids[:, :seq_length]
96
+
97
+ if inputs_embeds is None:
98
+ inputs_embeds = self.token_embedding(input_ids)
99
+
100
+ position_embeddings = self.position_embedding(position_ids)
101
+ embeddings = inputs_embeds + position_embeddings
102
+
103
+ return embeddings
104
+
105
+
106
+ class CLIPMLP(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ config,
111
+ act_layer: Type[nn.Module] = QuickGELU,
112
+ quant_config: Optional[QuantizationConfig] = None,
113
+ prefix: str = "",
114
+ ):
115
+ super().__init__()
116
+ self.fc1 = ColumnParallelLinear(
117
+ config.hidden_size,
118
+ config.intermediate_size,
119
+ quant_config=quant_config,
120
+ prefix=add_prefix("fc1", prefix),
121
+ )
122
+ self.act = act_layer()
123
+ self.fc2 = RowParallelLinear(
124
+ config.intermediate_size,
125
+ config.hidden_size,
126
+ quant_config=quant_config,
127
+ prefix=add_prefix("fc2", prefix),
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x_parallel, _ = self.fc1(x)
132
+ x_parallel = self.act(x_parallel)
133
+ x, _ = self.fc2(x_parallel)
134
+ return x
135
+
136
+
137
+ class CLIPEncoderLayer(nn.Module):
138
+
139
+ def __init__(
140
+ self,
141
+ config: CLIPVisionConfig,
142
+ act_layer: Type[nn.Module] = QuickGELU,
143
+ norm_layer: Type[nn.Module] = None,
144
+ attn_implementation: Optional[str] = "sdpa",
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ prefix: str = "",
147
+ ) -> None:
148
+ super().__init__()
149
+ if norm_layer is None:
150
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
151
+ self.layer_norm1 = norm_layer(config.hidden_size)
152
+ self.layer_norm2 = norm_layer(config.hidden_size)
153
+ if attn_implementation == "sdpa":
154
+ use_context_forward = False
155
+ softmax_in_single_precision = False
156
+ elif attn_implementation == "flash_attention_2":
157
+ softmax_in_single_precision = False
158
+ use_context_forward = True
159
+ elif attn_implementation == "eager":
160
+ softmax_in_single_precision = True
161
+ use_context_forward = False
162
+ self.self_attn = VisionAttention(
163
+ embed_dim=config.hidden_size,
164
+ num_heads=config.num_attention_heads,
165
+ projection_size=config.hidden_size,
166
+ use_qkv_parallel=True,
167
+ use_context_forward=use_context_forward,
168
+ softmax_in_single_precision=softmax_in_single_precision,
169
+ flatten_batch=True,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("attn", prefix),
172
+ )
173
+ self.mlp = CLIPMLP(
174
+ config,
175
+ act_layer=act_layer,
176
+ quant_config=quant_config,
177
+ prefix=add_prefix("mlp", prefix),
178
+ )
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ attention_mask: torch.Tensor,
184
+ causal_attention_mask: torch.Tensor,
185
+ ) -> torch.Tensor:
186
+
187
+ residual = hidden_states
188
+ hidden_states = self.layer_norm1(hidden_states)
189
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask`
190
+ if attention_mask is not None and causal_attention_mask is not None:
191
+ attn_mask = attention_mask + causal_attention_mask
192
+ elif causal_attention_mask is not None:
193
+ attn_mask = causal_attention_mask
194
+ else:
195
+ attn_mask = attention_mask
196
+ hidden_states = self.self_attn(
197
+ hidden_states,
198
+ attention_mask=attn_mask,
199
+ # causal_attention_mask=causal_attention_mask,
200
+ )
201
+
202
+ hidden_states = residual + hidden_states
203
+ residual = hidden_states
204
+ hidden_states = self.layer_norm2(hidden_states)
205
+ hidden_states = self.mlp(hidden_states)
206
+ hidden_states = residual + hidden_states
207
+ return hidden_states
208
+
209
+
210
+ class CLIPEncoder(nn.Module):
211
+ """
212
+ Transformer encoder consisting of `config.num_hidden_layers` self
213
+ attention layers. Each layer is a [`CLIPEncoderLayer`].
214
+
215
+ Args:
216
+ config: CLIPConfig
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ config: CLIPVisionConfig,
222
+ quant_config: Optional[QuantizationConfig] = None,
223
+ prefix: str = "",
224
+ ) -> None:
225
+ super().__init__()
226
+
227
+ self.config = config
228
+
229
+ num_hidden_layers = config.num_hidden_layers
230
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
231
+ self.layers = nn.ModuleList(
232
+ [
233
+ CLIPEncoderLayer(
234
+ config=config,
235
+ norm_layer=norm_layer,
236
+ attn_implementation="sdpa",
237
+ quant_config=quant_config,
238
+ prefix=add_prefix(f"layers.{layer_idx}", prefix),
239
+ )
240
+ for layer_idx in range(num_hidden_layers)
241
+ ]
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ inputs_embeds: torch.Tensor,
247
+ attention_mask: torch.Tensor = None,
248
+ causal_attention_mask: torch.Tensor = None,
249
+ return_all_hidden_states: bool = False,
250
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
251
+ hidden_states_pool = [inputs_embeds]
252
+ hidden_states = inputs_embeds
253
+
254
+ for encoder_layer in self.layers:
255
+ hidden_states = encoder_layer(
256
+ hidden_states, attention_mask, causal_attention_mask
257
+ )
258
+ if return_all_hidden_states:
259
+ hidden_states_pool.append(hidden_states)
260
+ if return_all_hidden_states:
261
+ return hidden_states_pool
262
+ return hidden_states
263
+
264
+
265
+ class CLIPTextTransformer(nn.Module):
266
+ def __init__(
267
+ self,
268
+ config: CLIPTextConfig,
269
+ quant_config: Optional[QuantizationConfig] = None,
270
+ prefix: str = "",
271
+ ) -> None:
272
+ super().__init__()
273
+ self.config = config
274
+ embed_dim = config.hidden_size
275
+ self.embeddings = CLIPTextEmbeddings(config)
276
+ self.encoder = CLIPEncoder(
277
+ config=config,
278
+ quant_config=quant_config,
279
+ prefix=add_prefix("encoder", prefix),
280
+ )
281
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
282
+
283
+ @property
284
+ def device(self) -> torch.device:
285
+ return self.encoder.layers[0].layer_norm1.weight.device
286
+
287
+ def forward(
288
+ self,
289
+ input_ids: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.Tensor] = None,
292
+ ):
293
+ input_shape = input_ids.size()
294
+ input_ids = input_ids.view(-1, input_shape[-1])
295
+ hidden_states = self.embeddings(input_ids, position_ids)
296
+ causal_attention_mask = _create_4d_causal_attention_mask(
297
+ input_ids.shape, hidden_states.dtype, device=hidden_states.device
298
+ )
299
+ encoder_outputs = self.encoder(
300
+ hidden_states, attention_mask, causal_attention_mask
301
+ )
302
+ last_hidden_state = self.final_layer_norm(encoder_outputs)
303
+ return last_hidden_state
304
+
305
+
306
+ class CLIPTextModel(nn.Module):
307
+ def __init__(
308
+ self,
309
+ config: CLIPTextConfig,
310
+ quant_config: Optional[QuantizationConfig] = None,
311
+ prefix: str = "",
312
+ ) -> None:
313
+ super().__init__()
314
+ self.config = config
315
+ self.text_model = CLIPTextTransformer(
316
+ config=config,
317
+ quant_config=quant_config,
318
+ prefix=add_prefix("text_model", prefix),
319
+ )
320
+
321
+ def forward(
322
+ self,
323
+ input_ids: torch.Tensor,
324
+ position_ids: torch.Tensor,
325
+ ):
326
+ return self.text_model(input_ids, position_ids)
327
+
328
+
329
+ class CLIPVisionTransformer(nn.Module):
330
+
331
+ def __init__(
332
+ self,
333
+ config: CLIPVisionConfig,
334
+ quant_config: Optional[QuantizationConfig] = None,
335
+ prefix: str = "",
336
+ ) -> None:
337
+ super().__init__()
338
+
339
+ self.config = config
340
+ embed_dim = config.hidden_size
341
+
342
+ self.embeddings = CLIPVisionEmbeddings(config)
343
+
344
+ # NOTE: This typo of "layrnorm" is not fixed on purpose to match
345
+ # the original transformers code and name of the model weights.
346
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
347
+
348
+ self.encoder = CLIPEncoder(
349
+ config=config,
350
+ quant_config=quant_config,
351
+ prefix=add_prefix("encoder", prefix),
352
+ )
353
+
354
+ num_hidden_layers = config.num_hidden_layers
355
+ if len(self.encoder.layers) > config.num_hidden_layers:
356
+ raise ValueError(
357
+ f"The original encoder only has {num_hidden_layers} "
358
+ f"layers, but you requested {len(self.encoder.layers)} layers."
359
+ )
360
+
361
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
362
+
363
+ @property
364
+ def device(self) -> torch.device:
365
+ return self.encoder.layers[0].layer_norm1.weight.device
366
+
367
+ def forward(
368
+ self,
369
+ pixel_values: torch.Tensor,
370
+ ) -> torch.Tensor:
371
+ hidden_states = self.embeddings(pixel_values.to(self.device))
372
+ hidden_states = self.pre_layrnorm(hidden_states)
373
+
374
+ return_all_hidden_states = False
375
+
376
+ last_hidden_state = self.encoder(
377
+ inputs_embeds=hidden_states,
378
+ return_all_hidden_states=return_all_hidden_states,
379
+ )
380
+
381
+ last_hidden_state = self.post_layernorm(last_hidden_state)
382
+
383
+ return last_hidden_state
384
+
385
+
386
+ class CLIPVisionModel(nn.Module):
387
+ def __init__(
388
+ self,
389
+ config: CLIPVisionConfig,
390
+ quant_config: Optional[QuantizationConfig] = None,
391
+ prefix: str = "",
392
+ ):
393
+ super().__init__()
394
+ self.vision_model = CLIPVisionTransformer(
395
+ config, quant_config, prefix=add_prefix("vision_model", prefix)
396
+ )
397
+
398
+ def forward(self, pixel_values: torch.Tensor):
399
+ return self.vision_model(pixel_values)
400
+
401
+
402
+ class CLIPModel(nn.Module):
403
+ def __init__(
404
+ self,
405
+ config: CLIPConfig,
406
+ quant_config: Optional[QuantizationConfig] = None,
407
+ prefix: str = "",
408
+ ) -> None:
409
+ super().__init__()
410
+ self.config = config
411
+ if not isinstance(config.text_config, CLIPTextConfig):
412
+ raise TypeError(
413
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
414
+ f" {type(config.text_config)}."
415
+ )
416
+
417
+ if not isinstance(config.vision_config, CLIPVisionConfig):
418
+ raise TypeError(
419
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
420
+ f" {type(config.vision_config)}."
421
+ )
422
+
423
+ text_config = config.text_config
424
+ vision_config = config.vision_config
425
+
426
+ self.projection_dim = config.projection_dim
427
+ self.text_embed_dim = text_config.hidden_size
428
+ self.vision_embed_dim = vision_config.hidden_size
429
+ self.visual_projection = nn.Linear(
430
+ self.vision_embed_dim, self.projection_dim, bias=False
431
+ )
432
+ self.text_projection = nn.Linear(
433
+ self.text_embed_dim, self.projection_dim, bias=False
434
+ )
435
+ self.logit_scale = nn.Parameter(
436
+ torch.tensor(self.config.logit_scale_init_value)
437
+ )
438
+
439
+ text_model = CLIPTextModel(
440
+ text_config, quant_config, prefix=add_prefix("text_model", prefix)
441
+ )
442
+ vision_model = CLIPVisionModel(
443
+ vision_config, quant_config, prefix=add_prefix("vision_model", prefix)
444
+ )
445
+ self.text_model = text_model.text_model
446
+ self.vision_model = vision_model.vision_model
447
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
448
+ monkey_patch_weight_loader()
449
+
450
+ def forward(
451
+ self,
452
+ input_ids: torch.Tensor,
453
+ positions: torch.Tensor,
454
+ forward_batch: ForwardBatch,
455
+ get_embedding: bool = True,
456
+ ):
457
+ assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
458
+ mm_inputs = []
459
+ if forward_batch.mm_inputs is not None:
460
+ mm_inputs = forward_batch.mm_inputs
461
+ pixel_values_list = [
462
+ item.pixel_values
463
+ for item in flatten_nested_list(
464
+ [mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
465
+ )
466
+ ]
467
+ if len(pixel_values_list) != 0:
468
+ pixel_values = torch.concat(pixel_values_list)
469
+ vision_outputs = self.vision_model(pixel_values)
470
+ pooled_output = vision_outputs[:, 0, :]
471
+ image_embeds = self.visual_projection(pooled_output)
472
+ image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
473
+ return EmbeddingPoolerOutput(embeddings=image_embeds)
474
+
475
+ else:
476
+ text_outputs = self.text_model(input_ids, position_ids=positions)
477
+ pooled_output = self.pooler(text_outputs[0], forward_batch)
478
+ return EmbeddingPoolerOutput(
479
+ embeddings=self.text_projection(pooled_output.embeddings)
480
+ )
481
+
482
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
483
+ # Clip embeddings models handle text/image separately, so we don't need to pad input ids
484
+ return input_ids
485
+
486
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
487
+ stacked_params_mapping = [
488
+ # (param_name, shard_name, shard_id)
489
+ ("qkv_proj", "q_proj", "q"),
490
+ ("qkv_proj", "k_proj", "k"),
491
+ ("qkv_proj", "v_proj", "v"),
492
+ ]
493
+ params_dict = dict(self.named_parameters())
494
+ for name, loaded_weight in weights:
495
+ if "position_ids" in name:
496
+ continue
497
+ if "out_proj" in name:
498
+ name = name.replace("out_proj", "proj")
499
+ for param_name, shard_name, shard_id in stacked_params_mapping:
500
+ if shard_name not in name:
501
+ continue
502
+ name = name.replace(shard_name, param_name)
503
+ param = params_dict[name]
504
+ weight_loader = param.weight_loader
505
+ weight_loader(param, loaded_weight, shard_id)
506
+ break
507
+ else:
508
+ param = params_dict[name]
509
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
510
+ weight_loader(param, loaded_weight)
511
+
512
+
513
+ # monkey patch weight loader to remove open_clip file
514
+ def monkey_patch_weight_loader():
515
+ import glob
516
+ import os
517
+
518
+ from sglang.srt.model_loader.loader import DefaultModelLoader
519
+ from sglang.srt.model_loader.weight_utils import (
520
+ download_weights_from_hf,
521
+ filter_files_not_needed_for_inference,
522
+ )
523
+
524
+ def prepare_weights(
525
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
526
+ ) -> Tuple[str, List[str], bool]:
527
+ model_name_or_path = (
528
+ self._maybe_download_from_modelscope(model_name_or_path, revision)
529
+ or model_name_or_path
530
+ )
531
+
532
+ is_local = os.path.isdir(model_name_or_path)
533
+ use_safetensors = False
534
+ allow_patterns = ["*.bin"]
535
+
536
+ if not is_local:
537
+ hf_folder = download_weights_from_hf(
538
+ model_name_or_path,
539
+ self.load_config.download_dir,
540
+ allow_patterns,
541
+ revision,
542
+ ignore_patterns=self.load_config.ignore_patterns,
543
+ )
544
+ else:
545
+ hf_folder = model_name_or_path
546
+
547
+ hf_weights_files: List[str] = []
548
+ for pattern in allow_patterns:
549
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
550
+
551
+ hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
552
+
553
+ # remove open_clip file
554
+ hf_weights_files = [
555
+ file for file in hf_weights_files if "open_clip" not in file
556
+ ]
557
+
558
+ if len(hf_weights_files) == 0:
559
+ raise RuntimeError(
560
+ f"Cannot find any model weights with `{model_name_or_path}`"
561
+ )
562
+
563
+ return hf_folder, hf_weights_files, use_safetensors
564
+
565
+ setattr(DefaultModelLoader, "_prepare_weights", prepare_weights)
566
+
567
+
568
+ EntryClass = CLIPModel
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
51
51
  MultiModalityDataPaddingPatternTokenPairs,
52
52
  general_mm_embed_routine,
53
53
  )
54
- from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
54
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
55
55
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
56
  from sglang.srt.model_loader.weight_utils import default_weight_loader
57
57
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -252,7 +252,7 @@ def resample_patch_embed(
252
252
  try:
253
253
  from torch import vmap
254
254
  except ImportError:
255
- from functorch import vmap
255
+ from torch.func import vmap
256
256
 
257
257
  assert len(patch_embed.shape) == 4, "Four dimensions expected"
258
258
  assert len(new_size) == 2, "New shape should only be hw"
@@ -1084,7 +1084,7 @@ def create_siglip_vit(
1084
1084
  )
1085
1085
 
1086
1086
  if ckpt_path:
1087
- state_dict = torch.load(ckpt_path, map_location="cpu")
1087
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
1088
1088
 
1089
1089
  incompatible_keys = model.load_state_dict(state_dict, strict=False)
1090
1090
  print(
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1959
1959
  )
1960
1960
  self.logits_processor = LogitsProcessor(config)
1961
1961
 
1962
- def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
1963
- pixel_values = image_input.pixel_values
1962
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1963
+ pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
1964
1964
  bs, n = pixel_values.shape[0:2]
1965
1965
  pixel_values = pixel_values.to(
1966
1966
  device=self.vision_model.device, dtype=self.vision_model.dtype
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1976
1976
  return images_embeds
1977
1977
 
1978
1978
  def get_input_embeddings(self) -> nn.Embedding:
1979
- return self.language_model.model.embed_tokens
1979
+ return self.language_model.get_input_embeddings()
1980
1980
 
1981
1981
  @torch.no_grad()
1982
1982
  def forward(
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1984
1984
  input_ids: torch.LongTensor,
1985
1985
  positions: torch.Tensor,
1986
1986
  forward_batch: ForwardBatch,
1987
+ get_embedding: bool = False,
1987
1988
  ) -> torch.Tensor:
1988
-
1989
- inputs_embeds = general_mm_embed_routine(
1989
+ hidden_states = general_mm_embed_routine(
1990
1990
  input_ids=input_ids,
1991
1991
  forward_batch=forward_batch,
1992
- embed_tokens=self.get_input_embeddings(),
1993
- mm_data_embedding_func=self.get_image_feature,
1994
- )
1995
-
1996
- return self.language_model(
1997
- input_ids=None,
1992
+ image_data_embedding_func=self.get_image_feature,
1993
+ language_model=self.language_model,
1998
1994
  positions=positions,
1999
- forward_batch=forward_batch,
2000
- input_embeds=inputs_embeds,
2001
- get_embedding=False,
2002
1995
  )
2003
1996
 
1997
+ return hidden_states
1998
+
2004
1999
  def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2005
2000
  return self.gen_aligner(self.gen_embed(image_ids))
2006
2001