sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
11
11
  )
12
12
  from sglang.srt.layers.linear import ReplicatedLinear
13
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
- from sglang.srt.managers.schedule_batch import MultimodalInputs
14
+ from sglang.srt.managers.mm_utils import (
15
+ MultiModalityDataPaddingPatternImageTokens,
16
+ general_mm_embed_routine,
17
+ )
18
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
15
19
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
16
20
  from sglang.srt.model_loader.weight_utils import default_weight_loader
17
21
  from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
150
154
  return x
151
155
 
152
156
 
153
- # todo
154
157
  class DeepseekVL2ForCausalLM(nn.Module):
155
158
 
156
159
  def __init__(
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
215
218
  forward_batch: ForwardBatch,
216
219
  **kwargs: object,
217
220
  ):
218
- input_embeds = self.language_model.model.embed_tokens(input_ids)
219
- if (
220
- forward_batch.forward_mode.is_extend()
221
- and forward_batch.contains_image_inputs()
222
- ):
223
- extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
224
- extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
225
- for idx, image in enumerate(forward_batch.mm_inputs):
226
- if image is None:
227
- continue
228
- start_idx = extend_start_loc_cpu[idx]
229
- end_idx = start_idx + extend_seq_lens_cpu[idx]
230
- images_emb_mask = image.images_emb_mask.to(device="cuda")
231
- image_features = self.get_image_feature(image)
232
- input_embeds[start_idx:end_idx] = input_embeds[
233
- start_idx:end_idx
234
- ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
235
-
236
- outputs = self.language_model.forward(
221
+ hs = general_mm_embed_routine(
237
222
  input_ids=input_ids,
238
223
  positions=positions,
239
224
  forward_batch=forward_batch,
240
- input_embeds=input_embeds,
225
+ image_data_embedding_func=self.get_image_feature,
226
+ language_model=self.language_model,
241
227
  )
242
228
 
243
- return outputs
229
+ return hs
244
230
 
245
231
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
246
232
  stacked_params_mapping = [
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
263
249
  weights_loader(param, loaded_weight)
264
250
 
265
251
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
266
- return input_ids
267
-
268
- def get_image_feature(self, image_input: MultimodalInputs):
269
- pixel_values = image_input.pixel_values.type(
270
- next(self.vision.parameters()).dtype
271
- ).to(device=next(self.vision.parameters()).device)
272
- image_feature = self.vision.forward_features(pixel_values)
273
- images_embeds = self.projector(image_feature)
274
- _, hw, n_dim = images_embeds.shape
275
- h = w = int(hw**0.5)
276
- tile_index = 0
252
+ helper = MultiModalityDataPaddingPatternImageTokens(
253
+ image_token_id=image_inputs.im_token_id
254
+ )
255
+ return helper.pad_input_tokens(input_ids, image_inputs)
256
+
257
+ def get_image_feature(self, items: List[MultimodalDataItem]):
258
+
259
+ images_spatial_crop = torch.cat(
260
+ [item.image_spatial_crop for item in items], dim=0
261
+ )
262
+
263
+ assert images_spatial_crop.dim() == 3
264
+
265
+ # TODO: can it be batched ?
277
266
  images_in_this_batch = []
278
- images_spatial_crop = image_input.image_spatial_crop
279
- for jdx in range(images_spatial_crop.shape[1]):
280
- num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
281
- if num_width_tiles == 0 or num_height_tiles == 0:
282
- break
283
- num_tiles_in_image = num_width_tiles * num_height_tiles
284
-
285
- # [hw, D]
286
- global_features = images_embeds[tile_index]
287
-
288
- # [num_height_tiles * num_width_tiles, hw, D]
289
- local_features = images_embeds[
290
- tile_index + 1 : tile_index + 1 + num_tiles_in_image
291
- ]
292
- tile_index += num_tiles_in_image + 1
293
-
294
- # format global and local features
295
- # ----------------- global view add newline -----------------
296
- # [hw, D] -> [h, w, D]
297
- global_features = global_features.view(h, w, n_dim)
298
-
299
- # [D] -> [h, 1, D]
300
- new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
301
-
302
- # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
303
- global_features = torch.cat([global_features, new_lines_in_global], dim=1)
304
-
305
- # [h, w + 1, D] -> [h * (w + 1), D]
306
- global_features = global_features.view(-1, n_dim)
307
-
308
- # ----------------- local view add newline -----------------
309
- # [num_height_tiles * num_width_tiles, h * w, D] ->
310
- # [num_height_tiles * h, num_width_tiles * w, D]
311
- local_features = rearrange(
312
- local_features,
313
- "(th tw) (h w) d -> (th h) (tw w) d",
314
- th=num_height_tiles,
315
- tw=num_width_tiles,
316
- h=h,
317
- w=w,
267
+ for item in items:
268
+ assert item.pixel_values.dim() == 4
269
+ image_feature = self.vision.forward_features(
270
+ item.pixel_values.type(next(self.vision.parameters()).dtype).to(
271
+ device=next(self.vision.parameters()).device
272
+ )
318
273
  )
274
+ images_embeds = self.projector(image_feature)
275
+ _, hw, n_dim = images_embeds.shape
276
+ h = w = int(hw**0.5)
277
+ tile_index = 0
278
+ for jdx in range(item.image_spatial_crop.shape[1]):
279
+ num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
280
+ if num_width_tiles == 0 or num_height_tiles == 0:
281
+ break
282
+ num_tiles_in_image = num_width_tiles * num_height_tiles
283
+
284
+ # [hw, D]
285
+ global_features = images_embeds[tile_index]
286
+
287
+ # [num_height_tiles * num_width_tiles, hw, D]
288
+ local_features = images_embeds[
289
+ tile_index + 1 : tile_index + 1 + num_tiles_in_image
290
+ ]
291
+ tile_index += num_tiles_in_image + 1
319
292
 
320
- # [D] -> [num_height_tiles * h, 1, D]
321
- new_lines_in_local = repeat(
322
- self.image_newline,
323
- "d -> (th h) 1 d",
324
- th=num_height_tiles,
325
- h=h,
326
- )
293
+ # format global and local features
294
+ # ----------------- global view add newline -----------------
295
+ # [hw, D] -> [h, w, D]
296
+ global_features = global_features.view(h, w, n_dim)
297
+
298
+ # [D] -> [h, 1, D]
299
+ new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
327
300
 
328
- # [num_height_tiles * h, num_width_tiles * w + 1, D]
329
- local_features = torch.cat([local_features, new_lines_in_local], dim=1)
330
-
331
- # [num_height_tiles * h, num_width_tiles * w + 1, D]
332
- # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
333
- local_features = local_features.view(-1, n_dim)
334
-
335
- # merge global and local tiles
336
- if self.global_view_pos == "head":
337
- global_local_features = torch.cat(
338
- [
339
- global_features,
340
- self.view_seperator[None, :],
341
- local_features,
342
- ]
301
+ # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
302
+ global_features = torch.cat(
303
+ [global_features, new_lines_in_global], dim=1
343
304
  )
344
- else:
345
- global_local_features = torch.cat(
346
- [
347
- local_features,
348
- self.view_seperator[None, :],
349
- global_features,
350
- ]
305
+
306
+ # [h, w + 1, D] -> [h * (w + 1), D]
307
+ global_features = global_features.view(-1, n_dim)
308
+
309
+ # ----------------- local view add newline -----------------
310
+ # [num_height_tiles * num_width_tiles, h * w, D] ->
311
+ # [num_height_tiles * h, num_width_tiles * w, D]
312
+ local_features = rearrange(
313
+ local_features,
314
+ "(th tw) (h w) d -> (th h) (tw w) d",
315
+ th=num_height_tiles,
316
+ tw=num_width_tiles,
317
+ h=h,
318
+ w=w,
351
319
  )
352
320
 
353
- images_in_this_batch.append(global_local_features)
321
+ # [D] -> [num_height_tiles * h, 1, D]
322
+ new_lines_in_local = repeat(
323
+ self.image_newline,
324
+ "d -> (th h) 1 d",
325
+ th=num_height_tiles,
326
+ h=h,
327
+ )
328
+
329
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
330
+ local_features = torch.cat([local_features, new_lines_in_local], dim=1)
331
+
332
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
333
+ # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
334
+ local_features = local_features.view(-1, n_dim)
335
+
336
+ # merge global and local tiles
337
+ if self.global_view_pos == "head":
338
+ global_local_features = torch.cat(
339
+ [
340
+ global_features,
341
+ self.view_seperator[None, :],
342
+ local_features,
343
+ ]
344
+ )
345
+ else:
346
+ global_local_features = torch.cat(
347
+ [
348
+ local_features,
349
+ self.view_seperator[None, :],
350
+ global_features,
351
+ ]
352
+ )
353
+
354
+ images_in_this_batch.append(global_local_features)
354
355
 
355
356
  return torch.cat(images_in_this_batch, dim=0)
356
357
 
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
- from transformers import (
25
- AutoModel,
26
- BatchFeature,
27
- Gemma3Config,
28
- Gemma3Processor,
29
- PreTrainedModel,
30
- )
31
- from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
24
+ from transformers import AutoModel, Gemma3Config, PreTrainedModel
32
25
 
33
26
  from sglang.srt.hf_transformers_utils import get_processor
34
27
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
38
31
  MultiModalityDataPaddingPatternTokenPairs,
39
32
  general_mm_embed_routine,
40
33
  )
41
- from sglang.srt.managers.schedule_batch import MultimodalInputs
34
+ from sglang.srt.managers.schedule_batch import (
35
+ MultimodalDataItem,
36
+ MultimodalInputs,
37
+ flatten_nested_list,
38
+ )
42
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
40
  from sglang.srt.model_loader.weight_utils import (
44
41
  default_weight_loader,
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
274
271
  """
275
272
  return self.language_model.get_attention_sliding_window_size()
276
273
 
277
- def get_image_feature(self, image_input: MultimodalInputs):
274
+ def get_image_feature(self, items: List[MultimodalDataItem]):
278
275
  """
279
276
  Projects the last hidden state from the vision model into language model space.
280
277
 
281
- Args:
282
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
283
- The tensors corresponding to the input images.
284
278
  Returns:
285
279
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
286
280
  """
287
- pixel_values = image_input.pixel_values
281
+ pixel_values = torch.stack(
282
+ flatten_nested_list([item.pixel_values for item in items]), dim=0
283
+ )
288
284
  pixel_values = pixel_values.to("cuda")
289
285
  pixel_values = pixel_values.to(dtype=self.language_model.dtype())
290
286
 
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
292
288
  image_features = self.multi_modal_projector(vision_outputs)
293
289
  return image_features
294
290
 
295
- def embed_mm_inputs(
296
- self,
297
- input_ids: torch.Tensor,
298
- forward_batch: ForwardBatch,
299
- image_input: MultimodalInputs,
300
- ) -> torch.Tensor:
301
- if input_ids is None:
302
- raise ValueError("Unimplemented")
303
- # boolean-masking image tokens
304
- special_image_mask = torch.isin(
305
- input_ids,
306
- torch.tensor(image_input.pad_values, device=input_ids.device),
307
- ).unsqueeze(-1)
308
- num_image_tokens_in_input_ids = special_image_mask.sum()
309
-
310
- inputs_embeds = None
311
- if num_image_tokens_in_input_ids == 0:
312
- inputs_embeds = self.get_input_embeddings()(input_ids)
313
- return inputs_embeds
314
- else:
315
- # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
316
- image_features = self.get_image_feature(image_input.pixel_values)
317
-
318
- # print(f"image tokens from image embeddings: {image_features.numel()}")
319
- num_image_tokens_in_embedding = (
320
- image_features.shape[0] * image_features.shape[1]
321
- )
322
-
323
- if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
324
- num_image = num_image_tokens_in_input_ids // image_features.shape[1]
325
- image_features = image_features[:num_image, :]
326
- logger.warning(
327
- f"Number of images does not match number of special image tokens in the input text. "
328
- f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
329
- "tokens from image embeddings."
330
- )
331
-
332
- # Important: clamp after extracting original image boundaries
333
- input_ids.clamp_(min=0, max=self.vocab_size - 1)
334
-
335
- inputs_embeds = self.get_input_embeddings()(input_ids)
336
-
337
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
338
- inputs_embeds.device
339
- )
340
-
341
- image_features = image_features.to(
342
- inputs_embeds.device, inputs_embeds.dtype
343
- )
344
- inputs_embeds = inputs_embeds.masked_scatter(
345
- special_image_mask, image_features
346
- )
347
-
348
- return inputs_embeds
349
-
350
291
  @torch.no_grad()
351
292
  def forward(
352
293
  self,
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
405
346
  else:
406
347
  llm_input_ids = input_ids
407
348
 
408
- inputs_embeds = general_mm_embed_routine(
349
+ hs = general_mm_embed_routine(
409
350
  input_ids=llm_input_ids,
410
351
  forward_batch=forward_batch,
411
- embed_tokens=self.get_input_embeddings(),
412
- mm_data_embedding_func=self.get_image_feature,
413
- )
414
-
415
- outputs = self.language_model(
416
- input_ids=None,
352
+ language_model=self.language_model,
353
+ image_data_embedding_func=self.get_image_feature,
417
354
  positions=positions,
418
- forward_batch=forward_batch,
419
- input_embeds=inputs_embeds,
420
- **kwargs,
421
355
  )
422
356
 
423
- return outputs
357
+ return hs
424
358
 
425
359
  def tie_weights(self):
426
360
  return self.language_model.tie_weights()
@@ -17,7 +17,7 @@
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
428
428
  else:
429
429
  return self.pooler(hidden_states, forward_batch)
430
430
 
431
+ def get_input_embeddings(self) -> nn.Embedding:
432
+ return self.model.embed_tokens
433
+
431
434
  def get_hidden_dim(self, module_name):
432
435
  # return input_dim, output_dim
433
436
  if module_name in ["q_proj", "o_proj", "qkv_proj"]:
@@ -31,7 +31,7 @@ from transformers import (
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
32
 
33
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import MultimodalInputs
34
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
35
35
  from sglang.srt.mm_utils import (
36
36
  get_anyres_image_grid_shape,
37
37
  unpad_image,
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.llama import LlamaForCausalLM
43
43
  from sglang.srt.models.mistral import MistralForCausalLM
44
44
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
45
- from sglang.srt.utils import add_prefix
45
+ from sglang.srt.utils import add_prefix, flatten_nested_list
46
46
 
47
47
 
48
48
  class LlavaBaseForCausalLM(nn.Module):
49
49
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
50
- image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
50
+ image_sizes = flatten_nested_list(
51
+ [item.image_sizes for item in image_inputs.mm_items]
52
+ )
53
+
54
+ pad_values = [item.pad_value for item in image_inputs.mm_items]
51
55
 
52
56
  # hardcode for spatial_unpad + anyres
53
- if image_inputs.modalities is not None and (
54
- "multi-images" in image_inputs.modalities
55
- or "video" in image_inputs.modalities
57
+ if any(
58
+ item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
59
+ for item in image_inputs.mm_items
56
60
  ):
57
61
  image_aspect_ratio = "pad"
58
62
  else:
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
66
70
  math.ceil(self.image_size / self.patch_size / 2) ** 2
67
71
  )
68
72
  else:
69
- new_image_feature_len = self.image_feature_len # multiimage
73
+ new_image_feature_len = self.image_feature_len # multi-image
70
74
 
71
75
  height = width = self.num_patches_per_side
72
76
  if "anyres" in image_aspect_ratio:
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
101
105
  # old_len + pad_len - 1, because we need to remove image_token_id
102
106
  input_ids = (
103
107
  input_ids[:offset]
104
- + [pad_values[image_idx]] * new_image_feature_len
108
+ + [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
105
109
  + input_ids[offset + 1 :]
106
110
  )
107
111
  offset_list.append(offset)
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
150
154
  modalities_list = []
151
155
  max_image_offset = []
152
156
  for im in image_inputs:
153
- if im and im.modalities is not None:
154
- modalities_list.extend(im.modalities)
157
+ if im:
158
+ modalities_list.extend([item.modality for item in im.mm_items])
155
159
  if im and im.image_offsets:
156
160
  max_image_offset.append(
157
161
  np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
164
168
 
165
169
  if need_vision.any():
166
170
  bs = forward_batch.batch_size
167
- pixel_values = [
168
- image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
169
- ]
171
+ pixel_values = flatten_nested_list(
172
+ [
173
+ [item.pixel_values for item in image_inputs[i].mm_items]
174
+ for i in range(bs)
175
+ if need_vision[i]
176
+ ]
177
+ )
170
178
  image_sizes = [
171
- image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
179
+ flatten_nested_list(
180
+ [item.image_sizes for item in image_inputs[i].mm_items]
181
+ )
182
+ for i in range(bs)
183
+ if need_vision[i]
172
184
  ]
173
185
 
174
186
  ########## Encode Image ########
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
197
209
  new_image_features = []
198
210
  height = width = self.num_patches_per_side
199
211
  for image_idx, image_feature in enumerate(image_features):
200
- if modalities_list[image_idx] == "image":
212
+ if modalities_list[image_idx] == Modality.IMAGE:
201
213
  image_aspect_ratio = (
202
214
  self.config.image_aspect_ratio
203
215
  ) # single image
204
216
  elif (
205
- modalities_list[image_idx] == "multi-images"
206
- or modalities_list[image_idx] == "video"
217
+ modalities_list[image_idx] == Modality.MULTI_IMAGES
218
+ or modalities_list[image_idx] == Modality.VIDEO
207
219
  ):
208
220
  image_aspect_ratio = "pad" # multi image
209
221
  # image_aspect_ratio = (
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
212
224
  if (
213
225
  image_feature.shape[0] > 1
214
226
  and "anyres" in image_aspect_ratio
215
- and modalities_list[image_idx] == "image"
227
+ and modalities_list[image_idx] == Modality.IMAGE
216
228
  ):
217
229
  base_image_feature = image_feature[0]
218
230
  image_feature = image_feature[1:]
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
312
324
  )
313
325
  image_feature = image_feature.unsqueeze(0)
314
326
  else:
315
- if modalities_list[image_idx] == "video": # video
327
+ if modalities_list[image_idx] == Modality.VIDEO: # video
316
328
  # 2x2 pooling
317
329
  num_of_frames = image_feature.shape[0]
318
330
  image_feature = image_feature.view(
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
23
 
24
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
- from sglang.srt.managers.schedule_batch import MultimodalInputs
25
+ from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
58
58
  )
59
59
 
60
60
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
61
- pad_values = image_inputs.pad_values
61
+ pad_values = [item.pad_value for item in image_inputs.mm_items]
62
62
  new_image_feature_len = self.image_feature_len
63
63
 
64
64
  pad_ids = pad_values * (
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
133
133
  need_vision = start_positions <= np.array(max_image_offset)
134
134
 
135
135
  if need_vision.any():
136
- pixel_values = [
137
- image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
138
- ]
136
+ pixel_values = flatten_nested_list(
137
+ [
138
+ [item.pixel_values for item in image_inputs[i].mm_items]
139
+ for i in range(bs)
140
+ if need_vision[i]
141
+ ]
142
+ )
139
143
  image_offsets = [
140
- image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
144
+ flatten_nested_list(
145
+ [item.image_offsets for item in image_inputs[i].mm_items]
146
+ )
147
+ for i in range(bs)
148
+ if need_vision[i]
141
149
  ]
142
150
 
143
151
  ########## Encode Image ########
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
246
254
  "model.mm_projector.2": "multi_modal_projector.linear_2",
247
255
  "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
248
256
  "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
249
- "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
257
+ "model.vision_tower.vision_tower": "vision_tower",
258
+ # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
250
259
  "model.image_newline": "language_model.model.image_newline",
251
260
  }
252
261
  params_dict = dict(self.named_parameters())