sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
|
|
1
|
+
from typing import Iterable, List, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from einops import rearrange, repeat
|
6
|
+
from torch import nn
|
7
|
+
|
8
|
+
from sglang.srt.configs.deepseekvl2 import (
|
9
|
+
DeepseekVL2Config,
|
10
|
+
DeepseekVL2MlpProjectorConfig,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.linear import ReplicatedLinear
|
13
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
14
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
15
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
16
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
17
|
+
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
18
|
+
|
19
|
+
|
20
|
+
class DeepseekVL2MlpProjector(nn.Module):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
config: DeepseekVL2MlpProjectorConfig,
|
24
|
+
quant_config: Optional[QuantizationConfig] = None,
|
25
|
+
):
|
26
|
+
|
27
|
+
super().__init__()
|
28
|
+
|
29
|
+
self.config = config
|
30
|
+
|
31
|
+
if config.projector_type == "identity":
|
32
|
+
modules = nn.Identity()
|
33
|
+
|
34
|
+
elif config.projector_type == "linear":
|
35
|
+
self.layers = nn.ModuleList(
|
36
|
+
[
|
37
|
+
ReplicatedLinear(
|
38
|
+
config.input_dim,
|
39
|
+
config.n_embed,
|
40
|
+
quant_config=quant_config,
|
41
|
+
)
|
42
|
+
]
|
43
|
+
)
|
44
|
+
|
45
|
+
elif config.projector_type == "mlp_gelu":
|
46
|
+
mlp_depth = config.depth
|
47
|
+
self.layers = nn.ModuleList(
|
48
|
+
[
|
49
|
+
ReplicatedLinear(
|
50
|
+
config.input_dim,
|
51
|
+
config.n_embed,
|
52
|
+
quant_config=quant_config,
|
53
|
+
)
|
54
|
+
]
|
55
|
+
)
|
56
|
+
for _ in range(1, mlp_depth):
|
57
|
+
self.layers.append(nn.GELU())
|
58
|
+
self.layers.append(
|
59
|
+
ReplicatedLinear(
|
60
|
+
config.n_embed,
|
61
|
+
config.n_embed,
|
62
|
+
quant_config=quant_config,
|
63
|
+
)
|
64
|
+
)
|
65
|
+
|
66
|
+
elif config.projector_type == "downsample_mlp_gelu":
|
67
|
+
mlp_depth = config.depth
|
68
|
+
mlp_ratio = config.mlp_ratio
|
69
|
+
self.layers = nn.ModuleList(
|
70
|
+
[
|
71
|
+
ReplicatedLinear(
|
72
|
+
config.input_dim
|
73
|
+
* config.downsample_ratio
|
74
|
+
* config.downsample_ratio,
|
75
|
+
config.n_embed * mlp_ratio,
|
76
|
+
quant_config=quant_config,
|
77
|
+
)
|
78
|
+
]
|
79
|
+
)
|
80
|
+
for _ in range(1, mlp_depth - 1):
|
81
|
+
self.layers.append(nn.GELU())
|
82
|
+
self.layers.append(
|
83
|
+
ReplicatedLinear(
|
84
|
+
config.n_embed * mlp_ratio,
|
85
|
+
config.n_embed * mlp_ratio,
|
86
|
+
quant_config=quant_config,
|
87
|
+
)
|
88
|
+
)
|
89
|
+
self.layers.append(nn.GELU())
|
90
|
+
self.layers.append(
|
91
|
+
ReplicatedLinear(
|
92
|
+
config.n_embed * mlp_ratio,
|
93
|
+
config.n_embed,
|
94
|
+
quant_config=quant_config,
|
95
|
+
)
|
96
|
+
)
|
97
|
+
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unknown projector type: {config.projector_type}")
|
100
|
+
|
101
|
+
if config.token_pooling:
|
102
|
+
self.token_pooling_layer = ReplicatedLinear(
|
103
|
+
config.input_dim * 4, config.input_dim, quant_config=quant_config
|
104
|
+
)
|
105
|
+
|
106
|
+
def forward(self, x):
|
107
|
+
if self.config.token_pooling:
|
108
|
+
batch_size, wxh, channels = x.shape
|
109
|
+
w = h = int(wxh**0.5)
|
110
|
+
x = x.view(batch_size, w, h, channels)
|
111
|
+
x = x.permute(0, 3, 1, 2)
|
112
|
+
|
113
|
+
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
114
|
+
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
115
|
+
patches = patches.contiguous().view(
|
116
|
+
batch_size, channels, h_patches * w_patches, -1
|
117
|
+
)
|
118
|
+
patches = patches.permute(0, 2, 1, 3).contiguous()
|
119
|
+
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
120
|
+
|
121
|
+
x = self.token_pooling_layer(patches)[0]
|
122
|
+
|
123
|
+
elif self.config.projector_type == "downsample_mlp_gelu":
|
124
|
+
bs, hw, input_dim = x.shape
|
125
|
+
h = w = int((hw) ** 0.5)
|
126
|
+
|
127
|
+
"""compute padding"""
|
128
|
+
if h % self.config.downsample_ratio:
|
129
|
+
pad = self.config.downsample_ratio - h % self.config.downsample_ratio
|
130
|
+
else:
|
131
|
+
pad = 0
|
132
|
+
x = x.reshape(bs, h, w, input_dim)
|
133
|
+
if pad > 0:
|
134
|
+
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
135
|
+
|
136
|
+
"""4 to 1 concat"""
|
137
|
+
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
138
|
+
x = F.unfold(
|
139
|
+
x,
|
140
|
+
kernel_size=self.config.downsample_ratio,
|
141
|
+
stride=self.config.downsample_ratio,
|
142
|
+
padding=0,
|
143
|
+
) # B, C*4, HW // 4
|
144
|
+
x = x.permute(0, 2, 1)
|
145
|
+
|
146
|
+
for layer in self.layers:
|
147
|
+
x = layer(x)
|
148
|
+
if isinstance(x, tuple):
|
149
|
+
x = x[0]
|
150
|
+
return x
|
151
|
+
|
152
|
+
|
153
|
+
# todo
|
154
|
+
class DeepseekVL2ForCausalLM(nn.Module):
|
155
|
+
|
156
|
+
def __init__(
|
157
|
+
self,
|
158
|
+
config: DeepseekVL2Config,
|
159
|
+
quant_config: Optional[QuantizationConfig] = None,
|
160
|
+
):
|
161
|
+
super().__init__()
|
162
|
+
|
163
|
+
# ----------- vision encoder ------------
|
164
|
+
vision_config = config.vision_config
|
165
|
+
self.vision = self._init_vision_module(vision_config, quant_config)
|
166
|
+
|
167
|
+
# ----------- vl projector ------------
|
168
|
+
projector_config = config.projector_config
|
169
|
+
self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)
|
170
|
+
|
171
|
+
self.tile_tag = config.tile_tag
|
172
|
+
self.global_view_pos = config.global_view_pos
|
173
|
+
|
174
|
+
embed_std = 1 / torch.sqrt(
|
175
|
+
torch.tensor(projector_config.n_embed, dtype=torch.float32)
|
176
|
+
)
|
177
|
+
if self.tile_tag == "2D":
|
178
|
+
self.image_newline = nn.Parameter(
|
179
|
+
torch.randn(projector_config.n_embed) * embed_std
|
180
|
+
)
|
181
|
+
self.view_seperator = nn.Parameter(
|
182
|
+
torch.randn(projector_config.n_embed) * embed_std
|
183
|
+
)
|
184
|
+
else:
|
185
|
+
raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")
|
186
|
+
|
187
|
+
# ----------- language model ------------
|
188
|
+
language_config = config.language_config
|
189
|
+
self.language_model = DeepseekV2ForCausalLM(language_config)
|
190
|
+
|
191
|
+
def _init_vision_module(
|
192
|
+
self, vision_config, quant_config: Optional[QuantizationConfig]
|
193
|
+
) -> nn.Module:
|
194
|
+
# TODO: refactor vision model through timm wrapper from transformers
|
195
|
+
try:
|
196
|
+
import timm
|
197
|
+
except ImportError:
|
198
|
+
raise ImportError("Please install timm") from ImportError
|
199
|
+
|
200
|
+
model = timm.create_model(
|
201
|
+
"vit_so400m_patch14_siglip_384.webli",
|
202
|
+
pretrained=False,
|
203
|
+
num_classes=0,
|
204
|
+
dynamic_img_size=True,
|
205
|
+
dynamic_img_pad=True,
|
206
|
+
)
|
207
|
+
|
208
|
+
model = model.to(dtype=torch.get_default_dtype())
|
209
|
+
return model
|
210
|
+
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
input_ids: torch.Tensor,
|
214
|
+
positions: torch.Tensor,
|
215
|
+
forward_batch: ForwardBatch,
|
216
|
+
**kwargs: object,
|
217
|
+
):
|
218
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
219
|
+
if (
|
220
|
+
forward_batch.forward_mode.is_extend()
|
221
|
+
and forward_batch.contains_image_inputs()
|
222
|
+
):
|
223
|
+
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
224
|
+
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
225
|
+
for idx, image in enumerate(forward_batch.mm_inputs):
|
226
|
+
if image is None:
|
227
|
+
continue
|
228
|
+
start_idx = extend_start_loc_cpu[idx]
|
229
|
+
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
230
|
+
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
231
|
+
image_features = self.get_image_feature(image)
|
232
|
+
input_embeds[start_idx:end_idx] = input_embeds[
|
233
|
+
start_idx:end_idx
|
234
|
+
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
235
|
+
|
236
|
+
outputs = self.language_model.forward(
|
237
|
+
input_ids=input_ids,
|
238
|
+
positions=positions,
|
239
|
+
forward_batch=forward_batch,
|
240
|
+
input_embeds=input_embeds,
|
241
|
+
)
|
242
|
+
|
243
|
+
return outputs
|
244
|
+
|
245
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
246
|
+
stacked_params_mapping = [
|
247
|
+
# (param_name, shard_name, shard_id)
|
248
|
+
("qkv_proj", "q_proj", "q"),
|
249
|
+
("qkv_proj", "k_proj", "k"),
|
250
|
+
("qkv_proj", "v_proj", "v"),
|
251
|
+
("gate_up_proj", "up_proj", 1),
|
252
|
+
("gate_up_proj", "gate_proj", 0),
|
253
|
+
]
|
254
|
+
params_dict = dict(self.named_parameters())
|
255
|
+
weights = list(weights)
|
256
|
+
for name, loaded_weight in weights:
|
257
|
+
if "language" in name:
|
258
|
+
name = name.replace("language.", "")
|
259
|
+
self.language_model.load_weights([(name, loaded_weight)])
|
260
|
+
else:
|
261
|
+
param = params_dict[name]
|
262
|
+
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
263
|
+
weights_loader(param, loaded_weight)
|
264
|
+
|
265
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
266
|
+
return input_ids
|
267
|
+
|
268
|
+
def get_image_feature(self, image_input: MultimodalInputs):
|
269
|
+
pixel_values = image_input.pixel_values.type(
|
270
|
+
next(self.vision.parameters()).dtype
|
271
|
+
).to(device=next(self.vision.parameters()).device)
|
272
|
+
image_feature = self.vision.forward_features(pixel_values)
|
273
|
+
images_embeds = self.projector(image_feature)
|
274
|
+
_, hw, n_dim = images_embeds.shape
|
275
|
+
h = w = int(hw**0.5)
|
276
|
+
tile_index = 0
|
277
|
+
images_in_this_batch = []
|
278
|
+
images_spatial_crop = image_input.image_spatial_crop
|
279
|
+
for jdx in range(images_spatial_crop.shape[1]):
|
280
|
+
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
281
|
+
if num_width_tiles == 0 or num_height_tiles == 0:
|
282
|
+
break
|
283
|
+
num_tiles_in_image = num_width_tiles * num_height_tiles
|
284
|
+
|
285
|
+
# [hw, D]
|
286
|
+
global_features = images_embeds[tile_index]
|
287
|
+
|
288
|
+
# [num_height_tiles * num_width_tiles, hw, D]
|
289
|
+
local_features = images_embeds[
|
290
|
+
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
291
|
+
]
|
292
|
+
tile_index += num_tiles_in_image + 1
|
293
|
+
|
294
|
+
# format global and local features
|
295
|
+
# ----------------- global view add newline -----------------
|
296
|
+
# [hw, D] -> [h, w, D]
|
297
|
+
global_features = global_features.view(h, w, n_dim)
|
298
|
+
|
299
|
+
# [D] -> [h, 1, D]
|
300
|
+
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
301
|
+
|
302
|
+
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
303
|
+
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
304
|
+
|
305
|
+
# [h, w + 1, D] -> [h * (w + 1), D]
|
306
|
+
global_features = global_features.view(-1, n_dim)
|
307
|
+
|
308
|
+
# ----------------- local view add newline -----------------
|
309
|
+
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
310
|
+
# [num_height_tiles * h, num_width_tiles * w, D]
|
311
|
+
local_features = rearrange(
|
312
|
+
local_features,
|
313
|
+
"(th tw) (h w) d -> (th h) (tw w) d",
|
314
|
+
th=num_height_tiles,
|
315
|
+
tw=num_width_tiles,
|
316
|
+
h=h,
|
317
|
+
w=w,
|
318
|
+
)
|
319
|
+
|
320
|
+
# [D] -> [num_height_tiles * h, 1, D]
|
321
|
+
new_lines_in_local = repeat(
|
322
|
+
self.image_newline,
|
323
|
+
"d -> (th h) 1 d",
|
324
|
+
th=num_height_tiles,
|
325
|
+
h=h,
|
326
|
+
)
|
327
|
+
|
328
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
329
|
+
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
330
|
+
|
331
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
332
|
+
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
333
|
+
local_features = local_features.view(-1, n_dim)
|
334
|
+
|
335
|
+
# merge global and local tiles
|
336
|
+
if self.global_view_pos == "head":
|
337
|
+
global_local_features = torch.cat(
|
338
|
+
[
|
339
|
+
global_features,
|
340
|
+
self.view_seperator[None, :],
|
341
|
+
local_features,
|
342
|
+
]
|
343
|
+
)
|
344
|
+
else:
|
345
|
+
global_local_features = torch.cat(
|
346
|
+
[
|
347
|
+
local_features,
|
348
|
+
self.view_seperator[None, :],
|
349
|
+
global_features,
|
350
|
+
]
|
351
|
+
)
|
352
|
+
|
353
|
+
images_in_this_batch.append(global_local_features)
|
354
|
+
|
355
|
+
return torch.cat(images_in_this_batch, dim=0)
|
356
|
+
|
357
|
+
|
358
|
+
EntryClass = DeepseekVL2ForCausalLM
|