optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,564 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import inspect
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union
19
+
20
+ import torch
21
+ from transformers import AutoModelForVision2Seq, PaliGemmaForConditionalGeneration, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
23
+ from transformers.modeling_utils import no_init_weights
24
+ from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
25
+ from transformers.models.paligemma.modeling_paligemma import PaligemmaModelOutputWithPast, PaliGemmaMultiModalProjector
26
+
27
+ from ....configuration_utils import RBLNModelConfig
28
+ from ....modeling import RBLNModel
29
+ from ....utils.logging import get_logger
30
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
33
+
34
+
35
+ logger = get_logger(__name__)
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
39
+
40
+
41
+ class LoopVisionTower(LoopProcessor):
42
+ def __init__(self, vision_tower: "RBLNModel"):
43
+ super().__init__(model=vision_tower.model[0])
44
+
45
+ def _get_batch_size(self, pixel_values, **kwargs):
46
+ return pixel_values.shape[0]
47
+
48
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
49
+ pixel_values_item = pixel_values[index : index + 1]
50
+ out_buffer = kwargs["out"][index : index + 1]
51
+ return ([pixel_values_item], {"out": out_buffer})
52
+
53
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
54
+ return BaseModelOutputWithPooling(
55
+ last_hidden_state=kwargs["out"],
56
+ )
57
+
58
+
59
+ class RBLNPaliGemmaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
60
+ """
61
+ RBLNPaliGemmaForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
62
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
63
+
64
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
65
+
66
+ Important Note:
67
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
68
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
69
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNPaliGemmaForConditionalGeneration class for details.
70
+
71
+ Examples:
72
+ ```python
73
+ from optimum.rbln import RBLNPaliGemmaForConditionalGeneration
74
+
75
+ model = RBLNPaliGemmaForConditionalGeneration.from_pretrained(
76
+ "google/paligemma2-3b-mix-224",
77
+ export=True,
78
+ rbln_config={
79
+ "language_model": {
80
+ "prefill_chunk_size": 8192,
81
+ }
82
+ },
83
+ rbln_tensor_parallel_size=4,
84
+ )
85
+
86
+ model.save_pretrained("compiled-paligemma2-3b-mix-224")
87
+ ```
88
+ """
89
+
90
+ auto_model_class = AutoModelForVision2Seq
91
+ _rbln_submodules = [
92
+ {"name": "vision_tower"},
93
+ {"name": "language_model"},
94
+ ]
95
+
96
+ def __getattr__(self, __name: str) -> Any:
97
+ def redirect(func):
98
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
99
+
100
+ val = getattr(PaliGemmaForConditionalGeneration, __name)
101
+
102
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
103
+ return redirect(val)
104
+ return val
105
+
106
+ def can_generate(self):
107
+ return True
108
+
109
+ @classmethod
110
+ def _update_submodule_rbln_config(
111
+ cls,
112
+ submodule_name: str,
113
+ submodule_cls: Type["RBLNModel"],
114
+ model: "PreTrainedModel",
115
+ submodule_config: PretrainedConfig,
116
+ submodule_rbln_config: RBLNModelConfig,
117
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
118
+ ):
119
+ if submodule_name == "language_model":
120
+ submodule_config.use_sliding_window = False
121
+ else:
122
+ return submodule_rbln_config
123
+
124
+ return submodule_rbln_config
125
+
126
+ @classmethod
127
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
128
+ with no_init_weights():
129
+ model_cls_name = model.model.language_model.__class__.__name__
130
+ causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
131
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
132
+ new_language_model = causal_model_cls(model.model.language_model.config)
133
+
134
+ new_language_model.lm_head = model.lm_head
135
+ new_language_model.model = model.model.language_model
136
+ model.model.language_model = new_language_model
137
+ model.lm_head = None
138
+ del model.lm_head
139
+ return model
140
+
141
+ def __post_init__(self, **kwargs):
142
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
143
+ self.language_model = self.rbln_submodules[1]
144
+
145
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
146
+ self.embed_tokens = self._create_embedding_layer()
147
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
148
+ self.multi_modal_projector = self._create_multi_modal_projector()
149
+ self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
150
+
151
+ return super().__post_init__(**kwargs)
152
+
153
+ @classmethod
154
+ def save_torch_artifacts(
155
+ cls,
156
+ model: "PaliGemmaForConditionalGeneration",
157
+ save_dir_path: Path,
158
+ subfolder: str,
159
+ rbln_config: RBLNModelConfig,
160
+ ):
161
+ save_dict = {}
162
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
163
+ save_dict["multi_modal_projector"] = model.multi_modal_projector.state_dict()
164
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
165
+
166
+ def get_attn_impl(self) -> str:
167
+ return self.rbln_config.language_model.attn_impl
168
+
169
+ def get_kvcache_num_blocks(self) -> int:
170
+ return self.rbln_config.language_model.kvcache_num_blocks
171
+
172
+ def get_input_embeddings(self):
173
+ return self.language_model.get_input_embeddings()
174
+
175
+ def _create_embedding_layer(self):
176
+ with no_init_weights():
177
+ embed_tokens = torch.nn.Embedding(
178
+ self.config.text_config.vocab_size,
179
+ self.config.text_config.hidden_size,
180
+ self.config.text_config.pad_token_id,
181
+ )
182
+ return embed_tokens
183
+
184
+ def _create_multi_modal_projector(self):
185
+ with no_init_weights():
186
+ multi_modal_projector = PaliGemmaMultiModalProjector(self.config)
187
+ return multi_modal_projector
188
+
189
+ def prepare_inputs_for_generation(
190
+ self,
191
+ input_ids,
192
+ inputs_embeds=None,
193
+ pixel_values=None,
194
+ image_sizes=None,
195
+ attention_mask=None,
196
+ generate_idx=None,
197
+ position_ids=None,
198
+ token_type_ids=None,
199
+ **kwargs,
200
+ ):
201
+ # Prepare HF generation
202
+ is_prefill_phase = generate_idx is None
203
+
204
+ model_inputs = self.language_model.prepare_inputs_for_generation(
205
+ input_ids=input_ids,
206
+ inputs_embeds=inputs_embeds,
207
+ generate_idx=generate_idx, # Not affect
208
+ attention_mask=attention_mask,
209
+ position_ids=position_ids,
210
+ **kwargs,
211
+ )
212
+
213
+ if is_prefill_phase:
214
+ model_inputs.update(
215
+ {
216
+ "pixel_values": pixel_values,
217
+ "token_type_ids": token_type_ids,
218
+ }
219
+ )
220
+
221
+ model_inputs["attention_mask"] = attention_mask
222
+
223
+ return model_inputs
224
+
225
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
226
+ model_kwargs["generate_idx"] = outputs.generate_idx
227
+ return model_kwargs
228
+
229
+ def get_image_features(self, pixel_values: torch.Tensor):
230
+ vision_output_size = [
231
+ pixel_values.shape[0],
232
+ self.config.vision_config.num_image_tokens,
233
+ self.config.vision_config.hidden_size,
234
+ ]
235
+ vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
236
+ self.vision_tower(pixel_values, out=vision_output)
237
+ image_features = self.multi_modal_projector(vision_output)
238
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
239
+ return image_features
240
+
241
+ def get_placeholder_mask(
242
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
243
+ ):
244
+ if input_ids is None:
245
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
246
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
247
+ )
248
+ special_image_mask = special_image_mask.all(-1)
249
+ else:
250
+ special_image_mask = input_ids == self.config.image_token_id
251
+
252
+ n_image_tokens = special_image_mask.sum()
253
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
254
+ n_image_features = image_features.shape[0] * image_features.shape[1]
255
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
256
+ raise ValueError(
257
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
258
+ )
259
+ return special_image_mask
260
+
261
+ def _preprocess_prefill(
262
+ self,
263
+ input_ids: Optional[torch.LongTensor] = None,
264
+ inputs_embeds: Optional[torch.FloatTensor] = None,
265
+ pixel_values: Optional[torch.FloatTensor] = None,
266
+ **kwargs,
267
+ ):
268
+ if (input_ids is None) ^ (inputs_embeds is not None):
269
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
270
+
271
+ if input_ids is not None and self.config.image_token_id >= self.config.text_config.vocab_size:
272
+ special_image_mask = input_ids == self.config.image_token_id
273
+ llm_input_ids = input_ids.clone()
274
+ llm_input_ids[special_image_mask] = 0
275
+ else:
276
+ llm_input_ids = input_ids
277
+
278
+ if inputs_embeds is None:
279
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
280
+
281
+ if pixel_values is not None:
282
+ image_features = self.get_image_features(pixel_values)
283
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
284
+ special_image_mask = self.get_placeholder_mask(
285
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
286
+ )
287
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
288
+
289
+ return inputs_embeds
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.LongTensor = None,
294
+ pixel_values: torch.FloatTensor = None,
295
+ attention_mask: torch.LongTensor = None,
296
+ position_ids: torch.LongTensor = None,
297
+ token_type_ids: torch.LongTensor = None,
298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
299
+ cache_position: torch.Tensor = None,
300
+ generate_idx: Optional[torch.Tensor] = None,
301
+ return_dict: Optional[bool] = None,
302
+ **kwargs,
303
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
304
+ # Prefill
305
+ if cache_position is None:
306
+ inputs_embeds = self._preprocess_prefill(
307
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
308
+ )
309
+ logits = []
310
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
311
+ batch_size = inputs.shape[0]
312
+
313
+ for b_idx in range(batch_size):
314
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
315
+ output = self.language_model.prefill_decoder(
316
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
317
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
318
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
319
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
320
+ cache_position=cache_position,
321
+ batch_idx=b_idx,
322
+ )
323
+ logits.append(output.logits)
324
+
325
+ logits = torch.cat(logits, dim=0)
326
+ # Decoder
327
+ else:
328
+ logits = self.language_model.decoder(
329
+ input_ids=input_ids,
330
+ inputs_embeds=inputs_embeds,
331
+ cache_position=cache_position,
332
+ position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
333
+ ).logits
334
+
335
+ if not return_dict:
336
+ return logits, generate_idx
337
+ else:
338
+ return RBLNDecoderOnlyOutput(
339
+ logits=logits,
340
+ generate_idx=generate_idx,
341
+ )
342
+
343
+
344
+ class RBLNPaliGemmaModel(RBLNModel):
345
+ """
346
+ RBLNPaliGemmaModel which consists of a vision backbone and a language model without language modeling head,
347
+ optimized for RBLN NPUs.
348
+
349
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
350
+
351
+ Important Note:
352
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
353
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
354
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNPaliGemmaModel class for details.
355
+
356
+ Examples:
357
+ ```python
358
+ from optimum.rbln import RBLNPaliGemmaModel
359
+
360
+ model = RBLNPaliGemmaModel.from_pretrained(
361
+ "google/paligemma2-3b-mix-224",
362
+ export=True,
363
+ rbln_config={
364
+ "language_model": {
365
+ "prefill_chunk_size": 8192,
366
+ }
367
+ },
368
+ rbln_tensor_parallel_size=4,
369
+ )
370
+
371
+ model.save_pretrained("compiled-paligemma2-3b-mix-224")
372
+ ```
373
+ """
374
+
375
+ _rbln_submodules = [
376
+ {"name": "vision_tower"},
377
+ {"name": "language_model"},
378
+ ]
379
+
380
+ def __post_init__(self, **kwargs):
381
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
382
+ self.language_model = self.rbln_submodules[1]
383
+
384
+ if not isinstance(self.config.text_config, PretrainedConfig):
385
+ cfg = self.config if isinstance(self.config, dict) else self.config.to_dict()
386
+ text_config = cfg.pop("text_config", None)
387
+ vision_config = cfg.pop("vision_config", None)
388
+ self.config = PaliGemmaConfig(
389
+ text_config=text_config,
390
+ vision_config=vision_config,
391
+ **cfg,
392
+ )
393
+
394
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
395
+ self.embed_tokens = self._create_embedding_layer()
396
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
397
+ self.multi_modal_projector = self._create_multi_modal_projector()
398
+ self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
399
+
400
+ return super().__post_init__(**kwargs)
401
+
402
+ def get_input_embeddings(self):
403
+ return self.language_model.get_input_embeddings()
404
+
405
+ @classmethod
406
+ def _update_submodule_rbln_config(
407
+ cls,
408
+ submodule_name: str,
409
+ submodule_cls: Type["RBLNModel"],
410
+ model: "PreTrainedModel",
411
+ submodule_config: PretrainedConfig,
412
+ submodule_rbln_config: RBLNModelConfig,
413
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
414
+ ):
415
+ if submodule_name == "language_model":
416
+ submodule_config.use_sliding_window = False
417
+ else:
418
+ return submodule_rbln_config
419
+
420
+ return submodule_rbln_config
421
+
422
+ @classmethod
423
+ def save_torch_artifacts(
424
+ cls,
425
+ model: "PaliGemmaForConditionalGeneration",
426
+ save_dir_path: Path,
427
+ subfolder: str,
428
+ rbln_config: RBLNModelConfig,
429
+ ):
430
+ save_dict = {}
431
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
432
+ save_dict["multi_modal_projector"] = model.multi_modal_projector.state_dict()
433
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
434
+
435
+ def _create_embedding_layer(self):
436
+ with no_init_weights():
437
+ embed_tokens = torch.nn.Embedding(
438
+ self.config.text_config.vocab_size,
439
+ self.config.text_config.hidden_size,
440
+ self.config.text_config.pad_token_id,
441
+ )
442
+ return embed_tokens
443
+
444
+ def _create_multi_modal_projector(self):
445
+ with no_init_weights():
446
+ multi_modal_projector = PaliGemmaMultiModalProjector(self.config)
447
+ return multi_modal_projector
448
+
449
+ def get_image_features(self, pixel_values: torch.Tensor):
450
+ vision_output_size = [
451
+ pixel_values.shape[0],
452
+ self.config.vision_config.num_image_tokens,
453
+ self.config.vision_config.hidden_size,
454
+ ]
455
+ vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
456
+ self.vision_tower(pixel_values, out=vision_output)
457
+ image_features = self.multi_modal_projector(vision_output)
458
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
459
+ return image_features
460
+
461
+ def get_placeholder_mask(
462
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
463
+ ):
464
+ if input_ids is None:
465
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
466
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
467
+ )
468
+ special_image_mask = special_image_mask.all(-1)
469
+ else:
470
+ special_image_mask = input_ids == self.config.image_token_index
471
+
472
+ n_image_tokens = special_image_mask.sum()
473
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
474
+ n_image_features = image_features.shape[0] * image_features.shape[1]
475
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
476
+ raise ValueError(
477
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
478
+ )
479
+ return special_image_mask
480
+
481
+ def _preprocess_prefill(
482
+ self,
483
+ input_ids: Optional[torch.LongTensor] = None,
484
+ inputs_embeds: Optional[torch.FloatTensor] = None,
485
+ pixel_values: Optional[torch.FloatTensor] = None,
486
+ **kwargs,
487
+ ):
488
+ if (input_ids is None) ^ (inputs_embeds is not None):
489
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
490
+
491
+ if input_ids is not None and self.config.image_token_index >= self.config.text_config.vocab_size:
492
+ special_image_mask = input_ids == self.config.image_token_index
493
+ llm_input_ids = input_ids.clone()
494
+ llm_input_ids[special_image_mask] = 0
495
+ else:
496
+ llm_input_ids = input_ids
497
+
498
+ if inputs_embeds is None:
499
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
500
+
501
+ image_features = None
502
+ if pixel_values is not None:
503
+ image_features = self.get_image_features(pixel_values)
504
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
505
+ special_image_mask = self.get_placeholder_mask(
506
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
507
+ )
508
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
509
+
510
+ return inputs_embeds, image_features
511
+
512
+ def forward(
513
+ self,
514
+ input_ids: Optional[torch.LongTensor] = None,
515
+ pixel_values: Optional[torch.FloatTensor] = None,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ position_ids: Optional[torch.LongTensor] = None,
518
+ token_type_ids: Optional[torch.LongTensor] = None,
519
+ inputs_embeds: Optional[torch.FloatTensor] = None,
520
+ output_hidden_states: Optional[bool] = None,
521
+ return_dict: Optional[bool] = None,
522
+ **kwargs,
523
+ ) -> Union[Tuple, PaligemmaModelOutputWithPast]:
524
+ """
525
+ Forward pass for the RBLN-optimized PaliGemmaModel model.
526
+
527
+ Args:
528
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary.
529
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images.
530
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length)) — Mask to avoid performing attention on padding token indices.
531
+ position_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of positions of each input sequence tokens in the position embeddings.
532
+ token_type_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Segment token indices to indicate first and second portions of the inputs.
533
+ output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
534
+ return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
535
+
536
+ Returns:
537
+ PaligemmaModelOutputWithPast or tuple(torch.FloatTensor)
538
+ """
539
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
540
+ output_hidden_states = (
541
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
542
+ )
543
+ if output_hidden_states != self.rbln_config.output_hidden_states:
544
+ raise ValueError(
545
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
546
+ f"Please compile again with the correct argument."
547
+ )
548
+
549
+ inputs_embeds, image_features = self._preprocess_prefill(
550
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
551
+ )
552
+
553
+ outputs = self.language_model(
554
+ inputs_embeds=inputs_embeds,
555
+ attention_mask=attention_mask,
556
+ position_ids=position_ids,
557
+ output_hidden_states=output_hidden_states,
558
+ )
559
+
560
+ return PaligemmaModelOutputWithPast(
561
+ last_hidden_state=outputs.last_hidden_state,
562
+ image_hidden_states=image_features if pixel_values is not None else None,
563
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
564
+ )
@@ -60,10 +60,10 @@ class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
60
60
  class PegasusDecoder(Seq2SeqDecoder):
61
61
  has_pos_emb = True
62
62
 
63
- def __post_init__(self):
64
- self.embed_positions = self._original_mod.embed_positions
65
- self.embed_scale = getattr(self._original_mod, "embed_scale", None)
66
- self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
63
+ def __post_init__(self, model: nn.Module):
64
+ self.embed_positions = model.embed_positions
65
+ self.embed_scale = getattr(model, "embed_scale", None)
66
+ self.final_layer_norm = getattr(model, "layer_norm", None)
67
67
 
68
68
  def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
69
  if attention_mask is not None:
@@ -110,11 +110,11 @@ class PegasusLayerFF(nn.Module):
110
110
 
111
111
 
112
112
  class PegasusDecoderLayer(Seq2SeqDecoderLayer):
113
- def __post_init__(self):
114
- self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
115
- self.encoder_attn = self._original_mod.encoder_attn
116
- self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
117
- self.ff_layer = PegasusLayerFF(self._original_mod)
113
+ def __post_init__(self, decoder_layer: nn.Module):
114
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
115
+ self.encoder_attn = decoder_layer.encoder_attn
116
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
117
+ self.ff_layer = PegasusLayerFF(decoder_layer)
118
118
 
119
119
  def pre_self_attn_layer_norm(self, hidden_states):
120
120
  return self.self_attn_layer_norm(hidden_states)
@@ -130,13 +130,13 @@ class PegasusDecoderLayer(Seq2SeqDecoderLayer):
130
130
 
131
131
 
132
132
  class PegasusSelfAttention(Seq2SeqSelfAttention):
133
- def __post_init__(self, use_attention_mask: bool = True):
134
- self.q_proj = self._original_mod.q_proj
135
- self.k_proj = self._original_mod.k_proj
136
- self.v_proj = self._original_mod.v_proj
137
- self.out_proj = self._original_mod.out_proj
138
- self.num_heads = self._original_mod.num_heads
139
- self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
133
+ def __post_init__(self, attn: nn.Module, use_attention_mask: bool = True):
134
+ self.q_proj = attn.q_proj
135
+ self.k_proj = attn.k_proj
136
+ self.v_proj = attn.v_proj
137
+ self.out_proj = attn.out_proj
138
+ self.num_heads = attn.num_heads
139
+ self.head_dim = attn.embed_dim // attn.num_heads
140
140
  self.scaling = self.head_dim**-0.5
141
141
  if use_attention_mask:
142
142
  self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
@@ -151,11 +151,11 @@ class PegasusSelfAttention(Seq2SeqSelfAttention):
151
151
 
152
152
 
153
153
  class PegasusCrossAttention(Seq2SeqCrossAttention):
154
- def __post_init__(self):
155
- self.q_proj = self._original_mod.q_proj
156
- self.k_proj = self._original_mod.k_proj
157
- self.v_proj = self._original_mod.v_proj
158
- self.out_proj = self._original_mod.out_proj
159
- self.num_heads = self._original_mod.num_heads
160
- self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
161
- self.embed_dim = self._original_mod.embed_dim
154
+ def __post_init__(self, attn: nn.Module):
155
+ self.q_proj = attn.q_proj
156
+ self.k_proj = attn.k_proj
157
+ self.v_proj = attn.v_proj
158
+ self.out_proj = attn.out_proj
159
+ self.num_heads = attn.num_heads
160
+ self.head_dim = attn.embed_dim // attn.num_heads
161
+ self.embed_dim = attn.embed_dim