sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  12. sglang/srt/layers/attention/vision.py +243 -40
  13. sglang/srt/layers/dp_attention.py +3 -1
  14. sglang/srt/layers/layernorm.py +5 -5
  15. sglang/srt/layers/linear.py +24 -9
  16. sglang/srt/layers/logits_processor.py +1 -1
  17. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  18. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  22. sglang/srt/layers/parameter.py +16 -7
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/fp8.py +11 -1
  33. sglang/srt/layers/rotary_embedding.py +34 -13
  34. sglang/srt/layers/sampler.py +33 -10
  35. sglang/srt/layers/torchao_utils.py +12 -6
  36. sglang/srt/managers/detokenizer_manager.py +1 -0
  37. sglang/srt/managers/image_processor.py +77 -38
  38. sglang/srt/managers/io_struct.py +36 -5
  39. sglang/srt/managers/schedule_batch.py +31 -25
  40. sglang/srt/managers/scheduler.py +78 -38
  41. sglang/srt/managers/tokenizer_manager.py +4 -0
  42. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  43. sglang/srt/mem_cache/chunk_cache.py +3 -0
  44. sglang/srt/mem_cache/radix_cache.py +30 -1
  45. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  46. sglang/srt/model_executor/forward_batch_info.py +5 -7
  47. sglang/srt/model_executor/model_runner.py +7 -4
  48. sglang/srt/model_loader/loader.py +75 -0
  49. sglang/srt/model_loader/weight_utils.py +91 -5
  50. sglang/srt/models/commandr.py +14 -2
  51. sglang/srt/models/dbrx.py +9 -1
  52. sglang/srt/models/deepseek_v2.py +3 -3
  53. sglang/srt/models/gemma2.py +9 -1
  54. sglang/srt/models/grok.py +1 -0
  55. sglang/srt/models/minicpm3.py +3 -3
  56. sglang/srt/models/minicpmv.py +129 -76
  57. sglang/srt/models/mllama.py +16 -56
  58. sglang/srt/models/qwen2.py +4 -1
  59. sglang/srt/models/qwen2_vl.py +18 -8
  60. sglang/srt/models/torch_native_llama.py +17 -4
  61. sglang/srt/openai_api/adapter.py +139 -37
  62. sglang/srt/openai_api/protocol.py +5 -4
  63. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  64. sglang/srt/sampling/sampling_batch_info.py +4 -14
  65. sglang/srt/server.py +2 -2
  66. sglang/srt/server_args.py +26 -1
  67. sglang/srt/speculative/eagle_utils.py +37 -15
  68. sglang/srt/speculative/eagle_worker.py +11 -13
  69. sglang/srt/utils.py +62 -67
  70. sglang/test/test_programs.py +1 -0
  71. sglang/test/test_utils.py +81 -22
  72. sglang/utils.py +42 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
  75. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
  76. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ import huggingface_hub.constants
27
27
  import numpy as np
28
28
  import torch
29
29
  from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
30
+ from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
30
31
  from safetensors.torch import load_file, safe_open, save_file
31
32
  from tqdm.auto import tqdm
32
33
 
@@ -403,8 +404,13 @@ def np_cache_weights_iterator(
403
404
 
404
405
  def safetensors_weights_iterator(
405
406
  hf_weights_files: List[str],
407
+ is_all_weights_sharded: bool = False,
406
408
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
407
- """Iterate over the weights in the model safetensor files."""
409
+ """Iterate over the weights in the model safetensor files.
410
+
411
+ If is_all_weights_sharded is True, it uses more optimize read by reading an
412
+ entire file instead of reading each tensor one by one.
413
+ """
408
414
  enable_tqdm = (
409
415
  not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
410
416
  )
@@ -414,9 +420,14 @@ def safetensors_weights_iterator(
414
420
  disable=not enable_tqdm,
415
421
  bar_format=_BAR_FORMAT,
416
422
  ):
417
- with safe_open(st_file, framework="pt") as f:
418
- for name in f.keys(): # noqa: SIM118
419
- param = f.get_tensor(name)
423
+ if not is_all_weights_sharded:
424
+ with safe_open(st_file, framework="pt") as f:
425
+ for name in f.keys(): # noqa: SIM118
426
+ param = f.get_tensor(name)
427
+ yield name, param
428
+ else:
429
+ result = load_file(st_file, device="cpu")
430
+ for name, param in result.items():
420
431
  yield name, param
421
432
 
422
433
 
@@ -650,6 +661,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
650
661
  return name
651
662
 
652
663
 
664
+ # Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
665
+ class KVCacheQuantSchema(BaseModel):
666
+ dtype: str
667
+ # Each key is a TP rank. Each value is a dictionary mapping a TP rank's
668
+ # layer indices to their per-tensor KV cache scaling factor.
669
+ # TODO: Consider pulling this and its validation methods out into its
670
+ # own schema class (tricky as its members are variable)
671
+ scaling_factor: Dict[int, Dict[int, float]]
672
+
673
+ @model_validator(mode="after")
674
+ def check_is_fp8(self) -> "KVCacheQuantSchema":
675
+ assert self.dtype == "float8_e4m3fn", (
676
+ "Loaded scaling factors intended for KV cache dtype = "
677
+ f"{self.dtype} rather than float8_e4m3fn!"
678
+ )
679
+ return self
680
+
681
+ @model_validator(mode="after")
682
+ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
683
+ context = info.context
684
+ if context:
685
+ tp_size = context["tp_size"]
686
+ num_hidden_layers = context["num_hidden_layers"]
687
+ assert len(self.scaling_factor) == tp_size, (
688
+ f"Loaded dictionary has TP size {len(self.scaling_factor)} "
689
+ f"but LLM engine is currently running with TP size {tp_size}."
690
+ )
691
+ for tp_rank, layer_maps in self.scaling_factor.items():
692
+ assert len(layer_maps) == num_hidden_layers, (
693
+ f"KV cache scales map for TP rank {tp_rank} is malformed. "
694
+ f"Expected {num_hidden_layers} layers, got "
695
+ f"{len(layer_maps)}."
696
+ )
697
+ for i in range(tp_size):
698
+ assert (
699
+ i in self.scaling_factor
700
+ ), f"KV cache scales map for TP rank {i} not found."
701
+ return self
702
+
703
+ @model_validator(mode="after")
704
+ def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
705
+ context = info.context
706
+ if context:
707
+ tp_rank = context["tp_rank"]
708
+ num_hidden_layers = context["num_hidden_layers"]
709
+ layer_scales_map = self.scaling_factor[tp_rank]
710
+ for i in range(num_hidden_layers):
711
+ assert i in layer_scales_map, (
712
+ f"Could not find KV cache scales for layer {i} in "
713
+ f"TP rank {tp_rank}."
714
+ )
715
+ return self
716
+
717
+
718
+ class QuantParamSchema(BaseModel):
719
+ # TODO: Generalize and extend with more fields
720
+ # (e.g. weights/activations params) once functionality is enabled
721
+ model_config = ConfigDict(protected_namespaces=())
722
+ model_type: Optional[str]
723
+ kv_cache: KVCacheQuantSchema
724
+
725
+ @model_validator(mode="after")
726
+ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
727
+ context = info.context
728
+ if context:
729
+ model_type = context.get("model_type", None)
730
+ if model_type is not None:
731
+ assert model_type == self.model_type, (
732
+ f"Model type is {model_type} but loaded "
733
+ f"scaling factors belonging to different "
734
+ f"model type {self.model_type}!"
735
+ )
736
+ return self
737
+
738
+
653
739
  def kv_cache_scales_loader(
654
740
  filename: str,
655
741
  tp_rank: int,
@@ -681,7 +767,7 @@ def kv_cache_scales_loader(
681
767
  except json.JSONDecodeError:
682
768
  logger.error("Error decoding JSON in file '%s'.", filename)
683
769
  except Exception:
684
- logger.exception("An error occurred while reading '%s'.", filename)
770
+ logger.error("An error occurred while reading '%s'.", filename)
685
771
  # This section is reached if and only if any of the excepts are hit
686
772
  # Return an empty iterable (list) => no KV cache scales are loaded
687
773
  # which ultimately defaults to 1.0 scales
@@ -61,7 +61,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
61
61
  from sglang.srt.layers.rotary_embedding import get_rope
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
- from sglang.srt.model_loader.weight_utils import default_weight_loader
64
+ from sglang.srt.model_loader.weight_utils import (
65
+ default_weight_loader,
66
+ maybe_remap_kv_scale_name,
67
+ )
65
68
  from sglang.srt.utils import get_compiler_backend, set_weight_attrs
66
69
 
67
70
 
@@ -372,10 +375,19 @@ class CohereForCausalLM(nn.Module):
372
375
  # Skip loading extra bias for GPTQ models.
373
376
  if name.endswith(".bias") and name not in params_dict:
374
377
  continue
378
+ # Remapping the name of FP8 kv-scale.
379
+ name = maybe_remap_kv_scale_name(name, params_dict)
380
+ if name is None:
381
+ continue
382
+
375
383
  param = params_dict[name]
376
384
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
377
385
  weight_loader(param, loaded_weight)
378
386
  loaded_params.add(name)
379
387
 
380
388
 
381
- EntryClass = CohereForCausalLM
389
+ class Cohere2ForCausalLM(CohereForCausalLM):
390
+ pass
391
+
392
+
393
+ EntryClass = [CohereForCausalLM, Cohere2ForCausalLM]
sglang/srt/models/dbrx.py CHANGED
@@ -42,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  VocabParallelEmbedding,
43
43
  )
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.model_loader.weight_utils import (
46
+ default_weight_loader,
47
+ maybe_remap_kv_scale_name,
48
+ )
46
49
  from sglang.srt.utils import set_weight_attrs
47
50
 
48
51
 
@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module):
411
414
  weight_loader(param, loaded_weight, weight_name)
412
415
  break
413
416
  else:
417
+ # Remapping the name of FP8 kv-scale.
418
+ name = maybe_remap_kv_scale_name(name, params_dict)
419
+ if name is None:
420
+ continue
421
+
414
422
  param = params_dict[name]
415
423
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
416
424
  weight_loader(param, loaded_weight)
@@ -56,12 +56,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
56
56
  from sglang.srt.managers.schedule_batch import global_server_args_dict
57
57
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
58
  from sglang.srt.model_loader.weight_utils import default_weight_loader
59
- from sglang.srt.utils import is_flashinfer_available, is_hip
59
+ from sglang.srt.utils import is_cuda_available, is_hip
60
60
 
61
61
  is_hip_ = is_hip()
62
62
 
63
- if is_flashinfer_available():
64
- from flashinfer import bmm_fp8
63
+ if is_cuda_available():
64
+ from sgl_kernel import bmm_fp8
65
65
 
66
66
 
67
67
  class DeepseekV2MLP(nn.Module):
@@ -35,7 +35,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
35
35
  from sglang.srt.layers.rotary_embedding import get_rope
36
36
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
37
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
- from sglang.srt.model_loader.weight_utils import default_weight_loader
38
+ from sglang.srt.model_loader.weight_utils import (
39
+ default_weight_loader,
40
+ maybe_remap_kv_scale_name,
41
+ )
39
42
  from sglang.srt.utils import make_layers
40
43
 
41
44
 
@@ -424,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module):
424
427
  # Skip loading extra bias for GPTQ models.
425
428
  if name.endswith(".bias") and name not in params_dict:
426
429
  continue
430
+ # Remapping the name of FP8 kv-scale.
431
+ name = maybe_remap_kv_scale_name(name, params_dict)
432
+ if name is None:
433
+ continue
434
+
427
435
  param = params_dict[name]
428
436
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
429
437
  weight_loader(param, loaded_weight)
sglang/srt/models/grok.py CHANGED
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
133
133
  renormalize=False,
134
134
  quant_config=quant_config,
135
135
  tp_size=tp_size,
136
+ activation="gelu",
136
137
  use_presharded_weights=use_presharded_weights,
137
138
  )
138
139
 
@@ -40,10 +40,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
- from sglang.srt.utils import is_flashinfer_available
43
+ from sglang.srt.utils import is_cuda_available
44
44
 
45
- if is_flashinfer_available():
46
- from flashinfer import bmm_fp8
45
+ if is_cuda_available():
46
+ from sgl_kernel import bmm_fp8
47
47
 
48
48
 
49
49
  class MiniCPM3MLP(nn.Module):
@@ -1,6 +1,6 @@
1
1
  # Adapted from
2
2
  # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
3
- # Copyright 2023 The vLLM team.
3
+ # Copyright 2023 The SGLang team.
4
4
  # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
5
  #
6
6
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@@ -20,7 +20,7 @@
20
20
  # See the License for the specific language governing permissions and
21
21
  # limitations under the License.
22
22
  """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
23
- from functools import cached_property, partial
23
+ from functools import partial
24
24
  from typing import (
25
25
  Any,
26
26
  Callable,
@@ -33,16 +33,13 @@ from typing import (
33
33
  Union,
34
34
  )
35
35
 
36
+ import numpy as np
36
37
  import torch
37
38
  import torch.types
38
39
  from PIL import Image
39
40
  from torch import nn
40
41
  from torch.nn.init import trunc_normal_
41
42
  from transformers import PretrainedConfig
42
- from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
43
- from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
- from vllm.model_executor.models.module_mapping import MultiModelKeys
45
- from vllm.model_executor.sampling_metadata import SamplingMetadata
46
43
 
47
44
  from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
48
45
  from sglang.srt.layers.activation import get_act_fn
@@ -63,6 +60,88 @@ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
63
60
  RawImageType = Union[Image.Image, torch.Tensor]
64
61
 
65
62
 
63
+ # sin/cos positional embedding helpers are adapted from:
64
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
65
+ def get_1d_sincos_pos_embed_from_grid(
66
+ embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
67
+ ) -> torch.Tensor:
68
+ """
69
+ embed_dim: output dimension for each position
70
+ pos: a list of positions to be encoded: size (M,) / (H, W)
71
+ out: (M, D) / (H, W, D)
72
+ """
73
+ assert embed_dim % 2 == 0
74
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
75
+ omega /= embed_dim / 2.0
76
+ omega = 1.0 / 10000**omega # (D/2,)
77
+
78
+ if version == (2, 0):
79
+ pos = pos.reshape(-1) # (M,)
80
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
81
+ emb_sin = np.sin(out) # (M, D/2)
82
+ emb_cos = np.cos(out) # (M, D/2)
83
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
84
+ else:
85
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
86
+ emb_sin = np.sin(out) # (H, W, D/2)
87
+ emb_cos = np.cos(out) # (H, W, D/2)
88
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
89
+ return emb
90
+
91
+
92
+ def get_2d_sincos_pos_embed_from_grid(
93
+ embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
94
+ ) -> torch.Tensor:
95
+ assert embed_dim % 2 == 0
96
+
97
+ # use half of dimensions to encode grid_h
98
+ emb_h = get_1d_sincos_pos_embed_from_grid(
99
+ embed_dim // 2, grid[0], version
100
+ ) # (H*W, D/2) or (H, W, D/2)
101
+ emb_w = get_1d_sincos_pos_embed_from_grid(
102
+ embed_dim // 2, grid[1], version
103
+ ) # (H*W, D/2) or (H, W, D/2)
104
+
105
+ if version == (2, 0):
106
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
107
+ else:
108
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
109
+ return emb
110
+
111
+
112
+ def get_2d_sincos_pos_embed(
113
+ embed_dim: int,
114
+ grid_size: Union[int, Tuple[int, int]],
115
+ cls_token: bool = False,
116
+ version: Tuple[int, int] = (2, 0),
117
+ ) -> torch.Tensor:
118
+ """
119
+ grid_size: int of the grid height and width
120
+ return:
121
+ pos_embed: [grid_size*grid_size, embed_dim] or
122
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
123
+ """
124
+ if isinstance(grid_size, int):
125
+ grid_h_size, grid_w_size = grid_size, grid_size
126
+ else:
127
+ grid_h_size, grid_w_size = grid_size[0], grid_size[1]
128
+
129
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
130
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
131
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
132
+ grid = np.stack(grid, axis=0)
133
+ assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size)
134
+
135
+ if version == (2, 0):
136
+ grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
137
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
138
+ if cls_token:
139
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
140
+ else:
141
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
142
+ return pos_embed
143
+
144
+
66
145
  class Idefics2VisionMLP(nn.Module):
67
146
 
68
147
  def __init__(
@@ -116,6 +195,10 @@ class Idefics2EncoderLayer(nn.Module):
116
195
  projection_size=config.intermediate_size,
117
196
  use_qkv_parallel=True,
118
197
  quant_config=quant_config,
198
+ dropout=config.attention_dropout,
199
+ use_context_forward=False,
200
+ use_full_precision_softmax=True,
201
+ flatten_batch=False,
119
202
  prefix=f"{prefix}.self_attn",
120
203
  )
121
204
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -126,7 +209,6 @@ class Idefics2EncoderLayer(nn.Module):
126
209
  self,
127
210
  hidden_states: torch.Tensor,
128
211
  cu_seqlens: torch.Tensor,
129
- forward_batch: ForwardBatch,
130
212
  ) -> torch.Tensor:
131
213
  """
132
214
  Args:
@@ -136,11 +218,8 @@ class Idefics2EncoderLayer(nn.Module):
136
218
  """
137
219
  residual = hidden_states
138
220
  hidden_states = self.layer_norm1(hidden_states)
139
- hidden_states = self.self_attn(
140
- hidden_states,
141
- cu_seqlens=cu_seqlens,
142
- # , forward_batch=forward_batch
143
- )
221
+ hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens)
222
+
144
223
  hidden_states = residual + hidden_states
145
224
  residual = hidden_states
146
225
  hidden_states = self.layer_norm2(hidden_states)
@@ -181,7 +260,6 @@ class Idefics2Encoder(nn.Module):
181
260
  self,
182
261
  inputs_embeds: torch.Tensor,
183
262
  cu_seqlens: torch.Tensor,
184
- forward_batch: ForwardBatch,
185
263
  ) -> torch.Tensor:
186
264
  r"""
187
265
  Args:
@@ -195,7 +273,8 @@ class Idefics2Encoder(nn.Module):
195
273
  hidden_states = inputs_embeds
196
274
  for encoder_layer in self.layers:
197
275
  layer_outputs = encoder_layer(
198
- hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
276
+ hidden_states,
277
+ cu_seqlens=cu_seqlens,
199
278
  )
200
279
  hidden_states = layer_outputs
201
280
  return hidden_states
@@ -232,19 +311,14 @@ class Idefics2VisionEmbeddings(nn.Module):
232
311
  self.num_positions = self.num_patches
233
312
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
234
313
 
235
- def forward(
314
+ def get_position_ids(
236
315
  self,
237
316
  pixel_values: torch.FloatTensor,
238
317
  patch_attention_mask: torch.BoolTensor,
239
318
  tgt_sizes: Optional[torch.IntTensor] = None,
240
- ) -> torch.Tensor:
319
+ ):
241
320
  batch_size, _, max_im_h, max_im_w = pixel_values.shape
242
- target_dtype = self.patch_embedding.weight.dtype
243
- pixel_values = pixel_values.to(
244
- device=self.patch_embedding.weight.device, dtype=target_dtype
245
- )
246
- patch_embeds = self.patch_embedding(pixel_values)
247
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
321
+
248
322
  max_nb_patches_h, max_nb_patches_w = (
249
323
  max_im_h // self.patch_size,
250
324
  max_im_w // self.patch_size,
@@ -277,6 +351,24 @@ class Idefics2VisionEmbeddings(nn.Module):
277
351
  ).flatten()
278
352
  position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
279
353
  position_ids = position_ids.to(self.position_embedding.weight.device)
354
+ return position_ids
355
+
356
+ def forward(
357
+ self,
358
+ pixel_values: torch.FloatTensor,
359
+ patch_attention_mask: torch.BoolTensor,
360
+ tgt_sizes: Optional[torch.IntTensor] = None,
361
+ ) -> torch.Tensor:
362
+ target_dtype = self.patch_embedding.weight.dtype
363
+ pixel_values = pixel_values.to(
364
+ device=self.patch_embedding.weight.device, dtype=target_dtype
365
+ )
366
+ patch_embeds = self.patch_embedding(pixel_values)
367
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
368
+ position_ids = self.get_position_ids(
369
+ pixel_values, patch_attention_mask, tgt_sizes
370
+ )
371
+
280
372
  embeddings = embeddings + self.position_embedding(position_ids)
281
373
  return embeddings
282
374
 
@@ -287,7 +379,6 @@ class Idefics2VisionTransformer(nn.Module):
287
379
  self,
288
380
  config: PretrainedConfig,
289
381
  quant_config: Optional[QuantizationConfig] = None,
290
- prefix: str = "",
291
382
  ) -> None:
292
383
  super().__init__()
293
384
 
@@ -302,8 +393,6 @@ class Idefics2VisionTransformer(nn.Module):
302
393
 
303
394
  def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
304
395
  patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
305
-
306
- # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
307
396
  cu_seqlens = torch.cat(
308
397
  [
309
398
  torch.tensor([0], device=patch_len.device, dtype=torch.int32),
@@ -316,19 +405,18 @@ class Idefics2VisionTransformer(nn.Module):
316
405
  def forward(
317
406
  self,
318
407
  pixel_values,
319
- forward_batch: ForwardBatch,
320
408
  patch_attention_mask: Optional[torch.BoolTensor] = None,
321
409
  tgt_sizes: Optional[torch.IntTensor] = None,
322
410
  ) -> torch.Tensor:
323
411
  hidden_states = self.embeddings(
324
412
  pixel_values=pixel_values,
325
413
  patch_attention_mask=patch_attention_mask,
326
- # forward_batch=forward_batch,
327
414
  tgt_sizes=tgt_sizes,
328
415
  )
329
416
  cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
330
417
  encoder_outputs = self.encoder(
331
- hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
418
+ hidden_states,
419
+ cu_seqlens=cu_seqlens,
332
420
  )
333
421
  last_hidden_state = self.post_layernorm(encoder_outputs)
334
422
  return last_hidden_state
@@ -573,14 +661,12 @@ class MiniCPMVBaseModel(nn.Module):
573
661
  config: PretrainedConfig,
574
662
  quant_config: Optional[QuantizationConfig] = None,
575
663
  ):
576
- # multimodal_config = config.model_config.multimodal_config
577
664
  super().__init__()
578
665
  # All MiniCPM-V models disable `tie_word_embeddings` but
579
666
  # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
580
- # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
667
+ # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model
581
668
  # and config class
582
669
  self.config = config
583
- # self.multimodal_config = multimodal_config
584
670
 
585
671
  self.version = get_version_by_config(self.config)
586
672
  self.llm = self.init_llm(config=config, quant_config=quant_config)
@@ -598,13 +684,6 @@ class MiniCPMVBaseModel(nn.Module):
598
684
 
599
685
  self.logits_processor = LogitsProcessor(config)
600
686
 
601
- @cached_property
602
- def sampler(self):
603
- if hasattr(self.llm, "sampler"):
604
- return self.llm.sampler
605
-
606
- return get_sampler()
607
-
608
687
  def _get_image_bounds(
609
688
  self,
610
689
  input_ids: torch.Tensor,
@@ -666,7 +745,6 @@ class MiniCPMVBaseModel(nn.Module):
666
745
  self,
667
746
  input_ids: torch.Tensor,
668
747
  image_inputs: Optional[MiniCPMVImageInputs],
669
- forward_batch: ForwardBatch,
670
748
  ) -> Tuple[torch.Tensor, torch.Tensor]:
671
749
  vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
672
750
 
@@ -680,10 +758,7 @@ class MiniCPMVBaseModel(nn.Module):
680
758
  .to(vlm_embedding.device)
681
759
  )
682
760
  else:
683
- vision_hidden_states = self.get_vision_hidden_states(
684
- forward_batch, image_inputs
685
- )
686
-
761
+ vision_hidden_states = self.get_vision_hidden_states(image_inputs)
687
762
  # See NOTE in _parse_and_validate_inputs
688
763
  image_bounds = image_inputs["image_bounds"]
689
764
  if len(image_bounds) > 0:
@@ -693,6 +768,7 @@ class MiniCPMVBaseModel(nn.Module):
693
768
  for start, end in image_bounds.tolist()
694
769
  ]
695
770
  ).to(vlm_embedding.device)
771
+
696
772
  vlm_embedding.scatter_(
697
773
  0,
698
774
  image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
@@ -839,7 +915,7 @@ class MiniCPMVBaseModel(nn.Module):
839
915
  # There values are useless because their embeddings will be replaced by vision embeddings anyway.
840
916
  input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
841
917
 
842
- vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
918
+ vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
843
919
 
844
920
  # always pass the input via `inputs_embeds`
845
921
  # to make sure the computation graph is consistent
@@ -857,29 +933,6 @@ class MiniCPMVBaseModel(nn.Module):
857
933
  input_ids, hidden_states, self.llm.lm_head, forward_batch
858
934
  )
859
935
 
860
- def compute_logits(
861
- self,
862
- hidden_states: torch.Tensor,
863
- sampling_metadata: SamplingMetadata,
864
- ) -> Optional[torch.Tensor]:
865
- return self.llm.compute_logits(hidden_states, sampling_metadata)
866
-
867
- def sample(
868
- self,
869
- logits: torch.Tensor,
870
- sampling_metadata: SamplingMetadata,
871
- ) -> Optional[SamplerOutput]:
872
- next_tokens = self.sampler(logits, sampling_metadata)
873
- return next_tokens
874
-
875
- def get_mm_mapping(self) -> MultiModelKeys:
876
- """
877
- Get the module prefix in multimodal models
878
- """
879
- return MultiModelKeys.from_string_field(
880
- language_model="llm", connector="resampler", tower_model="vpm"
881
- )
882
-
883
936
  def init_llm(
884
937
  self,
885
938
  config: Qwen2Config,
@@ -910,9 +963,7 @@ class MiniCPMVBaseModel(nn.Module):
910
963
  ) -> torch.Tensor:
911
964
  raise NotImplementedError
912
965
 
913
- def get_vision_hidden_states(
914
- self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
915
- ) -> torch.Tensor:
966
+ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
916
967
  raise NotImplementedError
917
968
 
918
969
 
@@ -1019,7 +1070,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1019
1070
 
1020
1071
  def get_vision_hidden_states(
1021
1072
  self,
1022
- forward_batch: ForwardBatch,
1023
1073
  data: MiniCPMVImageInputs,
1024
1074
  ) -> torch.Tensor:
1025
1075
  pixel_values = data["data"]
@@ -1042,15 +1092,18 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1042
1092
  patch_attn_mask = torch.zeros(
1043
1093
  (B, 1, max_patches), dtype=torch.bool, device=device
1044
1094
  )
1045
- for i in range(B):
1046
- patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
1095
+
1096
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
1097
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
1098
+ patch_attn_mask[:, 0, :] = torch.arange(
1099
+ patch_attn_mask.size(2), device=patch_attn_mask.device
1100
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
1101
+
1047
1102
  vision_embedding = self.vpm(
1048
1103
  all_pixel_values.type(dtype),
1049
- forward_batch=forward_batch,
1050
1104
  patch_attention_mask=patch_attn_mask,
1051
1105
  tgt_sizes=tgt_sizes,
1052
1106
  )
1053
-
1054
1107
  return self.resampler(vision_embedding, tgt_sizes)
1055
1108
 
1056
1109
  def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
@@ -1138,7 +1191,7 @@ class MiniCPMV:
1138
1191
  """
1139
1192
  Different versions of MiniCPMV use different visual encoders and LLMs,
1140
1193
  which is not conducive to the current integration logic of LoRA and
1141
- bitsandbytes in vLLM. Therefore, it is necessary to separate them.
1194
+ bitsandbytes in SGLang. Therefore, it is necessary to separate them.
1142
1195
  """
1143
1196
 
1144
1197
  # Ensure that the LoRA support check passes when the class is not