optimum-rbln 0.8.1a4__py3-none-any.whl → 0.8.1a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
optimum/rbln/__init__.py CHANGED
@@ -70,6 +70,8 @@ _import_structure = {
70
70
  "RBLNCLIPVisionModelConfig",
71
71
  "RBLNCLIPVisionModelWithProjection",
72
72
  "RBLNCLIPVisionModelWithProjectionConfig",
73
+ "RBLNColPaliForRetrieval",
74
+ "RBLNColPaliForRetrievalConfig",
73
75
  "RBLNDecoderOnlyModelForCausalLM",
74
76
  "RBLNDecoderOnlyModelForCausalLMConfig",
75
77
  "RBLNDistilBertForQuestionAnswering",
@@ -297,6 +299,8 @@ if TYPE_CHECKING:
297
299
  RBLNCLIPVisionModelConfig,
298
300
  RBLNCLIPVisionModelWithProjection,
299
301
  RBLNCLIPVisionModelWithProjectionConfig,
302
+ RBLNColPaliForRetrieval,
303
+ RBLNColPaliForRetrievalConfig,
300
304
  RBLNDecoderOnlyModelForCausalLM,
301
305
  RBLNDecoderOnlyModelForCausalLMConfig,
302
306
  RBLNDistilBertForQuestionAnswering,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.1a4'
21
- __version_tuple__ = version_tuple = (0, 8, 1, 'a4')
20
+ __version__ = version = '0.8.1a5'
21
+ __version_tuple__ = version_tuple = (0, 8, 1, 'a5')
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
16
16
 
17
17
  import torch
18
18
  from diffusers import ControlNetModel
19
- from diffusers.models.controlnet import ControlNetOutput
19
+ from diffusers.models.controlnets.controlnet import ControlNetOutput
20
20
  from transformers import PretrainedConfig
21
21
 
22
22
  from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
@@ -50,6 +50,8 @@ _import_structure = {
50
50
  "RBLNBlip2QFormerModelConfig",
51
51
  "RBLNBlip2VisionModel",
52
52
  "RBLNBlip2VisionModelConfig",
53
+ "RBLNColPaliForRetrieval",
54
+ "RBLNColPaliForRetrievalConfig",
53
55
  "RBLNCLIPTextModel",
54
56
  "RBLNCLIPTextModelConfig",
55
57
  "RBLNCLIPTextModelWithProjection",
@@ -69,6 +69,10 @@ _import_structure = {
69
69
  "RBLNCLIPVisionModelWithProjection",
70
70
  "RBLNCLIPVisionModelWithProjectionConfig",
71
71
  ],
72
+ "colpali": [
73
+ "RBLNColPaliForRetrieval",
74
+ "RBLNColPaliForRetrievalConfig",
75
+ ],
72
76
  "distilbert": [
73
77
  "RBLNDistilBertForQuestionAnswering",
74
78
  "RBLNDistilBertForQuestionAnsweringConfig",
@@ -193,6 +197,10 @@ if TYPE_CHECKING:
193
197
  RBLNCLIPVisionModelWithProjection,
194
198
  RBLNCLIPVisionModelWithProjectionConfig,
195
199
  )
200
+ from .colpali import (
201
+ RBLNColPaliForRetrieval,
202
+ RBLNColPaliForRetrievalConfig,
203
+ )
196
204
  from .decoderonly import (
197
205
  RBLNDecoderOnlyModelForCausalLM,
198
206
  RBLNDecoderOnlyModelForCausalLMConfig,
@@ -0,0 +1,2 @@
1
+ from .configuration_colpali import RBLNColPaliForRetrievalConfig
2
+ from .modeling_colpali import RBLNColPaliForRetrieval
@@ -0,0 +1,221 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import GemmaForCausalLM, GemmaModel
6
+
7
+ from ..decoderonly.decoderonly_architecture import (
8
+ RotaryEmbedding,
9
+ apply_rotary_pos_emb,
10
+ )
11
+
12
+
13
+ def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
14
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
15
+ cos = cos[position_ids[0]][None, None, None, :, :]
16
+ sin = sin[position_ids[0]][None, None, None, :, :]
17
+
18
+ return cos, sin
19
+
20
+
21
+ class RBLNColPaliForRetrievalWrapper(nn.Module):
22
+ def __init__(
23
+ self,
24
+ causal_lm: GemmaForCausalLM,
25
+ embedding_proj_layer: nn.Module,
26
+ max_seq_len: int,
27
+ output_hidden_states: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.text_config = causal_lm.config
31
+ self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
32
+
33
+ self.output_hidden_states = output_hidden_states
34
+ self.language_model = self.convert_to_rbln_language_model(causal_lm.model, max_seq_len)
35
+
36
+ self.num_hidden_layers = getattr(self.text_config, "num_hidden_layers", None)
37
+ self.embedding_proj_layer = embedding_proj_layer
38
+
39
+ def get_rotary_emb(self, max_seq_len):
40
+ return RotaryEmbedding(config=self.text_config, max_seq_len_cached=max_seq_len)
41
+
42
+ def convert_to_rbln_language_model(self, gemma_model: GemmaModel, max_seq_len: int):
43
+ new_layers = []
44
+ for layer in gemma_model.layers:
45
+ new_self_attn = ColPaliAttention(
46
+ layer.self_attn,
47
+ )
48
+ new_layer = ColPaliLayer(layer, new_self_attn)
49
+ new_layers.append(new_layer)
50
+
51
+ new_model = ColPaliModel(
52
+ gemma_model,
53
+ new_layers,
54
+ output_hidden_states=self.output_hidden_states,
55
+ max_seq_len=max_seq_len,
56
+ )
57
+
58
+ return new_model
59
+
60
+ def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor):
61
+ attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float32).min
62
+ attention_mask = attention_mask[:, None, None, None, :]
63
+
64
+ hidden_states, all_hidden_states = self.language_model(
65
+ inputs_embeds=inputs_embeds,
66
+ attention_mask=attention_mask,
67
+ rotary_emb=self.rotary_emb,
68
+ position_ids=position_ids,
69
+ )
70
+ embeddings = self.embedding_proj_layer(hidden_states)
71
+
72
+ if self.output_hidden_states:
73
+ return embeddings, all_hidden_states
74
+ else:
75
+ return embeddings
76
+
77
+
78
+ class ColPaliModel(nn.Module):
79
+ def __init__(
80
+ self, model, layers: List["ColPaliLayer"], output_hidden_states: bool = False, max_seq_len: int = 2048
81
+ ):
82
+ super().__init__()
83
+ self._original_mod = model
84
+ self.layers = nn.ModuleList(layers)
85
+ self.output_hidden_states = output_hidden_states
86
+ self.norm = self._original_mod.norm
87
+ self.hidden_size = self._original_mod.config.hidden_size
88
+ self.max_seq_len = max_seq_len
89
+
90
+ def forward(
91
+ self,
92
+ inputs_embeds: Optional[torch.Tensor] = None,
93
+ attention_mask: torch.Tensor = None,
94
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
95
+ position_ids: Optional[torch.Tensor] = None,
96
+ ):
97
+ hidden_states = inputs_embeds * self.hidden_size**0.5
98
+
99
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
100
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
101
+
102
+ all_hidden_states = () if self.output_hidden_states else None
103
+ for layer in self.layers:
104
+ if self.output_hidden_states:
105
+ all_hidden_states += (hidden_states,)
106
+
107
+ hidden_states = layer(
108
+ hidden_states=hidden_states,
109
+ attention_mask=attention_mask,
110
+ cos=cos,
111
+ sin=sin,
112
+ )
113
+ hidden_states = self.norm(hidden_states)
114
+
115
+ if self.output_hidden_states:
116
+ all_hidden_states += (hidden_states,)
117
+
118
+ return hidden_states, all_hidden_states
119
+
120
+
121
+ class ColPaliLayer(nn.Module):
122
+ def __init__(self, layer, self_attn: "ColPaliAttention"):
123
+ super().__init__()
124
+ self._original_mod = layer
125
+ self.self_attn = self_attn
126
+ self.mlp = layer.mlp
127
+ self.input_layernorm = layer.input_layernorm
128
+ self.post_attention_layernorm = layer.post_attention_layernorm
129
+
130
+ def forward(
131
+ self,
132
+ hidden_states: torch.Tensor,
133
+ attention_mask: Optional[torch.Tensor] = None,
134
+ cos: Optional[torch.Tensor] = None,
135
+ sin: Optional[torch.Tensor] = None,
136
+ ) -> Tuple[torch.FloatTensor]:
137
+ residual = hidden_states
138
+ hidden_states = self.input_layernorm(hidden_states)
139
+
140
+ # Self Attention
141
+ hidden_states = self.self_attn(
142
+ hidden_states=hidden_states,
143
+ attention_mask=attention_mask,
144
+ cos=cos,
145
+ sin=sin,
146
+ )
147
+ hidden_states = residual + hidden_states
148
+
149
+ # Fully Connected
150
+ residual = hidden_states
151
+ hidden_states = self.post_attention_layernorm(hidden_states)
152
+ hidden_states = self.mlp(hidden_states)
153
+ hidden_states = residual + hidden_states
154
+
155
+ return hidden_states
156
+
157
+
158
+ class ColPaliAttention(nn.Module):
159
+ def __init__(self, self_attn):
160
+ super().__init__()
161
+ self._original_mod = self_attn
162
+ self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
163
+ self._original_mod.config, "num_attention_heads"
164
+ )
165
+ self.head_dim = self._original_mod.head_dim
166
+ self.scaling = self.head_dim**-0.5
167
+
168
+ if hasattr(self._original_mod, "num_key_value_heads"):
169
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
170
+ elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
171
+ self.num_key_value_heads = self._original_mod.config.num_key_value_heads
172
+ else:
173
+ self.num_key_value_heads = self.num_heads
174
+
175
+ self.__post_init__()
176
+
177
+ def __post_init__(self):
178
+ self.q_proj = self._original_mod.q_proj
179
+ self.k_proj = self._original_mod.k_proj
180
+ self.v_proj = self._original_mod.v_proj
181
+ self.o_proj = self._original_mod.o_proj
182
+
183
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
184
+ query_states = self.q_proj(hidden_states)
185
+ key_states = self.k_proj(hidden_states)
186
+ value_states = self.v_proj(hidden_states)
187
+
188
+ return query_states, key_states, value_states
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states: torch.Tensor,
193
+ attention_mask: torch.Tensor,
194
+ cos: Optional[torch.Tensor] = None,
195
+ sin: Optional[torch.Tensor] = None,
196
+ ):
197
+ batch_size, query_length, _ = hidden_states.size()
198
+
199
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
200
+
201
+ query_states = query_states.view(batch_size, query_length, 1, self.num_heads, self.head_dim).transpose(1, 3)
202
+ key_states = key_states.view(batch_size, query_length, 1, self.num_key_value_heads, self.head_dim).transpose(
203
+ 1, 3
204
+ )
205
+ value_states = value_states.view(
206
+ batch_size, query_length, 1, self.num_key_value_heads, self.head_dim
207
+ ).transpose(1, 3)
208
+
209
+ if cos is not None and sin is not None:
210
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
211
+
212
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) * self.scaling
213
+ attn_weights = attn_weights + attention_mask
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+ attn_output = attn_output.transpose(1, 3)
217
+
218
+ attn_output = attn_output.reshape(batch_size, query_length, -1)
219
+ attn_output = self.o_proj(attn_output)
220
+
221
+ return attn_output
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Union
15
+
16
+ from ....configuration_utils import RBLNModelConfig
17
+
18
+
19
+ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
20
+ """
21
+ Configuration class for RBLN ColPali models for document retrieval.
22
+
23
+ This class extends RBLNModelConfig with specific configurations for ColPali models,
24
+ including vision tower settings and multi-sequence length support.
25
+
26
+ Example usage:
27
+ ```python
28
+ from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
29
+
30
+ # Create a configuration object
31
+ config = RBLNColPaliForRetrievalConfig(
32
+ max_seq_lens=1152,
33
+ output_hidden_states=False,
34
+ tensor_parallel_size=4
35
+ )
36
+
37
+ # Use the configuration with from_pretrained
38
+ model = RBLNColPaliForRetrieval.from_pretrained(
39
+ "vidore/colpali-v1.3-hf",
40
+ export=True,
41
+ rbln_config=config
42
+ )
43
+ ```
44
+ """
45
+
46
+ submodules = ["vision_tower"]
47
+
48
+ def __init__(
49
+ self,
50
+ max_seq_lens: Union[int, List[int]] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ vision_tower: Optional[RBLNModelConfig] = None,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Args:
57
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
58
+ max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
59
+ This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
60
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
61
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
62
+ Raises:
63
+ ValueError: If batch_size is not a positive integer.
64
+ """
65
+ super().__init__(**kwargs)
66
+ self.vision_tower = vision_tower
67
+ self.max_seq_lens = max_seq_lens
68
+ self.output_hidden_states = output_hidden_states
@@ -0,0 +1,383 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import bisect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Optional, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ PretrainedConfig,
22
+ PreTrainedModel,
23
+ )
24
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
25
+ from transformers.modeling_utils import no_init_weights
26
+ from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
27
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
28
+
29
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
+ from ....modeling import RBLNModel
31
+ from .colpali_architecture import RBLNColPaliForRetrievalWrapper
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from transformers import (
36
+ AutoFeatureExtractor,
37
+ AutoProcessor,
38
+ AutoTokenizer,
39
+ PretrainedConfig,
40
+ )
41
+
42
+
43
+ class LoopVisionTower:
44
+ def __init__(self, vision_tower: RBLNModel) -> None:
45
+ self.vision_tower = vision_tower
46
+
47
+ def forward(self, pixel_values, **kwargs):
48
+ batch_size = pixel_values.shape[0]
49
+ outputs = []
50
+ for i in range(batch_size):
51
+ outputs.append(self.vision_tower(pixel_values[i : i + 1]))
52
+
53
+ last_hidden_states = [output.last_hidden_state for output in outputs]
54
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
55
+
56
+ return BaseModelOutputWithPooling(
57
+ last_hidden_state=last_hidden_states,
58
+ )
59
+
60
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
61
+ return self.forward(*args, **kwds)
62
+
63
+ def __repr__(self) -> str:
64
+ return repr(self.vision_tower)
65
+
66
+
67
+ class LoopLanguageModel:
68
+ def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig) -> None:
69
+ self.language_model = language_model
70
+ self.rbln_config = rbln_config
71
+
72
+ def prepare_inputs(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor):
73
+ input_len = inputs_embeds.shape[1]
74
+ idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
75
+ if idx == len(self.rbln_config.max_seq_lens):
76
+ raise ValueError(
77
+ f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
78
+ )
79
+ else:
80
+ max_seq_len = self.rbln_config.max_seq_lens[idx]
81
+
82
+ inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
83
+ attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
84
+ position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
85
+
86
+ return inputs_embed, attn_mask, position_ids
87
+
88
+ def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
89
+ padded_inputs_embed, padded_attn_mask, padded_position_ids = self.prepare_inputs(inputs_embeds, attention_mask)
90
+ input_batch_size = inputs_embeds.shape[0]
91
+ input_seq_len = inputs_embeds.shape[1]
92
+
93
+ all_embeddings = []
94
+ all_hidden_states = []
95
+ for i in range(input_batch_size):
96
+ outputs = self.language_model(
97
+ inputs_embeds=padded_inputs_embed[i : i + 1],
98
+ attention_mask=padded_attn_mask[i : i + 1],
99
+ position_ids=padded_position_ids,
100
+ )
101
+
102
+ if self.rbln_config.output_hidden_states:
103
+ embedding = outputs[0]
104
+ hidden_states = outputs[1:]
105
+ else:
106
+ embedding = outputs
107
+ hidden_states = None
108
+
109
+ all_embeddings.append(embedding)
110
+ all_hidden_states.append(hidden_states)
111
+
112
+ embeddings = torch.cat(all_embeddings, dim=0)[:, :input_seq_len]
113
+ if self.rbln_config.output_hidden_states:
114
+ hidden_states = [
115
+ torch.cat(
116
+ [batch_hidden_states[layer_idx][:, :input_seq_len] for batch_hidden_states in all_hidden_states],
117
+ dim=0,
118
+ )
119
+ for layer_idx in range(len(all_hidden_states[0]))
120
+ ]
121
+ return embeddings, tuple(hidden_states)
122
+ else:
123
+ return embeddings
124
+
125
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
126
+ return self.forward(*args, **kwds)
127
+
128
+ def __repr__(self) -> str:
129
+ return repr(self.language_model)
130
+
131
+
132
+ class RBLNColPaliForRetrieval(RBLNModel):
133
+ """
134
+ The ColPali Model transformer for document retrieval using vision-language models.
135
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
136
+
137
+ A class to convert and run pre-trained transformers based ColPaliForRetrieval model on RBLN devices.
138
+ It implements the methods to convert a pre-trained transformers ColPaliForRetrieval model into a RBLN transformer model by:
139
+
140
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
141
+ - compiling the resulting graph using the RBLN compiler.
142
+
143
+ **Configuration:**
144
+ This model uses [`RBLNColPaliForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
145
+ the `rbln_config` parameter should be an instance of [`RBLNColPaliForRetrievalConfig`] or a dictionary conforming to its structure.
146
+
147
+ See the [`RBLNColPaliForRetrievalConfig`] class for all available configuration options.
148
+
149
+ Examples:
150
+ ```python
151
+ from optimum.rbln import RBLNColPaliForRetrieval
152
+
153
+ # Simple usage using rbln_* arguments
154
+ # `max_seq_lens` is automatically inferred from the model config
155
+ model = RBLNColPaliForRetrieval.from_pretrained(
156
+ "vidore/colpali-v1.3-hf",
157
+ export=True,
158
+ rbln_max_seq_lens=1152,
159
+ )
160
+
161
+ # Using a config dictionary
162
+ rbln_config = {
163
+ "max_seq_lens": 1152,
164
+ "output_hidden_states": False,
165
+ }
166
+ model = RBLNColPaliForRetrieval.from_pretrained(
167
+ "vidore/colpali-v1.3-hf",
168
+ export=True,
169
+ rbln_config=rbln_config
170
+ )
171
+
172
+ # Using a RBLNColPaliForRetrievalConfig instance (recommended for type checking)
173
+ from optimum.rbln import RBLNColPaliForRetrievalConfig
174
+
175
+ config = RBLNColPaliForRetrievalConfig(
176
+ max_seq_lens=1152,
177
+ output_hidden_states=False,
178
+ tensor_parallel_size=4
179
+ )
180
+ model = RBLNColPaliForRetrieval.from_pretrained(
181
+ "vidore/colpali-v1.3-hf",
182
+ export=True,
183
+ rbln_config=config
184
+ )
185
+ ```
186
+ """
187
+
188
+ auto_model_class = None
189
+ _rbln_submodules = [
190
+ {"name": "vision_tower"},
191
+ ]
192
+
193
+ def __post_init__(self, **kwargs):
194
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
195
+ self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
196
+
197
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
198
+ self.embed_tokens = self._create_embedding_layer()
199
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
200
+ self.multi_modal_projector = self._create_multi_modal_projector()
201
+ self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
202
+
203
+ return super().__post_init__(**kwargs)
204
+
205
+ def _create_embedding_layer(self):
206
+ with no_init_weights():
207
+ embed_tokens = torch.nn.Embedding(
208
+ self.config.text_config.vocab_size,
209
+ self.config.text_config.hidden_size,
210
+ self.config.text_config.pad_token_id,
211
+ )
212
+ return embed_tokens
213
+
214
+ def _create_multi_modal_projector(self):
215
+ with no_init_weights():
216
+ multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
217
+ return multi_modal_projector
218
+
219
+ @classmethod
220
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
221
+ return RBLNColPaliForRetrievalWrapper(
222
+ causal_lm=model.vlm.language_model,
223
+ embedding_proj_layer=model.embedding_proj_layer,
224
+ max_seq_len=max(rbln_config.max_seq_lens),
225
+ output_hidden_states=rbln_config.output_hidden_states,
226
+ )
227
+
228
+ @classmethod
229
+ def save_torch_artifacts(
230
+ cls,
231
+ model: "PreTrainedModel",
232
+ save_dir_path: Path,
233
+ subfolder: str,
234
+ rbln_config: RBLNModelConfig,
235
+ ):
236
+ save_dict = {}
237
+ save_dict["embed_tokens"] = model.vlm.get_input_embeddings().state_dict()
238
+ save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
239
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
240
+
241
+ @classmethod
242
+ def _update_rbln_config(
243
+ cls,
244
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
245
+ model: Optional["PreTrainedModel"] = None,
246
+ model_config: Optional["PretrainedConfig"] = None,
247
+ rbln_config: Optional[RBLNModelConfig] = None,
248
+ ) -> RBLNModelConfig:
249
+ hidden_size = model_config.vlm_config.text_config.hidden_size
250
+ if rbln_config.max_seq_lens is None:
251
+ rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
252
+ if isinstance(rbln_config.max_seq_lens, int):
253
+ rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
254
+ rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
255
+
256
+ if rbln_config.output_hidden_states is None:
257
+ rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
258
+
259
+ input_infos = []
260
+ for max_seq_len in rbln_config.max_seq_lens:
261
+ input_info = [
262
+ ("inputs_embeds", [1, max_seq_len, hidden_size], "float32"),
263
+ ("attention_mask", [1, max_seq_len], "float32"),
264
+ ("position_ids", [1, max_seq_len], "int32"),
265
+ ]
266
+ input_infos.append(input_info)
267
+
268
+ rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
269
+ rbln_config.set_compile_cfgs([rbln_compile_config])
270
+
271
+ return rbln_config
272
+
273
+ @classmethod
274
+ def from_model(cls, model: "PreTrainedModel", *args, **kwargs):
275
+ if not hasattr(model, "vision_tower"):
276
+ model.vision_tower = model.vlm.vision_tower
277
+ del model.vlm.vision_tower
278
+ model = super().from_model(model, *args, **kwargs)
279
+ return model
280
+
281
+ @classmethod
282
+ def get_pytorch_model(cls, *args, **kwargs):
283
+ model = super().get_pytorch_model(*args, **kwargs)
284
+ model.vision_tower = model.vlm.vision_tower
285
+ del model.vlm.vision_tower
286
+
287
+ return model
288
+
289
+ def get_image_features(self, pixel_values: torch.Tensor):
290
+ # Projects the last hidden state from the vision model into language model space.
291
+ # Args:
292
+ # pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
293
+ # The tensors corresponding to the input images.
294
+ # Returns:
295
+ # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
296
+
297
+ vision_outputs = self.vision_tower(pixel_values).last_hidden_state
298
+ image_features = self.multi_modal_projector(vision_outputs)
299
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
300
+ return image_features
301
+
302
+ def _preprocess_inputs(
303
+ self,
304
+ input_ids: Optional[torch.LongTensor] = None,
305
+ inputs_embeds: Optional[torch.FloatTensor] = None,
306
+ pixel_values: Optional[torch.FloatTensor] = None,
307
+ **kwargs,
308
+ ):
309
+ if (input_ids is None) ^ (inputs_embeds is not None):
310
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
311
+
312
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
313
+ if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
314
+ special_image_mask = input_ids == self.config.vlm_config.image_token_index
315
+ llm_input_ids = input_ids.clone()
316
+ llm_input_ids[special_image_mask] = 0
317
+ else:
318
+ llm_input_ids = input_ids
319
+
320
+ if inputs_embeds is None:
321
+ inputs_embeds = self.embed_tokens(llm_input_ids)
322
+
323
+ # Merge text and images
324
+ image_features = None
325
+ if pixel_values is not None:
326
+ image_features = self.get_image_features(pixel_values)
327
+ special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
328
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
329
+
330
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
331
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
332
+
333
+ return inputs_embeds, image_features
334
+
335
+ def forward(
336
+ self,
337
+ input_ids: Optional[torch.LongTensor] = None,
338
+ inputs_embeds: Optional[torch.FloatTensor] = None,
339
+ pixel_values: Optional[torch.FloatTensor] = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ **kwargs,
345
+ ) -> ColPaliForRetrievalOutput:
346
+ if pixel_values is not None:
347
+ pixel_values = pixel_values.to(dtype=self.dtype)
348
+
349
+ if output_attentions:
350
+ raise ValueError("output_attentions is not supported for RBLNColPaliForRetrieval")
351
+
352
+ if output_hidden_states is not None and output_hidden_states != self.rbln_config.output_hidden_states:
353
+ raise ValueError(
354
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
355
+ f"Please compile again with the correct argument."
356
+ )
357
+
358
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
359
+
360
+ inputs_embeds, image_features = self._preprocess_inputs(
361
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
362
+ )
363
+
364
+ # Embedding_proj_layer is fused on the bottom of the language model.
365
+ outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
366
+
367
+ embeddings = outputs if not self.rbln_config.output_hidden_states else outputs[0]
368
+ hidden_states = None if not self.rbln_config.output_hidden_states else outputs[1]
369
+
370
+ # L2 normalization
371
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
372
+
373
+ if attention_mask is not None:
374
+ embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
375
+
376
+ if not return_dict:
377
+ return (embeddings, hidden_states, image_features)
378
+ else:
379
+ return ColPaliForRetrievalOutput(
380
+ embeddings=embeddings,
381
+ hidden_states=hidden_states,
382
+ image_hidden_states=image_features,
383
+ )
@@ -79,7 +79,7 @@ class Qwen2_5_VLVisionFullAttention(nn.Module):
79
79
  super().__init__()
80
80
  self._origin_model = model
81
81
  self.num_heads = model.num_heads
82
- self.head_dim = model.head_dim
82
+ self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
83
83
  self.qkv = model.qkv
84
84
  self.proj = model.proj
85
85
 
@@ -114,7 +114,7 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
114
114
  super().__init__()
115
115
  self._origin_model = model
116
116
  self.num_heads = model.num_heads
117
- self.head_dim = model.head_dim
117
+ self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
118
118
  self.qkv = model.qkv
119
119
  self.proj = model.proj
120
120
  self.window_seq_len = window_seq_len
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.1a4
3
+ Version: 0.8.1a5
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -25,9 +25,9 @@ Requires-Python: <3.13,>=3.9
25
25
  Requires-Dist: accelerate>=1.0.1
26
26
  Requires-Dist: diffusers==0.34.0
27
27
  Requires-Dist: packaging>=24.1
28
- Requires-Dist: torch==2.6.0
29
- Requires-Dist: torchaudio<=2.6.0
30
- Requires-Dist: torchvision<=0.21.0
28
+ Requires-Dist: torch==2.7.0
29
+ Requires-Dist: torchaudio<=2.7.0
30
+ Requires-Dist: torchvision<=0.22.0
31
31
  Requires-Dist: transformers==4.51.3
32
32
  Description-Content-Type: text/markdown
33
33
 
@@ -1,5 +1,5 @@
1
- optimum/rbln/__init__.py,sha256=qJJTumXhoFnawXGpeGJbAm4J4A9FFwD1SQ2MqcKDXoM,14436
2
- optimum/rbln/__version__.py,sha256=hdBV0MOKkAsGp6FVqyauDmHCC6gC0y_cyykn1_s49sg,519
1
+ optimum/rbln/__init__.py,sha256=Z5GM8hmc_cgNzhdfOAKbAQr-vFP24kC-IbiRaIOIxxE,14584
2
+ optimum/rbln/__version__.py,sha256=Ln2yvKWXaraNsP7hCs26LOEd96BBrL7JNrmQ42n0dqA,519
3
3
  optimum/rbln/configuration_utils.py,sha256=o5oer7fBdE-MHLGNXoP35FjmuQbMmjEIDv0QE_k3kpo,32336
4
4
  optimum/rbln/modeling.py,sha256=ZlJ_tOCWiFjDIlwJ_B_HOCO0kBduWrBAbW9VSEVIAFg,12088
5
5
  optimum/rbln/modeling_base.py,sha256=5fUb1FaxfjApzJIkT8-SrPhuygGo_1Uc0i7UedawOeE,23393
@@ -20,7 +20,7 @@ optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.p
20
20
  optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py,sha256=5pDsxol2tm9hYs8u6_6713VwHxCo-iNhAK5G4JVwNwU,7952
21
21
  optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py,sha256=zg7aRyp8jYJuAnb_dTg6HdACCcAvhv1jX2FhEfRD6V0,7114
22
22
  optimum/rbln/diffusers/models/__init__.py,sha256=mkCvJyH1KcwrsUvYSq_bVC79oOfyqtBSFDyPS1_48wA,1478
23
- optimum/rbln/diffusers/models/controlnet.py,sha256=yKPQTO2jwb9VRMagiqzEXMAwJfcyAnfqMD7Lc8AOsr8,10573
23
+ optimum/rbln/diffusers/models/controlnet.py,sha256=6owledPe9BXhbZOG8lbuuYvpBU0UrQV7zmat6SoMXOM,10585
24
24
  optimum/rbln/diffusers/models/autoencoders/__init__.py,sha256=dg17ZTUsiqTcbIaEE4fqew9uRbao0diQ21PXvRKIqKg,679
25
25
  optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py,sha256=UwaYFHXKRJTzJDmfYMC7-xvaWTh7JUDQYD3bRiQs4ZA,8367
26
26
  optimum/rbln/diffusers/models/autoencoders/vae.py,sha256=ja9yLhPYGmg1d3Kec6fS-6XgfS0yVJXuVsNDD0X3yHM,4048
@@ -61,11 +61,11 @@ optimum/rbln/ops/flash_attn.py,sha256=z39DJZSk94630ueoOCkiybxR5gzvNR-SRADHs0F6pz
61
61
  optimum/rbln/ops/kv_cache_update.py,sha256=HjnHBR-oFrJQibsVnkYb0P5_-wEma8jl0mkjkylwakU,1270
62
62
  optimum/rbln/ops/linear.py,sha256=1_7Hg-9wXxhu97fqPobotLQx17k7VPeSSL91_9Z7EDg,1018
63
63
  optimum/rbln/ops/sliding_window_attn.py,sha256=EQrV_yRGc5z6kvwEsAcLP028bJWkQg2UPI3xubt9skU,3487
64
- optimum/rbln/transformers/__init__.py,sha256=fE-kzDnWj0ueAG-xDrIKdBX59wCE__8m86uBMBOEb9g,9031
64
+ optimum/rbln/transformers/__init__.py,sha256=MF7OaGf-KI9rz4EOzejxHTDYUB3RO2L02BquTe0PXmI,9107
65
65
  optimum/rbln/transformers/configuration_generic.py,sha256=kNhPWtzF0IovUnrsXfxXdXITqgpfCAAedjfB6jSAhEg,5131
66
66
  optimum/rbln/transformers/modeling_generic.py,sha256=u1JzjWcPsQgH_rqBzRVr582NARqOk7XVKgY4CdEfXe8,12228
67
67
  optimum/rbln/transformers/modeling_rope_utils.py,sha256=6Zg3r-TeUk4WQAlr95pqfhuoAD_RQ4njT1rbO9uPL0Q,14379
68
- optimum/rbln/transformers/models/__init__.py,sha256=-rc_00p4d58cdM2ylmgURxoAGKgIRF7X7r6z1w6h3mo,10061
68
+ optimum/rbln/transformers/models/__init__.py,sha256=VVQJgpUUnN4MPAQlOsxsw63w7WPK05ggFfRkGYuZFJQ,10266
69
69
  optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py,sha256=I2vL4lrzbT5p4eJcH-EKHzEfcPkj_XVsie7jb9q6yic,775
70
70
  optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py,sha256=z7LJiVJPmnlCM3mcyhPJP8AufSrxO_dsPeJ51onq-Nc,833
71
71
  optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py,sha256=FIKEVWpIt6-JQX9B_rAfCrAPqdUHtR2i8D_X2k7639E,1498
@@ -85,6 +85,10 @@ optimum/rbln/transformers/models/blip_2/modeling_blip_2.py,sha256=gx9pPXQfaIjDUN
85
85
  optimum/rbln/transformers/models/clip/__init__.py,sha256=TLeXDqcFK6M6v9x7Xr64kBbqGu3hFHM7p754dQ8UVQc,938
86
86
  optimum/rbln/transformers/models/clip/configuration_clip.py,sha256=mgtR_lS1_g5vAh_wWarff3-pwM_tzzRAWm7XkfhGwmo,3019
87
87
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=0u1JTlO47qoH_-qxWGvXLc67whddLzcuLoMB5KaMh94,7285
88
+ optimum/rbln/transformers/models/colpali/__init__.py,sha256=n3rueXT_oC0N8myoZiic0YkVK24CW5hZBPa-0L8so6Y,119
89
+ optimum/rbln/transformers/models/colpali/colpali_architecture.py,sha256=bWG7TehWRZkTh2y6mGkpd85_onWAyiyKdaQC9TFsy3E,8065
90
+ optimum/rbln/transformers/models/colpali/configuration_colpali.py,sha256=yPzLYON6qRJlBkzxFfIBzBWd2KjYWvdClO4iAqd_V7E,2609
91
+ optimum/rbln/transformers/models/colpali/modeling_colpali.py,sha256=jzvJCBrrCXSpjfmJ3O-VvPNFGWGaNbpOV09JwLPAZWs,15757
88
92
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=vQYZDDdoddwA7yKc5zzrq2Zs9sax-0p8rNF_aYfF4bk,1006
89
93
  optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py,sha256=cakn8RGo8gS3nmXdEqOfC2xUBOMGInROgLEbCOoLFR0,13398
90
94
  optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=YAn8J_lIq4IS-HM_gbi5Qov8_osxhWtBr5z_28QRbGM,49667
@@ -144,7 +148,7 @@ optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLoh
144
148
  optimum/rbln/transformers/models/qwen2_5_vl/__init__.py,sha256=rAW3DKQUzGL6EMwa5r1iLu94yhpiZpk6zfoD7TtYXrc,865
145
149
  optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py,sha256=U3ngIfkA58itqQZqTf-gbISMPoV7ipDttI7V2uwK_18,4155
146
150
  optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py,sha256=Q4U-avMkby-CunNXEERqvRZx9duC5i-6UmfF1376ciU,26336
147
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py,sha256=PAQz__9o_f5phlozhhXAB8JErBlS1jc4FYZkZkSYJuI,7312
151
+ optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py,sha256=oU4MyNeDHzqD3dl1DgwrMev07yvMFhl_hXvV6tRdXCo,7422
148
152
  optimum/rbln/transformers/models/resnet/__init__.py,sha256=0QqtEQF1IMYgEmmfXMGarCDS8kJB5tzODfwTEzDVZRg,837
149
153
  optimum/rbln/transformers/models/resnet/configuration_resnet.py,sha256=KQd887jgNOl_Am3b407P2OvKtzkkeBS1cEhCfiN0tJg,769
150
154
  optimum/rbln/transformers/models/resnet/modeling_resnet.py,sha256=E8vg3Rw_KsHt6vaOg0ungZD7sXe0T4OMP0X8NFG1EXI,816
@@ -191,7 +195,7 @@ optimum/rbln/utils/model_utils.py,sha256=4k5879Kh75m3x_vS4-qOGfqsOiAvc2kdNFFfvsF
191
195
  optimum/rbln/utils/runtime_utils.py,sha256=LoKNK3AQNV_BSScstIZWjICkJf265MnUgy360BOocVI,5454
192
196
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
193
197
  optimum/rbln/utils/submodule.py,sha256=w5mgPgncI740gVKMu3S-69DGNdUSI0bTZxegQGcZ98Y,5011
194
- optimum_rbln-0.8.1a4.dist-info/METADATA,sha256=jo7yVVPhX8QJJK0WE1x2ReG_VbuNiyhAkAPj9Um90A8,5299
195
- optimum_rbln-0.8.1a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
196
- optimum_rbln-0.8.1a4.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
197
- optimum_rbln-0.8.1a4.dist-info/RECORD,,
198
+ optimum_rbln-0.8.1a5.dist-info/METADATA,sha256=yetswBiXM1Cce75lQOgrUw3pNMuaxt6XoaclWnDlGIE,5299
199
+ optimum_rbln-0.8.1a5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
200
+ optimum_rbln-0.8.1a5.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
201
+ optimum_rbln-0.8.1a5.dist-info/RECORD,,