optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,48 +12,26 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
16
17
|
|
|
17
18
|
import torch
|
|
18
|
-
from transformers import
|
|
19
|
-
PretrainedConfig,
|
|
20
|
-
PreTrainedModel,
|
|
21
|
-
)
|
|
19
|
+
from transformers import ColQwen2Config, ColQwen2ForRetrieval
|
|
22
20
|
from transformers.modeling_utils import no_init_weights
|
|
23
21
|
from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrievalOutput
|
|
24
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
25
|
-
Qwen2_5_VLModel,
|
|
26
|
-
Qwen2_5_VLRotaryEmbedding,
|
|
27
|
-
)
|
|
28
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
29
|
-
Qwen2VLModel,
|
|
30
|
-
Qwen2VLRotaryEmbedding,
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
from optimum.rbln.transformers.models.decoderonly.modeling_decoderonly import (
|
|
34
|
-
RBLNDecoderOnlyModel,
|
|
35
|
-
)
|
|
36
22
|
|
|
37
|
-
from
|
|
38
|
-
|
|
39
|
-
)
|
|
23
|
+
from ....modeling import RBLNModel
|
|
24
|
+
from ....transformers.modeling_outputs import _validate_output_hidden_states
|
|
40
25
|
|
|
41
26
|
|
|
42
27
|
if TYPE_CHECKING:
|
|
43
|
-
from transformers import
|
|
44
|
-
AutoFeatureExtractor,
|
|
45
|
-
AutoProcessor,
|
|
46
|
-
AutoTokenizer,
|
|
47
|
-
PretrainedConfig,
|
|
48
|
-
)
|
|
28
|
+
from transformers import PreTrainedModel
|
|
49
29
|
|
|
50
|
-
from .colqwen2_architecture import ColQwen2LanguageModelWrapper
|
|
51
30
|
|
|
52
|
-
|
|
53
|
-
class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
|
|
31
|
+
class RBLNColQwen2ForRetrieval(RBLNModel):
|
|
54
32
|
"""
|
|
55
|
-
|
|
56
|
-
This model inherits from [`
|
|
33
|
+
RBLNColQwen2ForRetrieval is a model for document retrieval using vision-language models.
|
|
34
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
57
35
|
|
|
58
36
|
A class to convert and run pre-trained transformers based `ColQwen2ForRetrieval` model on RBLN devices.
|
|
59
37
|
It implements the methods to convert a pre-trained transformers `ColQwen2ForRetrieval` model into a RBLN transformer model by:
|
|
@@ -61,326 +39,82 @@ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
|
|
|
61
39
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
62
40
|
- compiling the resulting graph using the RBLN compiler.
|
|
63
41
|
|
|
64
|
-
**Configuration:**
|
|
65
|
-
This model uses [`RBLNColQwen2ForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
66
|
-
the `rbln_config` parameter should be an instance of [`RBLNColQwen2ForRetrievalConfig`] or a dictionary conforming to its structure.
|
|
67
|
-
|
|
68
|
-
See the [`RBLNColQwen2ForRetrievalConfig`] class for all available configuration options.
|
|
69
|
-
|
|
70
42
|
Examples:
|
|
71
43
|
```python
|
|
72
|
-
|
|
44
|
+
import torch
|
|
45
|
+
from PIL import Image
|
|
46
|
+
from transformers import ColQwen2Processor
|
|
47
|
+
|
|
48
|
+
from optimum.rbln import RBLNColQwen2ForRetrieval, RBLNColQwen2ForRetrievalConfig
|
|
73
49
|
|
|
74
|
-
# Using a config dictionary
|
|
75
50
|
rbln_config = {
|
|
76
|
-
"
|
|
77
|
-
"
|
|
51
|
+
"vlm": {
|
|
52
|
+
"visual": {
|
|
53
|
+
"max_seq_lens": 6400,
|
|
54
|
+
},
|
|
55
|
+
"tensor_parallel_size": 4,
|
|
56
|
+
"kvcache_partition_len": 16384,
|
|
57
|
+
"max_seq_len": 16384 * 7,
|
|
78
58
|
},
|
|
79
|
-
"max_seq_len": 32_768,
|
|
80
|
-
"tensor_parallel_size": 4,
|
|
81
|
-
"device": [0, 1, 2, 3],
|
|
82
|
-
"output_hidden_states": False,
|
|
83
59
|
}
|
|
84
|
-
model = RBLNColQwen2ForRetrieval.from_pretrained(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
)
|
|
60
|
+
model = RBLNColQwen2ForRetrieval.from_pretrained("vidore/colqwen2-v1.0-hf", rbln_config=config)
|
|
61
|
+
model.save_pretrained("compiled-colqwen2-v1.0-hf")
|
|
62
|
+
|
|
63
|
+
# The document page screenshots from your corpus. Below are dummy images.
|
|
64
|
+
images = [
|
|
65
|
+
Image.new("RGB", (128, 128), color="white"),
|
|
66
|
+
Image.new("RGB", (64, 32), color="black"),
|
|
67
|
+
]
|
|
68
|
+
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0-hf")
|
|
69
|
+
|
|
70
|
+
queries = [
|
|
71
|
+
"When was the United States Declaration of Independence proclaimed?",
|
|
72
|
+
"Who printed the edition of Romeo and Juliet?",
|
|
73
|
+
]
|
|
74
|
+
inputs_images = processor(images=images)
|
|
75
|
+
inputs_text = processor(text=queries)
|
|
76
|
+
|
|
77
|
+
# Forward pass
|
|
78
|
+
with torch.no_grad():
|
|
79
|
+
image_embeddings = model(**inputs_images).embeddings
|
|
80
|
+
query_embeddings = model(**inputs_text).embeddings
|
|
81
|
+
|
|
82
|
+
scores = processor.score_retrieval(query_embeddings, image_embeddings)
|
|
83
|
+
print("Retrieval scores (query x image):")
|
|
84
|
+
print(scores)
|
|
108
85
|
```
|
|
109
86
|
"""
|
|
110
87
|
|
|
111
|
-
|
|
112
|
-
auto_model_class = None
|
|
88
|
+
_rbln_submodule_postfix = "model"
|
|
113
89
|
_rbln_submodules = [
|
|
114
|
-
{"name": "
|
|
90
|
+
{"name": "vlm"},
|
|
115
91
|
]
|
|
116
|
-
|
|
117
|
-
_use_rotary_emb = False
|
|
92
|
+
_supports_non_fp32 = True
|
|
118
93
|
|
|
119
94
|
def __post_init__(self, **kwargs):
|
|
120
|
-
self.
|
|
121
|
-
|
|
122
|
-
artifacts = torch.load(
|
|
123
|
-
self.model_save_dir / self.subfolder / "torch_artifacts.pth",
|
|
124
|
-
weights_only=False,
|
|
125
|
-
)
|
|
126
|
-
self.embed_tokens = self._create_embedding_layer()
|
|
127
|
-
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
128
|
-
self.visual = self.rbln_submodules[0]
|
|
129
|
-
self.prefill_runtime = self.model[0]
|
|
130
|
-
self.mrope_section = self.config.text_config.rope_scaling["mrope_section"]
|
|
131
|
-
self.is_colqwen2_5 = "qwen2_5_vl" in self.config.model_type
|
|
132
|
-
|
|
133
|
-
if self.is_colqwen2_5:
|
|
134
|
-
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config.text_config)
|
|
135
|
-
else:
|
|
136
|
-
self.rotary_emb = Qwen2VLRotaryEmbedding(self.config.text_config)
|
|
137
|
-
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
|
|
95
|
+
self.vlm_model = self.rbln_submodules[0]
|
|
96
|
+
return super().__post_init__(**kwargs)
|
|
138
97
|
|
|
139
98
|
@classmethod
|
|
140
99
|
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
141
|
-
if
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
return model.to(torch.float32)
|
|
147
|
-
|
|
148
|
-
def _create_embedding_layer(self):
|
|
149
|
-
with no_init_weights():
|
|
150
|
-
embed_tokens = torch.nn.Embedding(
|
|
151
|
-
self.config.text_config.vocab_size,
|
|
152
|
-
self.config.text_config.hidden_size,
|
|
153
|
-
self.config.text_config.pad_token_id,
|
|
154
|
-
)
|
|
155
|
-
return embed_tokens
|
|
156
|
-
|
|
157
|
-
@classmethod
|
|
158
|
-
def get_input_info(
|
|
159
|
-
cls,
|
|
160
|
-
batch_size: int,
|
|
161
|
-
query_length: int,
|
|
162
|
-
rbln_config: RBLNColQwen2ForRetrievalConfig,
|
|
163
|
-
model_config: PretrainedConfig,
|
|
164
|
-
):
|
|
165
|
-
text_config = model_config.text_config
|
|
166
|
-
input_info = super().get_input_info(
|
|
167
|
-
batch_size,
|
|
168
|
-
query_length,
|
|
169
|
-
rbln_config=rbln_config,
|
|
170
|
-
model_config=text_config,
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
pos_idx = 3
|
|
174
|
-
input_info.insert(
|
|
175
|
-
pos_idx,
|
|
176
|
-
(
|
|
177
|
-
"position_emb",
|
|
178
|
-
[
|
|
179
|
-
2,
|
|
180
|
-
batch_size,
|
|
181
|
-
1,
|
|
182
|
-
query_length,
|
|
183
|
-
text_config.hidden_size // text_config.num_attention_heads,
|
|
184
|
-
],
|
|
185
|
-
rbln_config.torch_dtype,
|
|
186
|
-
),
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
# remove query postion from input_info
|
|
190
|
-
if "query_position" in input_info:
|
|
191
|
-
query_position = input_info.pop(4)
|
|
192
|
-
assert query_position[0] == "query_position", print(query_position[0], "is deleted.")
|
|
193
|
-
return input_info
|
|
194
|
-
|
|
195
|
-
@classmethod
|
|
196
|
-
def _update_rbln_config(
|
|
197
|
-
cls,
|
|
198
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
199
|
-
model: Optional["PreTrainedModel"] = None,
|
|
200
|
-
model_config: Optional["PretrainedConfig"] = None,
|
|
201
|
-
rbln_config: Optional[RBLNColQwen2ForRetrievalConfig] = None,
|
|
202
|
-
) -> RBLNColQwen2ForRetrievalConfig:
|
|
203
|
-
model_config = model_config.vlm_config if hasattr(model_config, "vlm_config") else model_config
|
|
204
|
-
if rbln_config.output_hidden_states is None:
|
|
205
|
-
rbln_config.output_hidden_states = getattr(model_config.text_config, "output_hidden_states", False)
|
|
206
|
-
|
|
207
|
-
return super()._update_rbln_config(
|
|
208
|
-
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
def _get_position_embeddings(self, hidden_states, position_ids):
|
|
212
|
-
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
|
213
|
-
mrope_section = self.mrope_section * 2
|
|
214
|
-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
|
|
215
|
-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
|
|
216
|
-
return torch.stack([cos, sin])
|
|
217
|
-
|
|
218
|
-
def get_rope_index(self, *args, **kwargs):
|
|
219
|
-
if self.is_colqwen2_5:
|
|
220
|
-
return Qwen2_5_VLModel.get_rope_index(self, *args, **kwargs)
|
|
221
|
-
else:
|
|
222
|
-
return Qwen2VLModel.get_rope_index(self, *args, **kwargs)
|
|
223
|
-
|
|
224
|
-
def _preprocess_visual(
|
|
225
|
-
self,
|
|
226
|
-
input_ids: torch.LongTensor = None,
|
|
227
|
-
attention_mask: torch.Tensor = None,
|
|
228
|
-
pixel_values: torch.Tensor = None,
|
|
229
|
-
pixel_values_videos: torch.FloatTensor = None,
|
|
230
|
-
image_grid_thw: torch.LongTensor = None,
|
|
231
|
-
video_grid_thw: torch.LongTensor = None,
|
|
232
|
-
second_per_grid_ts: torch.Tensor = None,
|
|
233
|
-
):
|
|
234
|
-
batch_size = input_ids.shape[0]
|
|
235
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
|
236
|
-
|
|
237
|
-
if pixel_values is not None:
|
|
238
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
239
|
-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
240
|
-
n_image_features = image_embeds.shape[0]
|
|
241
|
-
if n_image_tokens != n_image_features:
|
|
242
|
-
raise ValueError(
|
|
243
|
-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
100
|
+
# if model is from Colpali-engine, convert it to a ColQwen2ForRetrieval model
|
|
101
|
+
if hasattr(model, "custom_text_proj"):
|
|
102
|
+
with no_init_weights():
|
|
103
|
+
model_config = ColQwen2Config(
|
|
104
|
+
vlm_config=model.config, embedding_dim=model.custom_text_proj.out_features
|
|
244
105
|
)
|
|
106
|
+
new_model = ColQwen2ForRetrieval._from_config(model_config)
|
|
107
|
+
new_model.embedding_proj_layer = model.custom_text_proj
|
|
108
|
+
new_model.vlm.model.visual.load_state_dict(model.visual.state_dict())
|
|
109
|
+
new_model.vlm.model.language_model.load_state_dict(model.language_model.state_dict())
|
|
110
|
+
model = new_model
|
|
245
111
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
112
|
+
# replace the lm_head with the custom text projection layer for optimization
|
|
113
|
+
model.vlm.model.lm_head = model.embedding_proj_layer
|
|
114
|
+
model.vlm.model.config.embedding_dim = model.config.embedding_dim
|
|
249
115
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
if pixel_values_videos is not None:
|
|
254
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
255
|
-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
256
|
-
n_video_features = video_embeds.shape[0]
|
|
257
|
-
if n_video_tokens != n_video_features:
|
|
258
|
-
raise ValueError(
|
|
259
|
-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
mask = input_ids == self.config.video_token_id
|
|
263
|
-
mask_unsqueezed = mask.unsqueeze(-1)
|
|
264
|
-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
265
|
-
inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
|
|
266
|
-
|
|
267
|
-
max_inputs_len = input_ids.shape[1]
|
|
268
|
-
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
|
|
269
|
-
all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
|
|
270
|
-
all_rope_deltas = []
|
|
271
|
-
|
|
272
|
-
image_token_id = self.config.image_token_id
|
|
273
|
-
video_token_id = self.config.video_token_id
|
|
274
|
-
vision_start_token_id = self.config.vision_start_token_id
|
|
275
|
-
image_idx, video_idx = 0, 0
|
|
276
|
-
|
|
277
|
-
for b_idx in range(batch_size):
|
|
278
|
-
input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
|
|
279
|
-
vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
|
|
280
|
-
vision_tokens = input_id[0][vision_start_indices + 1]
|
|
281
|
-
image_nums = (vision_tokens == image_token_id).sum()
|
|
282
|
-
video_nums = (vision_tokens == video_token_id).sum()
|
|
283
|
-
args = [
|
|
284
|
-
input_id,
|
|
285
|
-
image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
|
|
286
|
-
video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
|
|
287
|
-
]
|
|
288
|
-
if self.config.model_type == "qwen2_5_vl":
|
|
289
|
-
args.append(
|
|
290
|
-
second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None
|
|
291
|
-
)
|
|
292
|
-
position_ids, rope_deltas = self.get_rope_index(*args)
|
|
293
|
-
image_idx += image_nums
|
|
294
|
-
video_idx += video_nums
|
|
295
|
-
|
|
296
|
-
position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
|
|
297
|
-
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
|
|
298
|
-
all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
|
|
299
|
-
all_rope_deltas.append(rope_deltas)
|
|
300
|
-
|
|
301
|
-
rope_deltas = torch.stack(all_rope_deltas)
|
|
302
|
-
|
|
303
|
-
return inputs_embeds, all_position_embeds, rope_deltas
|
|
304
|
-
|
|
305
|
-
def _preprocess_chunked_prefill(self, inputs_embeds, attention_mask, position_embed):
|
|
306
|
-
# valid sequence length of inputs_embeds
|
|
307
|
-
query_length = inputs_embeds.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
|
|
308
|
-
|
|
309
|
-
# extract valid inputs
|
|
310
|
-
inputs_embeds = inputs_embeds[:, attention_mask.bool()] if attention_mask is not None else inputs_embeds
|
|
311
|
-
position_embed = (
|
|
312
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# add padding for chunked prefill
|
|
316
|
-
padding_size = (
|
|
317
|
-
self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
|
|
318
|
-
) % self.rbln_config.prefill_chunk_size
|
|
319
|
-
padded_len = query_length + padding_size
|
|
320
|
-
|
|
321
|
-
inputs_embeds = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, padding_size))
|
|
322
|
-
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
323
|
-
cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
|
|
324
|
-
|
|
325
|
-
return inputs_embeds, position_embed, cache_position, query_length
|
|
326
|
-
|
|
327
|
-
def _chunked_prefill_forward(
|
|
328
|
-
self,
|
|
329
|
-
inputs_embeds: torch.Tensor,
|
|
330
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
331
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
332
|
-
output_hidden_states: Optional[bool] = False,
|
|
333
|
-
):
|
|
334
|
-
padded_inputs_embeds, padded_position_embed, cache_position, query_length = self._preprocess_chunked_prefill(
|
|
335
|
-
inputs_embeds, attention_mask, position_embed
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
# Chunked prefill
|
|
339
|
-
projs = []
|
|
340
|
-
all_hidden_states = [] if output_hidden_states else None
|
|
341
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
342
|
-
# Extract the current chunk of inputs and cache positions
|
|
343
|
-
input_chunk = padded_inputs_embeds[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
344
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
345
|
-
position_embed_chunk = padded_position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
|
346
|
-
|
|
347
|
-
# Forward pass for the current chunk
|
|
348
|
-
proj = self.prefill_runtime(
|
|
349
|
-
inputs_embeds=input_chunk,
|
|
350
|
-
cache_position=cache_pos_chunk,
|
|
351
|
-
block_tables=self.block_tables,
|
|
352
|
-
position_emb=position_embed_chunk,
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
if output_hidden_states:
|
|
356
|
-
projs.append(proj[0])
|
|
357
|
-
all_hidden_states.append(proj[1:])
|
|
358
|
-
else:
|
|
359
|
-
projs.append(proj)
|
|
360
|
-
|
|
361
|
-
projs = torch.concat(projs, dim=-2)[:, :query_length]
|
|
362
|
-
if output_hidden_states:
|
|
363
|
-
# Concatenate chunks for each layer
|
|
364
|
-
concatenated_hidden_states = [
|
|
365
|
-
torch.concat(hs_chunks, dim=-2)[:, :query_length] for hs_chunks in list(zip(*all_hidden_states))
|
|
366
|
-
]
|
|
367
|
-
all_hidden_states = tuple(concatenated_hidden_states)
|
|
368
|
-
|
|
369
|
-
return self._postprocess_chunked_prefill(projs, attention_mask), all_hidden_states
|
|
370
|
-
|
|
371
|
-
def _postprocess_chunked_prefill(self, projs, attention_mask):
|
|
372
|
-
# index copy for attention mask
|
|
373
|
-
if attention_mask is not None:
|
|
374
|
-
embedding = torch.full(
|
|
375
|
-
(1, attention_mask.shape[-1], projs.shape[-1]),
|
|
376
|
-
fill_value=1e-10,
|
|
377
|
-
dtype=projs.dtype,
|
|
378
|
-
)
|
|
379
|
-
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
380
|
-
embedding.index_copy_(dim=-2, index=mask_indices, source=projs)
|
|
381
|
-
else:
|
|
382
|
-
embedding = projs
|
|
383
|
-
return embedding
|
|
116
|
+
# Some of the model weights are different from the model.dtype(vidore/colqwen2-v1.0-hf)
|
|
117
|
+
return model.to(model.dtype)
|
|
384
118
|
|
|
385
119
|
def forward(
|
|
386
120
|
self,
|
|
@@ -388,22 +122,34 @@ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
|
|
|
388
122
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
389
123
|
attention_mask: Optional[torch.Tensor] = None,
|
|
390
124
|
pixel_values: Optional[torch.Tensor] = None,
|
|
391
|
-
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
392
125
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
393
|
-
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
394
|
-
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
395
126
|
output_hidden_states: Optional[bool] = None,
|
|
127
|
+
return_dict: Optional[bool] = None,
|
|
396
128
|
**kwargs,
|
|
397
|
-
) ->
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
129
|
+
) -> Union[Tuple, ColQwen2ForRetrievalOutput]:
|
|
130
|
+
"""
|
|
131
|
+
Runs a ColQwen2 retrieval forward pass on text tokens and optional visual inputs.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
input_ids (torch.LongTensor, optional): Indices of the textual tokens. Mutually exclusive with `inputs_embeds`.
|
|
135
|
+
inputs_embeds (torch.FloatTensor, optional): Pre-computed embeddings fed directly into the language model.
|
|
136
|
+
attention_mask (torch.Tensor, optional): Mask that selects which token positions contribute to the loss/embeddings.
|
|
137
|
+
pixel_values (torch.Tensor, optional): Flattened image patches produced by `ColQwen2Processor` for document pages.
|
|
138
|
+
image_grid_thw (torch.LongTensor, optional): Per-image `(t, h, w)` grid metadata that allows unpadding of `pixel_values`.
|
|
139
|
+
output_hidden_states (bool, optional): If `True`, expose intermediate decoder hidden states.
|
|
140
|
+
return_dict (bool, optional): If `True`, return a `ColQwen2ForRetrievalOutput`; otherwise return a tuple.
|
|
141
|
+
**kwargs (dict[str, Any], optional): Extra multimodal args forwarded to the wrapped VLM (e.g. `pixel_values_videos`,
|
|
142
|
+
`video_grid_thw`, `second_per_grid_ts`).
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Dataclass containing the embeddings and hidden states of the VLM model.
|
|
146
|
+
"""
|
|
147
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
401
148
|
|
|
402
|
-
if
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
)
|
|
149
|
+
if pixel_values is not None:
|
|
150
|
+
pixel_values = pixel_values.to(dtype=self.dtype)
|
|
151
|
+
|
|
152
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
407
153
|
|
|
408
154
|
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
|
|
409
155
|
if pixel_values is not None and image_grid_thw is not None:
|
|
@@ -412,35 +158,25 @@ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
|
|
|
412
158
|
[pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
|
|
413
159
|
dim=0,
|
|
414
160
|
)
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
input_ids,
|
|
418
|
-
attention_mask,
|
|
419
|
-
pixel_values,
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
161
|
+
|
|
162
|
+
vlm_output = self.vlm_model(
|
|
163
|
+
input_ids=input_ids,
|
|
164
|
+
attention_mask=attention_mask,
|
|
165
|
+
pixel_values=pixel_values,
|
|
166
|
+
image_grid_thw=image_grid_thw,
|
|
167
|
+
output_hidden_states=output_hidden_states,
|
|
168
|
+
return_dict=True,
|
|
169
|
+
**kwargs,
|
|
424
170
|
)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
proj = self._chunked_prefill_forward(
|
|
430
|
-
inputs_embeds[b_idx : b_idx + 1],
|
|
431
|
-
attention_mask[b_idx] if attention_mask is not None else None,
|
|
432
|
-
position_embed[:, b_idx : b_idx + 1],
|
|
433
|
-
output_hidden_states=output_hidden_states,
|
|
434
|
-
)
|
|
435
|
-
projs.append(proj[0])
|
|
436
|
-
all_hidden_states = proj[1] if output_hidden_states else ()
|
|
171
|
+
hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
|
172
|
+
|
|
173
|
+
embeddings = vlm_output[0]
|
|
174
|
+
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
|
437
175
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
projs = projs / projs.norm(dim=-1, keepdim=True)
|
|
441
|
-
projs = projs * attention_mask.unsqueeze(-1)
|
|
176
|
+
if attention_mask is not None:
|
|
177
|
+
embeddings = embeddings * attention_mask.unsqueeze(-1)
|
|
442
178
|
|
|
443
179
|
return ColQwen2ForRetrievalOutput(
|
|
444
|
-
embeddings=
|
|
445
|
-
hidden_states=
|
|
180
|
+
embeddings=embeddings,
|
|
181
|
+
hidden_states=hidden_states,
|
|
446
182
|
)
|