sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
|
|
11
11
|
)
|
12
12
|
from sglang.srt.layers.linear import ReplicatedLinear
|
13
13
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
14
|
-
from sglang.srt.managers.
|
14
|
+
from sglang.srt.managers.mm_utils import (
|
15
|
+
MultiModalityDataPaddingPatternImageTokens,
|
16
|
+
general_mm_embed_routine,
|
17
|
+
)
|
18
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
15
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
16
20
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
17
21
|
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
|
|
150
154
|
return x
|
151
155
|
|
152
156
|
|
153
|
-
# todo
|
154
157
|
class DeepseekVL2ForCausalLM(nn.Module):
|
155
158
|
|
156
159
|
def __init__(
|
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
215
218
|
forward_batch: ForwardBatch,
|
216
219
|
**kwargs: object,
|
217
220
|
):
|
218
|
-
|
219
|
-
if (
|
220
|
-
forward_batch.forward_mode.is_extend()
|
221
|
-
and forward_batch.contains_image_inputs()
|
222
|
-
):
|
223
|
-
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
224
|
-
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
225
|
-
for idx, image in enumerate(forward_batch.mm_inputs):
|
226
|
-
if image is None:
|
227
|
-
continue
|
228
|
-
start_idx = extend_start_loc_cpu[idx]
|
229
|
-
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
230
|
-
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
231
|
-
image_features = self.get_image_feature(image)
|
232
|
-
input_embeds[start_idx:end_idx] = input_embeds[
|
233
|
-
start_idx:end_idx
|
234
|
-
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
235
|
-
|
236
|
-
outputs = self.language_model.forward(
|
221
|
+
hs = general_mm_embed_routine(
|
237
222
|
input_ids=input_ids,
|
238
223
|
positions=positions,
|
239
224
|
forward_batch=forward_batch,
|
240
|
-
|
225
|
+
image_data_embedding_func=self.get_image_feature,
|
226
|
+
language_model=self.language_model,
|
241
227
|
)
|
242
228
|
|
243
|
-
return
|
229
|
+
return hs
|
244
230
|
|
245
231
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
246
232
|
stacked_params_mapping = [
|
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
263
249
|
weights_loader(param, loaded_weight)
|
264
250
|
|
265
251
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
252
|
+
helper = MultiModalityDataPaddingPatternImageTokens(
|
253
|
+
image_token_id=image_inputs.im_token_id
|
254
|
+
)
|
255
|
+
return helper.pad_input_tokens(input_ids, image_inputs)
|
256
|
+
|
257
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
258
|
+
|
259
|
+
images_spatial_crop = torch.cat(
|
260
|
+
[item.image_spatial_crop for item in items], dim=0
|
261
|
+
)
|
262
|
+
|
263
|
+
assert images_spatial_crop.dim() == 3
|
264
|
+
|
265
|
+
# TODO: can it be batched ?
|
277
266
|
images_in_this_batch = []
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
# [hw, D]
|
286
|
-
global_features = images_embeds[tile_index]
|
287
|
-
|
288
|
-
# [num_height_tiles * num_width_tiles, hw, D]
|
289
|
-
local_features = images_embeds[
|
290
|
-
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
291
|
-
]
|
292
|
-
tile_index += num_tiles_in_image + 1
|
293
|
-
|
294
|
-
# format global and local features
|
295
|
-
# ----------------- global view add newline -----------------
|
296
|
-
# [hw, D] -> [h, w, D]
|
297
|
-
global_features = global_features.view(h, w, n_dim)
|
298
|
-
|
299
|
-
# [D] -> [h, 1, D]
|
300
|
-
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
301
|
-
|
302
|
-
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
303
|
-
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
304
|
-
|
305
|
-
# [h, w + 1, D] -> [h * (w + 1), D]
|
306
|
-
global_features = global_features.view(-1, n_dim)
|
307
|
-
|
308
|
-
# ----------------- local view add newline -----------------
|
309
|
-
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
310
|
-
# [num_height_tiles * h, num_width_tiles * w, D]
|
311
|
-
local_features = rearrange(
|
312
|
-
local_features,
|
313
|
-
"(th tw) (h w) d -> (th h) (tw w) d",
|
314
|
-
th=num_height_tiles,
|
315
|
-
tw=num_width_tiles,
|
316
|
-
h=h,
|
317
|
-
w=w,
|
267
|
+
for item in items:
|
268
|
+
assert item.pixel_values.dim() == 4
|
269
|
+
image_feature = self.vision.forward_features(
|
270
|
+
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
|
271
|
+
device=next(self.vision.parameters()).device
|
272
|
+
)
|
318
273
|
)
|
274
|
+
images_embeds = self.projector(image_feature)
|
275
|
+
_, hw, n_dim = images_embeds.shape
|
276
|
+
h = w = int(hw**0.5)
|
277
|
+
tile_index = 0
|
278
|
+
for jdx in range(item.image_spatial_crop.shape[1]):
|
279
|
+
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
|
280
|
+
if num_width_tiles == 0 or num_height_tiles == 0:
|
281
|
+
break
|
282
|
+
num_tiles_in_image = num_width_tiles * num_height_tiles
|
283
|
+
|
284
|
+
# [hw, D]
|
285
|
+
global_features = images_embeds[tile_index]
|
286
|
+
|
287
|
+
# [num_height_tiles * num_width_tiles, hw, D]
|
288
|
+
local_features = images_embeds[
|
289
|
+
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
290
|
+
]
|
291
|
+
tile_index += num_tiles_in_image + 1
|
319
292
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
h
|
326
|
-
|
293
|
+
# format global and local features
|
294
|
+
# ----------------- global view add newline -----------------
|
295
|
+
# [hw, D] -> [h, w, D]
|
296
|
+
global_features = global_features.view(h, w, n_dim)
|
297
|
+
|
298
|
+
# [D] -> [h, 1, D]
|
299
|
+
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
327
300
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
332
|
-
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
333
|
-
local_features = local_features.view(-1, n_dim)
|
334
|
-
|
335
|
-
# merge global and local tiles
|
336
|
-
if self.global_view_pos == "head":
|
337
|
-
global_local_features = torch.cat(
|
338
|
-
[
|
339
|
-
global_features,
|
340
|
-
self.view_seperator[None, :],
|
341
|
-
local_features,
|
342
|
-
]
|
301
|
+
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
302
|
+
global_features = torch.cat(
|
303
|
+
[global_features, new_lines_in_global], dim=1
|
343
304
|
)
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
305
|
+
|
306
|
+
# [h, w + 1, D] -> [h * (w + 1), D]
|
307
|
+
global_features = global_features.view(-1, n_dim)
|
308
|
+
|
309
|
+
# ----------------- local view add newline -----------------
|
310
|
+
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
311
|
+
# [num_height_tiles * h, num_width_tiles * w, D]
|
312
|
+
local_features = rearrange(
|
313
|
+
local_features,
|
314
|
+
"(th tw) (h w) d -> (th h) (tw w) d",
|
315
|
+
th=num_height_tiles,
|
316
|
+
tw=num_width_tiles,
|
317
|
+
h=h,
|
318
|
+
w=w,
|
351
319
|
)
|
352
320
|
|
353
|
-
|
321
|
+
# [D] -> [num_height_tiles * h, 1, D]
|
322
|
+
new_lines_in_local = repeat(
|
323
|
+
self.image_newline,
|
324
|
+
"d -> (th h) 1 d",
|
325
|
+
th=num_height_tiles,
|
326
|
+
h=h,
|
327
|
+
)
|
328
|
+
|
329
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
330
|
+
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
331
|
+
|
332
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
333
|
+
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
334
|
+
local_features = local_features.view(-1, n_dim)
|
335
|
+
|
336
|
+
# merge global and local tiles
|
337
|
+
if self.global_view_pos == "head":
|
338
|
+
global_local_features = torch.cat(
|
339
|
+
[
|
340
|
+
global_features,
|
341
|
+
self.view_seperator[None, :],
|
342
|
+
local_features,
|
343
|
+
]
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
global_local_features = torch.cat(
|
347
|
+
[
|
348
|
+
local_features,
|
349
|
+
self.view_seperator[None, :],
|
350
|
+
global_features,
|
351
|
+
]
|
352
|
+
)
|
353
|
+
|
354
|
+
images_in_this_batch.append(global_local_features)
|
354
355
|
|
355
356
|
return torch.cat(images_in_this_batch, dim=0)
|
356
357
|
|
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
|
|
47
47
|
from sglang.srt.utils import add_prefix, make_layers
|
48
48
|
|
49
49
|
|
50
|
+
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
51
|
+
# SGLang assumes exclusive
|
52
|
+
def get_attention_sliding_window_size(config):
|
53
|
+
return config.sliding_window - 1
|
54
|
+
|
55
|
+
|
50
56
|
# Adapted from:
|
51
57
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
|
52
58
|
def extract_layer_index(prefix: str) -> int:
|
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
|
|
170
176
|
self.rope_scaling = {"rope_type": "default"}
|
171
177
|
# FIXME(mick): idk why vllm does this
|
172
178
|
# self.sliding_window = config.interleaved_sliding_window
|
173
|
-
self.sliding_window = config
|
179
|
+
self.sliding_window = get_attention_sliding_window_size(config)
|
174
180
|
else:
|
175
181
|
# Global attention. Use the values in config.json.
|
176
182
|
self.rope_theta = config.rope_theta
|
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
|
|
184
190
|
num_kv_heads=self.num_kv_heads,
|
185
191
|
layer_id=layer_id,
|
186
192
|
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
|
193
|
+
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
194
|
+
# attention backend in `ForwardBatch`.
|
187
195
|
sliding_window_size=self.sliding_window,
|
188
196
|
prefix=add_prefix("attn", prefix),
|
189
197
|
)
|
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|
609
617
|
def get_input_embeddings(self) -> nn.Embedding:
|
610
618
|
return self.model.embed_tokens
|
611
619
|
|
620
|
+
def get_attention_sliding_window_size(self):
|
621
|
+
return get_attention_sliding_window_size(self.config)
|
622
|
+
|
612
623
|
def dtype(self) -> torch.dtype:
|
613
624
|
return next(self.parameters()).dtype
|
614
625
|
|
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|
621
632
|
input_embeds: torch.Tensor = None,
|
622
633
|
**kwargs,
|
623
634
|
) -> LogitsProcessor:
|
624
|
-
|
625
635
|
hidden_states = self.model(
|
626
636
|
input_ids, positions, forward_batch, input_embeds, **kwargs
|
627
637
|
)
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
|
-
from transformers import
|
25
|
-
AutoModel,
|
26
|
-
BatchFeature,
|
27
|
-
Gemma3Config,
|
28
|
-
Gemma3Processor,
|
29
|
-
PreTrainedModel,
|
30
|
-
)
|
31
|
-
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
24
|
+
from transformers import AutoModel, Gemma3Config, PreTrainedModel
|
32
25
|
|
33
26
|
from sglang.srt.hf_transformers_utils import get_processor
|
34
27
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
|
|
38
31
|
MultiModalityDataPaddingPatternTokenPairs,
|
39
32
|
general_mm_embed_routine,
|
40
33
|
)
|
41
|
-
from sglang.srt.managers.schedule_batch import
|
34
|
+
from sglang.srt.managers.schedule_batch import (
|
35
|
+
MultimodalDataItem,
|
36
|
+
MultimodalInputs,
|
37
|
+
flatten_nested_list,
|
38
|
+
)
|
42
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
40
|
from sglang.srt.model_loader.weight_utils import (
|
44
41
|
default_weight_loader,
|
@@ -268,17 +265,22 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
268
265
|
def get_input_embeddings(self) -> nn.Embedding:
|
269
266
|
return self.language_model.get_input_embeddings()
|
270
267
|
|
271
|
-
def
|
268
|
+
def get_attention_sliding_window_size(self):
|
269
|
+
"""
|
270
|
+
This value is used to initialize attention backends in `ForwardBatch`.
|
271
|
+
"""
|
272
|
+
return self.language_model.get_attention_sliding_window_size()
|
273
|
+
|
274
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
272
275
|
"""
|
273
276
|
Projects the last hidden state from the vision model into language model space.
|
274
277
|
|
275
|
-
Args:
|
276
|
-
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
277
|
-
The tensors corresponding to the input images.
|
278
278
|
Returns:
|
279
279
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
280
280
|
"""
|
281
|
-
pixel_values =
|
281
|
+
pixel_values = torch.stack(
|
282
|
+
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
283
|
+
)
|
282
284
|
pixel_values = pixel_values.to("cuda")
|
283
285
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
284
286
|
|
@@ -286,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
286
288
|
image_features = self.multi_modal_projector(vision_outputs)
|
287
289
|
return image_features
|
288
290
|
|
289
|
-
def embed_mm_inputs(
|
290
|
-
self,
|
291
|
-
input_ids: torch.Tensor,
|
292
|
-
forward_batch: ForwardBatch,
|
293
|
-
image_input: MultimodalInputs,
|
294
|
-
) -> torch.Tensor:
|
295
|
-
if input_ids is None:
|
296
|
-
raise ValueError("Unimplemented")
|
297
|
-
# boolean-masking image tokens
|
298
|
-
special_image_mask = torch.isin(
|
299
|
-
input_ids,
|
300
|
-
torch.tensor(image_input.pad_values, device=input_ids.device),
|
301
|
-
).unsqueeze(-1)
|
302
|
-
num_image_tokens_in_input_ids = special_image_mask.sum()
|
303
|
-
|
304
|
-
inputs_embeds = None
|
305
|
-
if num_image_tokens_in_input_ids == 0:
|
306
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
307
|
-
return inputs_embeds
|
308
|
-
else:
|
309
|
-
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
310
|
-
image_features = self.get_image_feature(image_input.pixel_values)
|
311
|
-
|
312
|
-
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
313
|
-
num_image_tokens_in_embedding = (
|
314
|
-
image_features.shape[0] * image_features.shape[1]
|
315
|
-
)
|
316
|
-
|
317
|
-
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
318
|
-
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
|
319
|
-
image_features = image_features[:num_image, :]
|
320
|
-
logger.warning(
|
321
|
-
f"Number of images does not match number of special image tokens in the input text. "
|
322
|
-
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
323
|
-
"tokens from image embeddings."
|
324
|
-
)
|
325
|
-
|
326
|
-
# Important: clamp after extracting original image boundaries
|
327
|
-
input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
328
|
-
|
329
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
330
|
-
|
331
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
332
|
-
inputs_embeds.device
|
333
|
-
)
|
334
|
-
|
335
|
-
image_features = image_features.to(
|
336
|
-
inputs_embeds.device, inputs_embeds.dtype
|
337
|
-
)
|
338
|
-
inputs_embeds = inputs_embeds.masked_scatter(
|
339
|
-
special_image_mask, image_features
|
340
|
-
)
|
341
|
-
|
342
|
-
return inputs_embeds
|
343
|
-
|
344
291
|
@torch.no_grad()
|
345
292
|
def forward(
|
346
293
|
self,
|
@@ -399,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
399
346
|
else:
|
400
347
|
llm_input_ids = input_ids
|
401
348
|
|
402
|
-
|
349
|
+
hs = general_mm_embed_routine(
|
403
350
|
input_ids=llm_input_ids,
|
404
351
|
forward_batch=forward_batch,
|
405
|
-
|
406
|
-
|
407
|
-
)
|
408
|
-
|
409
|
-
outputs = self.language_model(
|
410
|
-
input_ids=None,
|
352
|
+
language_model=self.language_model,
|
353
|
+
image_data_embedding_func=self.get_image_feature,
|
411
354
|
positions=positions,
|
412
|
-
forward_batch=forward_batch,
|
413
|
-
input_embeds=inputs_embeds,
|
414
|
-
**kwargs,
|
415
355
|
)
|
416
356
|
|
417
|
-
return
|
357
|
+
return hs
|
418
358
|
|
419
359
|
def tie_weights(self):
|
420
360
|
return self.language_model.tie_weights()
|
sglang/srt/models/llama.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from typing import Any, Dict, Iterable, List, Optional,
|
20
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
|
|
428
428
|
else:
|
429
429
|
return self.pooler(hidden_states, forward_batch)
|
430
430
|
|
431
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
432
|
+
return self.model.embed_tokens
|
433
|
+
|
431
434
|
def get_hidden_dim(self, module_name):
|
432
435
|
# return input_dim, output_dim
|
433
436
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
sglang/srt/models/llava.py
CHANGED
@@ -31,7 +31,7 @@ from transformers import (
|
|
31
31
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
32
|
|
33
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
34
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
35
35
|
from sglang.srt.mm_utils import (
|
36
36
|
get_anyres_image_grid_shape,
|
37
37
|
unpad_image,
|
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
-
from sglang.srt.utils import add_prefix
|
45
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list
|
46
46
|
|
47
47
|
|
48
48
|
class LlavaBaseForCausalLM(nn.Module):
|
49
49
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
50
|
-
image_sizes
|
50
|
+
image_sizes = flatten_nested_list(
|
51
|
+
[item.image_sizes for item in image_inputs.mm_items]
|
52
|
+
)
|
53
|
+
|
54
|
+
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
51
55
|
|
52
56
|
# hardcode for spatial_unpad + anyres
|
53
|
-
if
|
54
|
-
|
55
|
-
|
57
|
+
if any(
|
58
|
+
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
|
59
|
+
for item in image_inputs.mm_items
|
56
60
|
):
|
57
61
|
image_aspect_ratio = "pad"
|
58
62
|
else:
|
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
66
70
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
67
71
|
)
|
68
72
|
else:
|
69
|
-
new_image_feature_len = self.image_feature_len #
|
73
|
+
new_image_feature_len = self.image_feature_len # multi-image
|
70
74
|
|
71
75
|
height = width = self.num_patches_per_side
|
72
76
|
if "anyres" in image_aspect_ratio:
|
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
101
105
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
102
106
|
input_ids = (
|
103
107
|
input_ids[:offset]
|
104
|
-
+ [pad_values[image_idx]] * new_image_feature_len
|
108
|
+
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
|
105
109
|
+ input_ids[offset + 1 :]
|
106
110
|
)
|
107
111
|
offset_list.append(offset)
|
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
150
154
|
modalities_list = []
|
151
155
|
max_image_offset = []
|
152
156
|
for im in image_inputs:
|
153
|
-
if im
|
154
|
-
modalities_list.extend(im.
|
157
|
+
if im:
|
158
|
+
modalities_list.extend([item.modality for item in im.mm_items])
|
155
159
|
if im and im.image_offsets:
|
156
160
|
max_image_offset.append(
|
157
161
|
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
164
168
|
|
165
169
|
if need_vision.any():
|
166
170
|
bs = forward_batch.batch_size
|
167
|
-
pixel_values =
|
168
|
-
|
169
|
-
|
171
|
+
pixel_values = flatten_nested_list(
|
172
|
+
[
|
173
|
+
[item.pixel_values for item in image_inputs[i].mm_items]
|
174
|
+
for i in range(bs)
|
175
|
+
if need_vision[i]
|
176
|
+
]
|
177
|
+
)
|
170
178
|
image_sizes = [
|
171
|
-
|
179
|
+
flatten_nested_list(
|
180
|
+
[item.image_sizes for item in image_inputs[i].mm_items]
|
181
|
+
)
|
182
|
+
for i in range(bs)
|
183
|
+
if need_vision[i]
|
172
184
|
]
|
173
185
|
|
174
186
|
########## Encode Image ########
|
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
197
209
|
new_image_features = []
|
198
210
|
height = width = self.num_patches_per_side
|
199
211
|
for image_idx, image_feature in enumerate(image_features):
|
200
|
-
if modalities_list[image_idx] ==
|
212
|
+
if modalities_list[image_idx] == Modality.IMAGE:
|
201
213
|
image_aspect_ratio = (
|
202
214
|
self.config.image_aspect_ratio
|
203
215
|
) # single image
|
204
216
|
elif (
|
205
|
-
modalities_list[image_idx] ==
|
206
|
-
or modalities_list[image_idx] ==
|
217
|
+
modalities_list[image_idx] == Modality.MULTI_IMAGES
|
218
|
+
or modalities_list[image_idx] == Modality.VIDEO
|
207
219
|
):
|
208
220
|
image_aspect_ratio = "pad" # multi image
|
209
221
|
# image_aspect_ratio = (
|
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
212
224
|
if (
|
213
225
|
image_feature.shape[0] > 1
|
214
226
|
and "anyres" in image_aspect_ratio
|
215
|
-
and modalities_list[image_idx] ==
|
227
|
+
and modalities_list[image_idx] == Modality.IMAGE
|
216
228
|
):
|
217
229
|
base_image_feature = image_feature[0]
|
218
230
|
image_feature = image_feature[1:]
|
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
312
324
|
)
|
313
325
|
image_feature = image_feature.unsqueeze(0)
|
314
326
|
else:
|
315
|
-
if modalities_list[image_idx] ==
|
327
|
+
if modalities_list[image_idx] == Modality.VIDEO: # video
|
316
328
|
# 2x2 pooling
|
317
329
|
num_of_frames = image_feature.shape[0]
|
318
330
|
image_feature = image_feature.view(
|
sglang/srt/models/llavavid.py
CHANGED
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
|
|
22
22
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
23
23
|
|
24
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
25
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
25
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
61
|
-
pad_values = image_inputs.
|
61
|
+
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
62
62
|
new_image_feature_len = self.image_feature_len
|
63
63
|
|
64
64
|
pad_ids = pad_values * (
|
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
|
|
133
133
|
need_vision = start_positions <= np.array(max_image_offset)
|
134
134
|
|
135
135
|
if need_vision.any():
|
136
|
-
pixel_values =
|
137
|
-
|
138
|
-
|
136
|
+
pixel_values = flatten_nested_list(
|
137
|
+
[
|
138
|
+
[item.pixel_values for item in image_inputs[i].mm_items]
|
139
|
+
for i in range(bs)
|
140
|
+
if need_vision[i]
|
141
|
+
]
|
142
|
+
)
|
139
143
|
image_offsets = [
|
140
|
-
|
144
|
+
flatten_nested_list(
|
145
|
+
[item.image_offsets for item in image_inputs[i].mm_items]
|
146
|
+
)
|
147
|
+
for i in range(bs)
|
148
|
+
if need_vision[i]
|
141
149
|
]
|
142
150
|
|
143
151
|
########## Encode Image ########
|
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
246
254
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
247
255
|
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
248
256
|
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
249
|
-
"model.vision_tower.vision_tower": "vision_tower",
|
257
|
+
"model.vision_tower.vision_tower": "vision_tower",
|
258
|
+
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
250
259
|
"model.image_newline": "language_model.model.image_newline",
|
251
260
|
}
|
252
261
|
params_dict = dict(self.named_parameters())
|