sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
|
|
11
11
|
)
|
12
12
|
from sglang.srt.layers.linear import ReplicatedLinear
|
13
13
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
14
|
-
from sglang.srt.managers.
|
14
|
+
from sglang.srt.managers.mm_utils import (
|
15
|
+
MultiModalityDataPaddingPatternImageTokens,
|
16
|
+
general_mm_embed_routine,
|
17
|
+
)
|
18
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
15
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
16
20
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
17
21
|
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
|
|
150
154
|
return x
|
151
155
|
|
152
156
|
|
153
|
-
# todo
|
154
157
|
class DeepseekVL2ForCausalLM(nn.Module):
|
155
158
|
|
156
159
|
def __init__(
|
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
215
218
|
forward_batch: ForwardBatch,
|
216
219
|
**kwargs: object,
|
217
220
|
):
|
218
|
-
|
219
|
-
if (
|
220
|
-
forward_batch.forward_mode.is_extend()
|
221
|
-
and forward_batch.contains_image_inputs()
|
222
|
-
):
|
223
|
-
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
224
|
-
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
225
|
-
for idx, image in enumerate(forward_batch.mm_inputs):
|
226
|
-
if image is None:
|
227
|
-
continue
|
228
|
-
start_idx = extend_start_loc_cpu[idx]
|
229
|
-
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
230
|
-
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
231
|
-
image_features = self.get_image_feature(image)
|
232
|
-
input_embeds[start_idx:end_idx] = input_embeds[
|
233
|
-
start_idx:end_idx
|
234
|
-
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
235
|
-
|
236
|
-
outputs = self.language_model.forward(
|
221
|
+
hs = general_mm_embed_routine(
|
237
222
|
input_ids=input_ids,
|
238
223
|
positions=positions,
|
239
224
|
forward_batch=forward_batch,
|
240
|
-
|
225
|
+
image_data_embedding_func=self.get_image_feature,
|
226
|
+
language_model=self.language_model,
|
241
227
|
)
|
242
228
|
|
243
|
-
return
|
229
|
+
return hs
|
244
230
|
|
245
231
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
246
232
|
stacked_params_mapping = [
|
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
263
249
|
weights_loader(param, loaded_weight)
|
264
250
|
|
265
251
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
252
|
+
helper = MultiModalityDataPaddingPatternImageTokens(
|
253
|
+
image_token_id=image_inputs.im_token_id
|
254
|
+
)
|
255
|
+
return helper.pad_input_tokens(input_ids, image_inputs)
|
256
|
+
|
257
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
258
|
+
|
259
|
+
images_spatial_crop = torch.cat(
|
260
|
+
[item.image_spatial_crop for item in items], dim=0
|
261
|
+
)
|
262
|
+
|
263
|
+
assert images_spatial_crop.dim() == 3
|
264
|
+
|
265
|
+
# TODO: can it be batched ?
|
277
266
|
images_in_this_batch = []
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
# [hw, D]
|
286
|
-
global_features = images_embeds[tile_index]
|
287
|
-
|
288
|
-
# [num_height_tiles * num_width_tiles, hw, D]
|
289
|
-
local_features = images_embeds[
|
290
|
-
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
291
|
-
]
|
292
|
-
tile_index += num_tiles_in_image + 1
|
293
|
-
|
294
|
-
# format global and local features
|
295
|
-
# ----------------- global view add newline -----------------
|
296
|
-
# [hw, D] -> [h, w, D]
|
297
|
-
global_features = global_features.view(h, w, n_dim)
|
298
|
-
|
299
|
-
# [D] -> [h, 1, D]
|
300
|
-
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
301
|
-
|
302
|
-
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
303
|
-
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
304
|
-
|
305
|
-
# [h, w + 1, D] -> [h * (w + 1), D]
|
306
|
-
global_features = global_features.view(-1, n_dim)
|
307
|
-
|
308
|
-
# ----------------- local view add newline -----------------
|
309
|
-
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
310
|
-
# [num_height_tiles * h, num_width_tiles * w, D]
|
311
|
-
local_features = rearrange(
|
312
|
-
local_features,
|
313
|
-
"(th tw) (h w) d -> (th h) (tw w) d",
|
314
|
-
th=num_height_tiles,
|
315
|
-
tw=num_width_tiles,
|
316
|
-
h=h,
|
317
|
-
w=w,
|
267
|
+
for item in items:
|
268
|
+
assert item.pixel_values.dim() == 4
|
269
|
+
image_feature = self.vision.forward_features(
|
270
|
+
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
|
271
|
+
device=next(self.vision.parameters()).device
|
272
|
+
)
|
318
273
|
)
|
274
|
+
images_embeds = self.projector(image_feature)
|
275
|
+
_, hw, n_dim = images_embeds.shape
|
276
|
+
h = w = int(hw**0.5)
|
277
|
+
tile_index = 0
|
278
|
+
for jdx in range(item.image_spatial_crop.shape[1]):
|
279
|
+
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
|
280
|
+
if num_width_tiles == 0 or num_height_tiles == 0:
|
281
|
+
break
|
282
|
+
num_tiles_in_image = num_width_tiles * num_height_tiles
|
283
|
+
|
284
|
+
# [hw, D]
|
285
|
+
global_features = images_embeds[tile_index]
|
286
|
+
|
287
|
+
# [num_height_tiles * num_width_tiles, hw, D]
|
288
|
+
local_features = images_embeds[
|
289
|
+
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
290
|
+
]
|
291
|
+
tile_index += num_tiles_in_image + 1
|
319
292
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
h
|
326
|
-
|
293
|
+
# format global and local features
|
294
|
+
# ----------------- global view add newline -----------------
|
295
|
+
# [hw, D] -> [h, w, D]
|
296
|
+
global_features = global_features.view(h, w, n_dim)
|
297
|
+
|
298
|
+
# [D] -> [h, 1, D]
|
299
|
+
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
327
300
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
332
|
-
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
333
|
-
local_features = local_features.view(-1, n_dim)
|
334
|
-
|
335
|
-
# merge global and local tiles
|
336
|
-
if self.global_view_pos == "head":
|
337
|
-
global_local_features = torch.cat(
|
338
|
-
[
|
339
|
-
global_features,
|
340
|
-
self.view_seperator[None, :],
|
341
|
-
local_features,
|
342
|
-
]
|
301
|
+
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
302
|
+
global_features = torch.cat(
|
303
|
+
[global_features, new_lines_in_global], dim=1
|
343
304
|
)
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
305
|
+
|
306
|
+
# [h, w + 1, D] -> [h * (w + 1), D]
|
307
|
+
global_features = global_features.view(-1, n_dim)
|
308
|
+
|
309
|
+
# ----------------- local view add newline -----------------
|
310
|
+
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
311
|
+
# [num_height_tiles * h, num_width_tiles * w, D]
|
312
|
+
local_features = rearrange(
|
313
|
+
local_features,
|
314
|
+
"(th tw) (h w) d -> (th h) (tw w) d",
|
315
|
+
th=num_height_tiles,
|
316
|
+
tw=num_width_tiles,
|
317
|
+
h=h,
|
318
|
+
w=w,
|
351
319
|
)
|
352
320
|
|
353
|
-
|
321
|
+
# [D] -> [num_height_tiles * h, 1, D]
|
322
|
+
new_lines_in_local = repeat(
|
323
|
+
self.image_newline,
|
324
|
+
"d -> (th h) 1 d",
|
325
|
+
th=num_height_tiles,
|
326
|
+
h=h,
|
327
|
+
)
|
328
|
+
|
329
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
330
|
+
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
331
|
+
|
332
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
333
|
+
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
334
|
+
local_features = local_features.view(-1, n_dim)
|
335
|
+
|
336
|
+
# merge global and local tiles
|
337
|
+
if self.global_view_pos == "head":
|
338
|
+
global_local_features = torch.cat(
|
339
|
+
[
|
340
|
+
global_features,
|
341
|
+
self.view_seperator[None, :],
|
342
|
+
local_features,
|
343
|
+
]
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
global_local_features = torch.cat(
|
347
|
+
[
|
348
|
+
local_features,
|
349
|
+
self.view_seperator[None, :],
|
350
|
+
global_features,
|
351
|
+
]
|
352
|
+
)
|
353
|
+
|
354
|
+
images_in_this_batch.append(global_local_features)
|
354
355
|
|
355
356
|
return torch.cat(images_in_this_batch, dim=0)
|
356
357
|
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
|
-
from transformers import
|
25
|
-
AutoModel,
|
26
|
-
BatchFeature,
|
27
|
-
Gemma3Config,
|
28
|
-
Gemma3Processor,
|
29
|
-
PreTrainedModel,
|
30
|
-
)
|
31
|
-
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
24
|
+
from transformers import AutoModel, Gemma3Config, PreTrainedModel
|
32
25
|
|
33
26
|
from sglang.srt.hf_transformers_utils import get_processor
|
34
27
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
|
|
38
31
|
MultiModalityDataPaddingPatternTokenPairs,
|
39
32
|
general_mm_embed_routine,
|
40
33
|
)
|
41
|
-
from sglang.srt.managers.schedule_batch import
|
34
|
+
from sglang.srt.managers.schedule_batch import (
|
35
|
+
MultimodalDataItem,
|
36
|
+
MultimodalInputs,
|
37
|
+
flatten_nested_list,
|
38
|
+
)
|
42
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
40
|
from sglang.srt.model_loader.weight_utils import (
|
44
41
|
default_weight_loader,
|
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
274
271
|
"""
|
275
272
|
return self.language_model.get_attention_sliding_window_size()
|
276
273
|
|
277
|
-
def get_image_feature(self,
|
274
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
278
275
|
"""
|
279
276
|
Projects the last hidden state from the vision model into language model space.
|
280
277
|
|
281
|
-
Args:
|
282
|
-
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
283
|
-
The tensors corresponding to the input images.
|
284
278
|
Returns:
|
285
279
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
286
280
|
"""
|
287
|
-
pixel_values =
|
281
|
+
pixel_values = torch.stack(
|
282
|
+
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
283
|
+
)
|
288
284
|
pixel_values = pixel_values.to("cuda")
|
289
285
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
290
286
|
|
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
292
288
|
image_features = self.multi_modal_projector(vision_outputs)
|
293
289
|
return image_features
|
294
290
|
|
295
|
-
def embed_mm_inputs(
|
296
|
-
self,
|
297
|
-
input_ids: torch.Tensor,
|
298
|
-
forward_batch: ForwardBatch,
|
299
|
-
image_input: MultimodalInputs,
|
300
|
-
) -> torch.Tensor:
|
301
|
-
if input_ids is None:
|
302
|
-
raise ValueError("Unimplemented")
|
303
|
-
# boolean-masking image tokens
|
304
|
-
special_image_mask = torch.isin(
|
305
|
-
input_ids,
|
306
|
-
torch.tensor(image_input.pad_values, device=input_ids.device),
|
307
|
-
).unsqueeze(-1)
|
308
|
-
num_image_tokens_in_input_ids = special_image_mask.sum()
|
309
|
-
|
310
|
-
inputs_embeds = None
|
311
|
-
if num_image_tokens_in_input_ids == 0:
|
312
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
313
|
-
return inputs_embeds
|
314
|
-
else:
|
315
|
-
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
316
|
-
image_features = self.get_image_feature(image_input.pixel_values)
|
317
|
-
|
318
|
-
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
319
|
-
num_image_tokens_in_embedding = (
|
320
|
-
image_features.shape[0] * image_features.shape[1]
|
321
|
-
)
|
322
|
-
|
323
|
-
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
324
|
-
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
|
325
|
-
image_features = image_features[:num_image, :]
|
326
|
-
logger.warning(
|
327
|
-
f"Number of images does not match number of special image tokens in the input text. "
|
328
|
-
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
329
|
-
"tokens from image embeddings."
|
330
|
-
)
|
331
|
-
|
332
|
-
# Important: clamp after extracting original image boundaries
|
333
|
-
input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
334
|
-
|
335
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
336
|
-
|
337
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
338
|
-
inputs_embeds.device
|
339
|
-
)
|
340
|
-
|
341
|
-
image_features = image_features.to(
|
342
|
-
inputs_embeds.device, inputs_embeds.dtype
|
343
|
-
)
|
344
|
-
inputs_embeds = inputs_embeds.masked_scatter(
|
345
|
-
special_image_mask, image_features
|
346
|
-
)
|
347
|
-
|
348
|
-
return inputs_embeds
|
349
|
-
|
350
291
|
@torch.no_grad()
|
351
292
|
def forward(
|
352
293
|
self,
|
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
405
346
|
else:
|
406
347
|
llm_input_ids = input_ids
|
407
348
|
|
408
|
-
|
349
|
+
hs = general_mm_embed_routine(
|
409
350
|
input_ids=llm_input_ids,
|
410
351
|
forward_batch=forward_batch,
|
411
|
-
|
412
|
-
|
413
|
-
)
|
414
|
-
|
415
|
-
outputs = self.language_model(
|
416
|
-
input_ids=None,
|
352
|
+
language_model=self.language_model,
|
353
|
+
image_data_embedding_func=self.get_image_feature,
|
417
354
|
positions=positions,
|
418
|
-
forward_batch=forward_batch,
|
419
|
-
input_embeds=inputs_embeds,
|
420
|
-
**kwargs,
|
421
355
|
)
|
422
356
|
|
423
|
-
return
|
357
|
+
return hs
|
424
358
|
|
425
359
|
def tie_weights(self):
|
426
360
|
return self.language_model.tie_weights()
|
sglang/srt/models/llama.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from typing import Any, Dict, Iterable, List, Optional,
|
20
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
|
|
428
428
|
else:
|
429
429
|
return self.pooler(hidden_states, forward_batch)
|
430
430
|
|
431
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
432
|
+
return self.model.embed_tokens
|
433
|
+
|
431
434
|
def get_hidden_dim(self, module_name):
|
432
435
|
# return input_dim, output_dim
|
433
436
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
sglang/srt/models/llava.py
CHANGED
@@ -31,7 +31,7 @@ from transformers import (
|
|
31
31
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
32
|
|
33
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
34
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
35
35
|
from sglang.srt.mm_utils import (
|
36
36
|
get_anyres_image_grid_shape,
|
37
37
|
unpad_image,
|
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
-
from sglang.srt.utils import add_prefix
|
45
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list
|
46
46
|
|
47
47
|
|
48
48
|
class LlavaBaseForCausalLM(nn.Module):
|
49
49
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
50
|
-
image_sizes
|
50
|
+
image_sizes = flatten_nested_list(
|
51
|
+
[item.image_sizes for item in image_inputs.mm_items]
|
52
|
+
)
|
53
|
+
|
54
|
+
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
51
55
|
|
52
56
|
# hardcode for spatial_unpad + anyres
|
53
|
-
if
|
54
|
-
|
55
|
-
|
57
|
+
if any(
|
58
|
+
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
|
59
|
+
for item in image_inputs.mm_items
|
56
60
|
):
|
57
61
|
image_aspect_ratio = "pad"
|
58
62
|
else:
|
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
66
70
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
67
71
|
)
|
68
72
|
else:
|
69
|
-
new_image_feature_len = self.image_feature_len #
|
73
|
+
new_image_feature_len = self.image_feature_len # multi-image
|
70
74
|
|
71
75
|
height = width = self.num_patches_per_side
|
72
76
|
if "anyres" in image_aspect_ratio:
|
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
101
105
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
102
106
|
input_ids = (
|
103
107
|
input_ids[:offset]
|
104
|
-
+ [pad_values[image_idx]] * new_image_feature_len
|
108
|
+
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
|
105
109
|
+ input_ids[offset + 1 :]
|
106
110
|
)
|
107
111
|
offset_list.append(offset)
|
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
150
154
|
modalities_list = []
|
151
155
|
max_image_offset = []
|
152
156
|
for im in image_inputs:
|
153
|
-
if im
|
154
|
-
modalities_list.extend(im.
|
157
|
+
if im:
|
158
|
+
modalities_list.extend([item.modality for item in im.mm_items])
|
155
159
|
if im and im.image_offsets:
|
156
160
|
max_image_offset.append(
|
157
161
|
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
|
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
164
168
|
|
165
169
|
if need_vision.any():
|
166
170
|
bs = forward_batch.batch_size
|
167
|
-
pixel_values =
|
168
|
-
|
169
|
-
|
171
|
+
pixel_values = flatten_nested_list(
|
172
|
+
[
|
173
|
+
[item.pixel_values for item in image_inputs[i].mm_items]
|
174
|
+
for i in range(bs)
|
175
|
+
if need_vision[i]
|
176
|
+
]
|
177
|
+
)
|
170
178
|
image_sizes = [
|
171
|
-
|
179
|
+
flatten_nested_list(
|
180
|
+
[item.image_sizes for item in image_inputs[i].mm_items]
|
181
|
+
)
|
182
|
+
for i in range(bs)
|
183
|
+
if need_vision[i]
|
172
184
|
]
|
173
185
|
|
174
186
|
########## Encode Image ########
|
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
197
209
|
new_image_features = []
|
198
210
|
height = width = self.num_patches_per_side
|
199
211
|
for image_idx, image_feature in enumerate(image_features):
|
200
|
-
if modalities_list[image_idx] ==
|
212
|
+
if modalities_list[image_idx] == Modality.IMAGE:
|
201
213
|
image_aspect_ratio = (
|
202
214
|
self.config.image_aspect_ratio
|
203
215
|
) # single image
|
204
216
|
elif (
|
205
|
-
modalities_list[image_idx] ==
|
206
|
-
or modalities_list[image_idx] ==
|
217
|
+
modalities_list[image_idx] == Modality.MULTI_IMAGES
|
218
|
+
or modalities_list[image_idx] == Modality.VIDEO
|
207
219
|
):
|
208
220
|
image_aspect_ratio = "pad" # multi image
|
209
221
|
# image_aspect_ratio = (
|
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
212
224
|
if (
|
213
225
|
image_feature.shape[0] > 1
|
214
226
|
and "anyres" in image_aspect_ratio
|
215
|
-
and modalities_list[image_idx] ==
|
227
|
+
and modalities_list[image_idx] == Modality.IMAGE
|
216
228
|
):
|
217
229
|
base_image_feature = image_feature[0]
|
218
230
|
image_feature = image_feature[1:]
|
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
312
324
|
)
|
313
325
|
image_feature = image_feature.unsqueeze(0)
|
314
326
|
else:
|
315
|
-
if modalities_list[image_idx] ==
|
327
|
+
if modalities_list[image_idx] == Modality.VIDEO: # video
|
316
328
|
# 2x2 pooling
|
317
329
|
num_of_frames = image_feature.shape[0]
|
318
330
|
image_feature = image_feature.view(
|
sglang/srt/models/llavavid.py
CHANGED
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
|
|
22
22
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
23
23
|
|
24
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
25
|
-
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
25
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
61
|
-
pad_values = image_inputs.
|
61
|
+
pad_values = [item.pad_value for item in image_inputs.mm_items]
|
62
62
|
new_image_feature_len = self.image_feature_len
|
63
63
|
|
64
64
|
pad_ids = pad_values * (
|
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
|
|
133
133
|
need_vision = start_positions <= np.array(max_image_offset)
|
134
134
|
|
135
135
|
if need_vision.any():
|
136
|
-
pixel_values =
|
137
|
-
|
138
|
-
|
136
|
+
pixel_values = flatten_nested_list(
|
137
|
+
[
|
138
|
+
[item.pixel_values for item in image_inputs[i].mm_items]
|
139
|
+
for i in range(bs)
|
140
|
+
if need_vision[i]
|
141
|
+
]
|
142
|
+
)
|
139
143
|
image_offsets = [
|
140
|
-
|
144
|
+
flatten_nested_list(
|
145
|
+
[item.image_offsets for item in image_inputs[i].mm_items]
|
146
|
+
)
|
147
|
+
for i in range(bs)
|
148
|
+
if need_vision[i]
|
141
149
|
]
|
142
150
|
|
143
151
|
########## Encode Image ########
|
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
246
254
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
247
255
|
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
248
256
|
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
249
|
-
"model.vision_tower.vision_tower": "vision_tower",
|
257
|
+
"model.vision_tower.vision_tower": "vision_tower",
|
258
|
+
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
250
259
|
"model.image_newline": "language_model.model.image_newline",
|
251
260
|
}
|
252
261
|
params_dict = dict(self.named_parameters())
|