optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -12,48 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Optional, Union
15
+
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
17
 
17
18
  import torch
18
- from transformers import (
19
- PretrainedConfig,
20
- PreTrainedModel,
21
- )
19
+ from transformers import ColQwen2Config, ColQwen2ForRetrieval
22
20
  from transformers.modeling_utils import no_init_weights
23
21
  from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrievalOutput
24
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
25
- Qwen2_5_VLModel,
26
- Qwen2_5_VLRotaryEmbedding,
27
- )
28
- from transformers.models.qwen2_vl.modeling_qwen2_vl import (
29
- Qwen2VLModel,
30
- Qwen2VLRotaryEmbedding,
31
- )
32
-
33
- from optimum.rbln.transformers.models.decoderonly.modeling_decoderonly import (
34
- RBLNDecoderOnlyModel,
35
- )
36
22
 
37
- from .configuration_colqwen2 import (
38
- RBLNColQwen2ForRetrievalConfig,
39
- )
23
+ from ....modeling import RBLNModel
24
+ from ....transformers.modeling_outputs import _validate_output_hidden_states
40
25
 
41
26
 
42
27
  if TYPE_CHECKING:
43
- from transformers import (
44
- AutoFeatureExtractor,
45
- AutoProcessor,
46
- AutoTokenizer,
47
- PretrainedConfig,
48
- )
28
+ from transformers import PreTrainedModel
49
29
 
50
- from .colqwen2_architecture import ColQwen2LanguageModelWrapper
51
30
 
52
-
53
- class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
31
+ class RBLNColQwen2ForRetrieval(RBLNModel):
54
32
  """
55
- The ColQwen Model transformer for document retrieval using vision-language models.
56
- This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
33
+ RBLNColQwen2ForRetrieval is a model for document retrieval using vision-language models.
34
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
57
35
 
58
36
  A class to convert and run pre-trained transformers based `ColQwen2ForRetrieval` model on RBLN devices.
59
37
  It implements the methods to convert a pre-trained transformers `ColQwen2ForRetrieval` model into a RBLN transformer model by:
@@ -61,326 +39,82 @@ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
61
39
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
62
40
  - compiling the resulting graph using the RBLN compiler.
63
41
 
64
- **Configuration:**
65
- This model uses [`RBLNColQwen2ForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
66
- the `rbln_config` parameter should be an instance of [`RBLNColQwen2ForRetrievalConfig`] or a dictionary conforming to its structure.
67
-
68
- See the [`RBLNColQwen2ForRetrievalConfig`] class for all available configuration options.
69
-
70
42
  Examples:
71
43
  ```python
72
- from optimum.rbln import RBLNColQwen2ForRetrieval
44
+ import torch
45
+ from PIL import Image
46
+ from transformers import ColQwen2Processor
47
+
48
+ from optimum.rbln import RBLNColQwen2ForRetrieval, RBLNColQwen2ForRetrievalConfig
73
49
 
74
- # Using a config dictionary
75
50
  rbln_config = {
76
- "visual": {
77
- "max_seq_lens": 6400,
51
+ "vlm": {
52
+ "visual": {
53
+ "max_seq_lens": 6400,
54
+ },
55
+ "tensor_parallel_size": 4,
56
+ "kvcache_partition_len": 16384,
57
+ "max_seq_len": 16384 * 7,
78
58
  },
79
- "max_seq_len": 32_768,
80
- "tensor_parallel_size": 4,
81
- "device": [0, 1, 2, 3],
82
- "output_hidden_states": False,
83
59
  }
84
- model = RBLNColQwen2ForRetrieval.from_pretrained(
85
- "vidore/colqwen2-v1.0-hf",
86
- export=True,
87
- rbln_config=rbln_config
88
- )
89
-
90
- # Using a RBLNColQwen2ForRetrievalConfig instance (recommended for type checking)
91
- from optimum.rbln import RBLNColQwen2ForRetrievalConfig
92
-
93
- config = RBLNColQwen2ForRetrievalConfig(
94
- visual={
95
- "max_seq_lens": 6400,
96
- "device": 0,
97
- },
98
- max_seq_len=32_768,
99
- tensor_parallel_size=4,
100
- device=[0, 1, 2, 3],
101
- output_hidden_states=False,
102
- )
103
- model = RBLNColQwen2ForRetrieval.from_pretrained(
104
- "vidore/colqwen2-v1.0-hf",
105
- export=True,
106
- rbln_config=config
107
- )
60
+ model = RBLNColQwen2ForRetrieval.from_pretrained("vidore/colqwen2-v1.0-hf", rbln_config=config)
61
+ model.save_pretrained("compiled-colqwen2-v1.0-hf")
62
+
63
+ # The document page screenshots from your corpus. Below are dummy images.
64
+ images = [
65
+ Image.new("RGB", (128, 128), color="white"),
66
+ Image.new("RGB", (64, 32), color="black"),
67
+ ]
68
+ processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0-hf")
69
+
70
+ queries = [
71
+ "When was the United States Declaration of Independence proclaimed?",
72
+ "Who printed the edition of Romeo and Juliet?",
73
+ ]
74
+ inputs_images = processor(images=images)
75
+ inputs_text = processor(text=queries)
76
+
77
+ # Forward pass
78
+ with torch.no_grad():
79
+ image_embeddings = model(**inputs_images).embeddings
80
+ query_embeddings = model(**inputs_text).embeddings
81
+
82
+ scores = processor.score_retrieval(query_embeddings, image_embeddings)
83
+ print("Retrieval scores (query x image):")
84
+ print(scores)
108
85
  ```
109
86
  """
110
87
 
111
- main_input_name = "inputs_embeds"
112
- auto_model_class = None
88
+ _rbln_submodule_postfix = "model"
113
89
  _rbln_submodules = [
114
- {"name": "visual"},
90
+ {"name": "vlm"},
115
91
  ]
116
- _decoder_wrapper_cls = ColQwen2LanguageModelWrapper
117
- _use_rotary_emb = False
92
+ _supports_non_fp32 = True
118
93
 
119
94
  def __post_init__(self, **kwargs):
120
- self.config = self.config.vlm_config if hasattr(self.config, "vlm_config") else self.config
121
-
122
- artifacts = torch.load(
123
- self.model_save_dir / self.subfolder / "torch_artifacts.pth",
124
- weights_only=False,
125
- )
126
- self.embed_tokens = self._create_embedding_layer()
127
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
128
- self.visual = self.rbln_submodules[0]
129
- self.prefill_runtime = self.model[0]
130
- self.mrope_section = self.config.text_config.rope_scaling["mrope_section"]
131
- self.is_colqwen2_5 = "qwen2_5_vl" in self.config.model_type
132
-
133
- if self.is_colqwen2_5:
134
- self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config.text_config)
135
- else:
136
- self.rotary_emb = Qwen2VLRotaryEmbedding(self.config.text_config)
137
- self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
95
+ self.vlm_model = self.rbln_submodules[0]
96
+ return super().__post_init__(**kwargs)
138
97
 
139
98
  @classmethod
140
99
  def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
141
- if hasattr(model, "vlm"):
142
- model.visual = model.vlm.visual
143
- model.language_model = model.vlm.language_model
144
-
145
- # FIXME: temporary fix for ColQwen2ForRetrieval dtype issue
146
- return model.to(torch.float32)
147
-
148
- def _create_embedding_layer(self):
149
- with no_init_weights():
150
- embed_tokens = torch.nn.Embedding(
151
- self.config.text_config.vocab_size,
152
- self.config.text_config.hidden_size,
153
- self.config.text_config.pad_token_id,
154
- )
155
- return embed_tokens
156
-
157
- @classmethod
158
- def get_input_info(
159
- cls,
160
- batch_size: int,
161
- query_length: int,
162
- rbln_config: RBLNColQwen2ForRetrievalConfig,
163
- model_config: PretrainedConfig,
164
- ):
165
- text_config = model_config.text_config
166
- input_info = super().get_input_info(
167
- batch_size,
168
- query_length,
169
- rbln_config=rbln_config,
170
- model_config=text_config,
171
- )
172
-
173
- pos_idx = 3
174
- input_info.insert(
175
- pos_idx,
176
- (
177
- "position_emb",
178
- [
179
- 2,
180
- batch_size,
181
- 1,
182
- query_length,
183
- text_config.hidden_size // text_config.num_attention_heads,
184
- ],
185
- rbln_config.torch_dtype,
186
- ),
187
- )
188
-
189
- # remove query postion from input_info
190
- if "query_position" in input_info:
191
- query_position = input_info.pop(4)
192
- assert query_position[0] == "query_position", print(query_position[0], "is deleted.")
193
- return input_info
194
-
195
- @classmethod
196
- def _update_rbln_config(
197
- cls,
198
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
199
- model: Optional["PreTrainedModel"] = None,
200
- model_config: Optional["PretrainedConfig"] = None,
201
- rbln_config: Optional[RBLNColQwen2ForRetrievalConfig] = None,
202
- ) -> RBLNColQwen2ForRetrievalConfig:
203
- model_config = model_config.vlm_config if hasattr(model_config, "vlm_config") else model_config
204
- if rbln_config.output_hidden_states is None:
205
- rbln_config.output_hidden_states = getattr(model_config.text_config, "output_hidden_states", False)
206
-
207
- return super()._update_rbln_config(
208
- preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
209
- )
210
-
211
- def _get_position_embeddings(self, hidden_states, position_ids):
212
- cos, sin = self.rotary_emb(hidden_states, position_ids)
213
- mrope_section = self.mrope_section * 2
214
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
215
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
216
- return torch.stack([cos, sin])
217
-
218
- def get_rope_index(self, *args, **kwargs):
219
- if self.is_colqwen2_5:
220
- return Qwen2_5_VLModel.get_rope_index(self, *args, **kwargs)
221
- else:
222
- return Qwen2VLModel.get_rope_index(self, *args, **kwargs)
223
-
224
- def _preprocess_visual(
225
- self,
226
- input_ids: torch.LongTensor = None,
227
- attention_mask: torch.Tensor = None,
228
- pixel_values: torch.Tensor = None,
229
- pixel_values_videos: torch.FloatTensor = None,
230
- image_grid_thw: torch.LongTensor = None,
231
- video_grid_thw: torch.LongTensor = None,
232
- second_per_grid_ts: torch.Tensor = None,
233
- ):
234
- batch_size = input_ids.shape[0]
235
- inputs_embeds = self.embed_tokens(input_ids)
236
-
237
- if pixel_values is not None:
238
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
239
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
240
- n_image_features = image_embeds.shape[0]
241
- if n_image_tokens != n_image_features:
242
- raise ValueError(
243
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
100
+ # if model is from Colpali-engine, convert it to a ColQwen2ForRetrieval model
101
+ if hasattr(model, "custom_text_proj"):
102
+ with no_init_weights():
103
+ model_config = ColQwen2Config(
104
+ vlm_config=model.config, embedding_dim=model.custom_text_proj.out_features
244
105
  )
106
+ new_model = ColQwen2ForRetrieval._from_config(model_config)
107
+ new_model.embedding_proj_layer = model.custom_text_proj
108
+ new_model.vlm.model.visual.load_state_dict(model.visual.state_dict())
109
+ new_model.vlm.model.language_model.load_state_dict(model.language_model.state_dict())
110
+ model = new_model
245
111
 
246
- mask = input_ids == self.config.image_token_id
247
- mask_unsqueezed = mask.unsqueeze(-1)
248
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
112
+ # replace the lm_head with the custom text projection layer for optimization
113
+ model.vlm.model.lm_head = model.embedding_proj_layer
114
+ model.vlm.model.config.embedding_dim = model.config.embedding_dim
249
115
 
250
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
251
- inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, image_embeds)
252
-
253
- if pixel_values_videos is not None:
254
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
255
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
256
- n_video_features = video_embeds.shape[0]
257
- if n_video_tokens != n_video_features:
258
- raise ValueError(
259
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
260
- )
261
-
262
- mask = input_ids == self.config.video_token_id
263
- mask_unsqueezed = mask.unsqueeze(-1)
264
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
265
- inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
266
-
267
- max_inputs_len = input_ids.shape[1]
268
- head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
269
- all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
270
- all_rope_deltas = []
271
-
272
- image_token_id = self.config.image_token_id
273
- video_token_id = self.config.video_token_id
274
- vision_start_token_id = self.config.vision_start_token_id
275
- image_idx, video_idx = 0, 0
276
-
277
- for b_idx in range(batch_size):
278
- input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
279
- vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
280
- vision_tokens = input_id[0][vision_start_indices + 1]
281
- image_nums = (vision_tokens == image_token_id).sum()
282
- video_nums = (vision_tokens == video_token_id).sum()
283
- args = [
284
- input_id,
285
- image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
286
- video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
287
- ]
288
- if self.config.model_type == "qwen2_5_vl":
289
- args.append(
290
- second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None
291
- )
292
- position_ids, rope_deltas = self.get_rope_index(*args)
293
- image_idx += image_nums
294
- video_idx += video_nums
295
-
296
- position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
297
- mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
298
- all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
299
- all_rope_deltas.append(rope_deltas)
300
-
301
- rope_deltas = torch.stack(all_rope_deltas)
302
-
303
- return inputs_embeds, all_position_embeds, rope_deltas
304
-
305
- def _preprocess_chunked_prefill(self, inputs_embeds, attention_mask, position_embed):
306
- # valid sequence length of inputs_embeds
307
- query_length = inputs_embeds.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
308
-
309
- # extract valid inputs
310
- inputs_embeds = inputs_embeds[:, attention_mask.bool()] if attention_mask is not None else inputs_embeds
311
- position_embed = (
312
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
313
- )
314
-
315
- # add padding for chunked prefill
316
- padding_size = (
317
- self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
318
- ) % self.rbln_config.prefill_chunk_size
319
- padded_len = query_length + padding_size
320
-
321
- inputs_embeds = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, padding_size))
322
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
323
- cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
324
-
325
- return inputs_embeds, position_embed, cache_position, query_length
326
-
327
- def _chunked_prefill_forward(
328
- self,
329
- inputs_embeds: torch.Tensor,
330
- attention_mask: Optional[torch.Tensor] = None,
331
- position_embed: Optional[torch.Tensor] = None,
332
- output_hidden_states: Optional[bool] = False,
333
- ):
334
- padded_inputs_embeds, padded_position_embed, cache_position, query_length = self._preprocess_chunked_prefill(
335
- inputs_embeds, attention_mask, position_embed
336
- )
337
-
338
- # Chunked prefill
339
- projs = []
340
- all_hidden_states = [] if output_hidden_states else None
341
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
342
- # Extract the current chunk of inputs and cache positions
343
- input_chunk = padded_inputs_embeds[:, step : step + self.rbln_config.prefill_chunk_size]
344
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
345
- position_embed_chunk = padded_position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
346
-
347
- # Forward pass for the current chunk
348
- proj = self.prefill_runtime(
349
- inputs_embeds=input_chunk,
350
- cache_position=cache_pos_chunk,
351
- block_tables=self.block_tables,
352
- position_emb=position_embed_chunk,
353
- )
354
-
355
- if output_hidden_states:
356
- projs.append(proj[0])
357
- all_hidden_states.append(proj[1:])
358
- else:
359
- projs.append(proj)
360
-
361
- projs = torch.concat(projs, dim=-2)[:, :query_length]
362
- if output_hidden_states:
363
- # Concatenate chunks for each layer
364
- concatenated_hidden_states = [
365
- torch.concat(hs_chunks, dim=-2)[:, :query_length] for hs_chunks in list(zip(*all_hidden_states))
366
- ]
367
- all_hidden_states = tuple(concatenated_hidden_states)
368
-
369
- return self._postprocess_chunked_prefill(projs, attention_mask), all_hidden_states
370
-
371
- def _postprocess_chunked_prefill(self, projs, attention_mask):
372
- # index copy for attention mask
373
- if attention_mask is not None:
374
- embedding = torch.full(
375
- (1, attention_mask.shape[-1], projs.shape[-1]),
376
- fill_value=1e-10,
377
- dtype=projs.dtype,
378
- )
379
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
380
- embedding.index_copy_(dim=-2, index=mask_indices, source=projs)
381
- else:
382
- embedding = projs
383
- return embedding
116
+ # Some of the model weights are different from the model.dtype(vidore/colqwen2-v1.0-hf)
117
+ return model.to(model.dtype)
384
118
 
385
119
  def forward(
386
120
  self,
@@ -388,22 +122,34 @@ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
388
122
  inputs_embeds: Optional[torch.FloatTensor] = None,
389
123
  attention_mask: Optional[torch.Tensor] = None,
390
124
  pixel_values: Optional[torch.Tensor] = None,
391
- pixel_values_videos: Optional[torch.FloatTensor] = None,
392
125
  image_grid_thw: Optional[torch.LongTensor] = None,
393
- video_grid_thw: Optional[torch.LongTensor] = None,
394
- second_per_grid_ts: Optional[torch.Tensor] = None,
395
126
  output_hidden_states: Optional[bool] = None,
127
+ return_dict: Optional[bool] = None,
396
128
  **kwargs,
397
- ) -> torch.Tensor:
398
- output_hidden_states = (
399
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
400
- )
129
+ ) -> Union[Tuple, ColQwen2ForRetrievalOutput]:
130
+ """
131
+ Runs a ColQwen2 retrieval forward pass on text tokens and optional visual inputs.
132
+
133
+ Args:
134
+ input_ids (torch.LongTensor, optional): Indices of the textual tokens. Mutually exclusive with `inputs_embeds`.
135
+ inputs_embeds (torch.FloatTensor, optional): Pre-computed embeddings fed directly into the language model.
136
+ attention_mask (torch.Tensor, optional): Mask that selects which token positions contribute to the loss/embeddings.
137
+ pixel_values (torch.Tensor, optional): Flattened image patches produced by `ColQwen2Processor` for document pages.
138
+ image_grid_thw (torch.LongTensor, optional): Per-image `(t, h, w)` grid metadata that allows unpadding of `pixel_values`.
139
+ output_hidden_states (bool, optional): If `True`, expose intermediate decoder hidden states.
140
+ return_dict (bool, optional): If `True`, return a `ColQwen2ForRetrievalOutput`; otherwise return a tuple.
141
+ **kwargs (dict[str, Any], optional): Extra multimodal args forwarded to the wrapped VLM (e.g. `pixel_values_videos`,
142
+ `video_grid_thw`, `second_per_grid_ts`).
143
+
144
+ Returns:
145
+ Dataclass containing the embeddings and hidden states of the VLM model.
146
+ """
147
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
401
148
 
402
- if output_hidden_states != self.rbln_config.output_hidden_states:
403
- raise ValueError(
404
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
405
- f"Please compile again with the correct argument."
406
- )
149
+ if pixel_values is not None:
150
+ pixel_values = pixel_values.to(dtype=self.dtype)
151
+
152
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
407
153
 
408
154
  # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
409
155
  if pixel_values is not None and image_grid_thw is not None:
@@ -412,35 +158,25 @@ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
412
158
  [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
413
159
  dim=0,
414
160
  )
415
- # visual preprocessing
416
- inputs_embeds, position_embed, _ = self._preprocess_visual(
417
- input_ids,
418
- attention_mask,
419
- pixel_values,
420
- pixel_values_videos,
421
- image_grid_thw,
422
- video_grid_thw,
423
- second_per_grid_ts,
161
+
162
+ vlm_output = self.vlm_model(
163
+ input_ids=input_ids,
164
+ attention_mask=attention_mask,
165
+ pixel_values=pixel_values,
166
+ image_grid_thw=image_grid_thw,
167
+ output_hidden_states=output_hidden_states,
168
+ return_dict=True,
169
+ **kwargs,
424
170
  )
425
- batch_size = inputs_embeds.shape[0]
426
-
427
- projs = []
428
- for b_idx in range(batch_size):
429
- proj = self._chunked_prefill_forward(
430
- inputs_embeds[b_idx : b_idx + 1],
431
- attention_mask[b_idx] if attention_mask is not None else None,
432
- position_embed[:, b_idx : b_idx + 1],
433
- output_hidden_states=output_hidden_states,
434
- )
435
- projs.append(proj[0])
436
- all_hidden_states = proj[1] if output_hidden_states else ()
171
+ hidden_states = vlm_output.hidden_states if output_hidden_states else None
172
+
173
+ embeddings = vlm_output[0]
174
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
437
175
 
438
- # postprocess
439
- projs = torch.cat(projs, dim=0)
440
- projs = projs / projs.norm(dim=-1, keepdim=True)
441
- projs = projs * attention_mask.unsqueeze(-1)
176
+ if attention_mask is not None:
177
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
442
178
 
443
179
  return ColQwen2ForRetrievalOutput(
444
- embeddings=projs,
445
- hidden_states=all_hidden_states,
180
+ embeddings=embeddings,
181
+ hidden_states=hidden_states,
446
182
  )
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....ops import (
16
+ custom_moe_ff,
17
+ custom_moe_glu,
16
18
  paged_attn_decode,
17
19
  paged_attn_prefill,
18
20
  paged_causal_attn_decode,