optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -99,9 +99,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
99
99
  return True
100
100
 
101
101
  @classmethod
102
- def get_pytorch_model(cls, *args, **kwargs):
103
- model = super().get_pytorch_model(*args, **kwargs)
104
-
102
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
105
103
  with no_init_weights():
106
104
  model_cls_name = model.model.language_model.__class__.__name__
107
105
  causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
@@ -135,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
135
133
  return self.language_model.get_input_embeddings()
136
134
 
137
135
  @classmethod
138
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
136
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
139
137
  return model.multi_modal_projector
140
138
 
141
139
  @classmethod
@@ -301,28 +299,60 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
301
299
  generate_idx: Optional[torch.Tensor] = None,
302
300
  padded_cache_lengths: Optional[torch.Tensor] = None,
303
301
  position_ids: Optional[torch.Tensor] = None,
302
+ output_hidden_states: Optional[bool] = None,
304
303
  **lm_kwargs: Dict[str, Any],
305
304
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
305
+ output_hidden_states = (
306
+ output_hidden_states
307
+ if output_hidden_states is not None
308
+ else self.rbln_config.language_model.output_hidden_states
309
+ )
310
+ if output_hidden_states != self.rbln_config.language_model.output_hidden_states:
311
+ raise ValueError(
312
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.language_model.output_hidden_states {self.rbln_config.language_model.output_hidden_states} "
313
+ f"Please compile again with the correct argument."
314
+ )
315
+
306
316
  # prefill
307
317
  if cache_position is None:
308
318
  logits = []
309
319
  inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values)
310
320
  batch_size = inputs_embeds.shape[0]
311
321
 
322
+ all_hidden_states = (
323
+ tuple(
324
+ torch.zeros(
325
+ batch_size,
326
+ inputs_embeds.shape[1],
327
+ self.config.text_config.hidden_size,
328
+ dtype=self.rbln_config.torch_dtype,
329
+ )
330
+ for _ in range(self.config.text_config.num_hidden_layers + 1)
331
+ )
332
+ if self.rbln_config.language_model.output_hidden_states
333
+ else None
334
+ )
335
+
312
336
  for b_idx in range(batch_size):
313
337
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
314
338
  token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
315
339
  cache_position = self.get_padded_cache_position(cache_position, token_type_id)
316
340
 
317
- output = self.language_model.prefill_decoder(
341
+ outputs = self.language_model.prefill_decoder(
318
342
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
319
343
  attention_mask=attention_mask[b_idx],
320
344
  cache_position=cache_position,
321
345
  batch_idx=b_idx,
322
346
  token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
323
347
  )
324
- padded_cache_lengths[b_idx] += output.padded_cache_lengths
325
- logits.append(output.logits)
348
+ padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
349
+ logits.append(outputs.logits)
350
+ if self.rbln_config.language_model.output_hidden_states:
351
+ for l_idx in range(self.config.text_config.num_hidden_layers + 1):
352
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
353
+ all_hidden_states[l_idx][b_idx].index_copy_(
354
+ dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
355
+ )
326
356
 
327
357
  logits = torch.cat(logits, dim=0)
328
358
  # decoder
@@ -336,15 +366,20 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
336
366
  f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
337
367
  )
338
368
 
339
- logits = self.language_model.decoders[batch_size](
369
+ outputs = self.language_model.decoders[batch_size](
340
370
  input_ids=input_ids,
341
371
  inputs_embeds=inputs_embeds,
342
372
  cache_position=cache_position,
343
373
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
344
- ).logits
374
+ )
375
+ logits = outputs.logits
376
+ all_hidden_states = outputs.hidden_states
345
377
 
346
378
  return RBLNDecoderOnlyOutput(
347
- logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
379
+ logits=logits,
380
+ generate_idx=generate_idx,
381
+ padded_cache_lengths=padded_cache_lengths,
382
+ hidden_states=all_hidden_states,
348
383
  )
349
384
 
350
385
 
@@ -405,26 +440,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
405
440
  )
406
441
  return embed_tokens
407
442
 
408
- @classmethod
409
- def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
410
- sliding_window = getattr(model_config, "sliding_window", None)
411
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
412
- if sliding_window_pattern is None:
413
- if hasattr(model_config, "layer_types"):
414
- first_full_attention_index = model_config.layer_types.index("full_attention")
415
- sliding_window_pattern = first_full_attention_index + 1
416
- else:
417
- raise ValueError("Cannot determine sliding_window_pattern from model_config")
418
-
419
- if sliding_window_pattern <= model_config.num_hidden_layers:
420
- rbln_config.cache_impl = "hybrid"
421
- rbln_config.sliding_window = sliding_window
422
- rbln_config.sliding_window_layers = [
423
- i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
424
- ]
425
-
426
- return rbln_config
427
-
428
443
  @classmethod
429
444
  def _update_submodule_config(
430
445
  cls,
@@ -482,7 +497,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
482
497
  @classmethod
483
498
  @torch.inference_mode()
484
499
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
485
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
500
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
486
501
 
487
502
  rbln_compile_configs = rbln_config.compile_cfgs
488
503
  prefill_compile_config = rbln_compile_configs[0]
@@ -150,7 +150,7 @@ class _GroundingDinoEncoder(torch.nn.Module):
150
150
  all_attn_fused_vision = () if output_attentions else None
151
151
  all_attn_enhanced_text = () if output_attentions else None
152
152
  all_attn_deformable = () if output_attentions else None
153
- for i, encoder_layer in enumerate(self.layers):
153
+ for _, encoder_layer in enumerate(self.layers):
154
154
  if output_hidden_states:
155
155
  encoder_vision_states += (vision_features,)
156
156
  encoder_text_states += (text_features,)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import Tensor, nn
@@ -206,8 +206,7 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
206
206
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
207
207
 
208
208
  @classmethod
209
- def get_pytorch_model(cls, *args, **kwargs):
210
- model = super().get_pytorch_model(*args, **kwargs)
209
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
211
210
  model.encoder = model.model.encoder
212
211
  model.decoder = model.model.decoder
213
212
  model.text_backbone = model.model.text_backbone
@@ -217,7 +216,7 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
217
216
  return model
218
217
 
219
218
  @classmethod
220
- def wrap_model_if_needed(
219
+ def _wrap_model_if_needed(
221
220
  cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
222
221
  ) -> torch.nn.Module:
223
222
  return model.model.text_projection
@@ -305,7 +304,6 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
305
304
  for feature_map, mask in vision_features:
306
305
  # position encoding
307
306
  position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
308
- vision_features, position_embeddings_list
309
307
 
310
308
  # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
311
309
  feature_maps = []
@@ -530,9 +528,26 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
530
528
  output_attentions: Optional[bool] = None,
531
529
  output_hidden_states: Optional[bool] = None,
532
530
  return_dict: Optional[bool] = None,
533
- labels: List[Dict[str, Union[torch.LongTensor, torch.FloatTensor]]] = None,
534
531
  **kwargs,
535
- ):
532
+ ) -> Union[GroundingDinoObjectDetectionOutput, Tuple]:
533
+ """
534
+ Forward pass for the RBLN-optimized GroundingDinoForObjectDetection model.
535
+
536
+ Args:
537
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
538
+ input_ids (torch.LongTensor of shape (batch_size, text_sequence_length)): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
539
+ token_type_ids (torch.LongTensor of shape (batch_size, text_sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
540
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
541
+ pixel_mask (torch.Tensor of shape (batch_size, height, width), optional): Mask to avoid performing attention on padding pixel values.
542
+ encoder_outputs (Tuple consists of last_hidden_state of shape(batch_size, sequence_length, hidden_size), optional): A sequence of hidden-states at the output of the last layer of the encoder.
543
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
544
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
545
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
546
+
547
+ Returns:
548
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a GroundingDinoObjectDetectionOutput object.
549
+ """
550
+
536
551
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
537
552
 
538
553
  # Pad image to rbln_config.image_height and rbln_config.image_width
@@ -663,7 +678,7 @@ class RBLNGroundingDinoEncoder(RBLNModel):
663
678
  self.encoder_runtime = RBLNPytorchRuntime(self.model[0])
664
679
 
665
680
  @classmethod
666
- def wrap_model_if_needed(
681
+ def _wrap_model_if_needed(
667
682
  cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
668
683
  ) -> torch.nn.Module:
669
684
  model = _GroundingDinoEncoder(model, rbln_config).eval()
@@ -861,7 +876,7 @@ class RBLNGroundingDinoDecoder(RBLNModel):
861
876
  self.decoder_runtime = RBLNPytorchRuntime(self.model[0])
862
877
 
863
878
  @classmethod
864
- def wrap_model_if_needed(
879
+ def _wrap_model_if_needed(
865
880
  cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
866
881
  ) -> torch.nn.Module:
867
882
  return _GroundingDinoDecoder(model, rbln_config).eval()
@@ -110,7 +110,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
110
110
  return self.embeddings
111
111
 
112
112
  @classmethod
113
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
113
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
114
114
  class Idefics3VisionTransformerWrapper(torch.nn.Module):
115
115
  def __init__(self, model: "Idefics3VisionTransformer"):
116
116
  super().__init__()
@@ -240,9 +240,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationM
240
240
  return True
241
241
 
242
242
  @classmethod
243
- def get_pytorch_model(cls, *args, **kwargs):
244
- model = super().get_pytorch_model(*args, **kwargs)
245
-
243
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
246
244
  with no_init_weights():
247
245
  model_cls_name = model.model.text_model.__class__.__name__
248
246
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -271,7 +269,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationM
271
269
  return self.text_model.get_input_embeddings()
272
270
 
273
271
  @classmethod
274
- def wrap_model_if_needed(cls, model, rbln_config):
272
+ def _wrap_model_if_needed(cls, model, rbln_config):
275
273
  return model.model.connector
276
274
 
277
275
  @classmethod
@@ -88,15 +88,22 @@ class LoopVisionTower(LoopProcessor):
88
88
 
89
89
 
90
90
  class LoopProjector(LoopProcessor):
91
- def __init__(self, multi_modal_projector: "RBLNModel"):
91
+ def __init__(self, multi_modal_projector: "RBLNModel", rbln_config=None):
92
92
  super().__init__(model=multi_modal_projector)
93
+ self.rbln_config = rbln_config
93
94
 
94
95
  def _get_batch_size(self, image_feature, **kwargs):
95
96
  return image_feature.shape[0]
96
97
 
97
98
  def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
98
99
  image_feature_item = image_feature[index : index + 1]
99
- out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
100
+ if hasattr(self.rbln_config.vision_tower, "max_image_size"):
101
+ out_buffer = [
102
+ tensor[:, index * image_feature.shape[1] : (index + 1) * image_feature.shape[1], :]
103
+ for tensor in kwargs["out"]
104
+ ]
105
+ else:
106
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
100
107
  return ([image_feature_item], {"out": out_buffer})
101
108
 
102
109
  def _process_outputs(self, outputs: list, **kwargs):
@@ -175,9 +182,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
175
182
  return True
176
183
 
177
184
  @classmethod
178
- def get_pytorch_model(cls, *args, **kwargs):
179
- model = super().get_pytorch_model(*args, **kwargs)
180
-
185
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
181
186
  with no_init_weights():
182
187
  model_cls_name = model.model.language_model.__class__.__name__
183
188
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -194,7 +199,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
194
199
  def __post_init__(self, **kwargs):
195
200
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
196
201
  self.language_model = self.rbln_submodules[1]
197
- self.multi_modal_projector = LoopProjector(self.model[0])
202
+ self.multi_modal_projector = LoopProjector(self.model[0], rbln_config=self.rbln_config)
198
203
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
199
204
  return super().__post_init__(**kwargs)
200
205
 
@@ -208,7 +213,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
208
213
  return self.language_model.get_input_embeddings()
209
214
 
210
215
  @classmethod
211
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
216
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
212
217
  return model.multi_modal_projector
213
218
 
214
219
  @classmethod
@@ -221,10 +226,8 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
221
226
  ) -> RBLNModelConfig:
222
227
  # support for pixtral that needs padding
223
228
  if hasattr(rbln_config.vision_tower, "max_image_size"):
224
- num_positions = (
225
- rbln_config.batch_size
226
- * (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size)
227
- * (rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size)
229
+ num_positions = (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size) * (
230
+ rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size
228
231
  )
229
232
  selected_image_feature_dim = num_positions
230
233
 
@@ -334,7 +337,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
334
337
  pooler_out_size = [pixel_values.shape[0], self.config.vision_config.hidden_size]
335
338
 
336
339
  vision_out_buffer = []
337
- for i in range(self.config.vision_config.num_hidden_layers + 2):
340
+ for _ in range(self.config.vision_config.num_hidden_layers + 2):
338
341
  vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
339
342
  if pooler_out_size is not None:
340
343
  vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
@@ -353,23 +356,32 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
353
356
 
354
357
  if hasattr(self.rbln_config.vision_tower, "max_image_size"):
355
358
  num_real_patches = selected_image_feature.shape[1]
356
- max_patches = (
357
- (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size)
358
- * (self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size)
359
- * pixel_values.shape[0]
359
+ max_patches = (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size) * (
360
+ self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size
360
361
  )
361
- num_padding_patches = max_patches - num_real_patches
362
362
 
363
- projector_out_size = [1, max_patches, self.config.text_config.hidden_size]
363
+ chunks = []
364
+ for i in range(0, num_real_patches, max_patches):
365
+ chunk = selected_image_feature[:, i : i + max_patches, :]
366
+ chunk_size = chunk.shape[1]
367
+
368
+ if chunk_size < max_patches:
369
+ padding_tensor = torch.zeros(
370
+ (selected_image_feature.shape[0], max_patches - chunk_size, selected_image_feature.shape[2]),
371
+ dtype=selected_image_feature.dtype,
372
+ )
373
+ chunk = torch.cat([chunk, padding_tensor], dim=1)
374
+ chunks.append(chunk)
375
+
376
+ split_features = torch.cat(chunks, dim=0)
377
+ num_chunks = len(chunks)
378
+ projector_out_size = [1, max_patches * num_chunks, self.config.text_config.hidden_size]
364
379
  projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
365
-
366
- padding_tensor = torch.zeros(
367
- (selected_image_feature.shape[0], num_padding_patches, selected_image_feature.shape[2]),
368
- dtype=selected_image_feature.dtype,
380
+ projected_features = self.multi_modal_projector(split_features, out=projector_out_buffer)
381
+ projected_features = projected_features.view(
382
+ selected_image_feature.shape[0], num_chunks * max_patches, self.config.text_config.hidden_size
369
383
  )
370
- padded_feature = torch.cat([selected_image_feature, padding_tensor], dim=1)
371
- padded_projected_feature = self.multi_modal_projector(padded_feature, out=projector_out_buffer)
372
- image_features = padded_projected_feature[:, :num_real_patches, :]
384
+ image_features = projected_features[:, :num_real_patches, :]
373
385
  else:
374
386
  projector_out_size = [
375
387
  pixel_values.shape[0] * pixel_values.shape[1],
@@ -139,9 +139,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
139
139
  return True
140
140
 
141
141
  @classmethod
142
- def get_pytorch_model(cls, *args, **kwargs):
143
- model = super().get_pytorch_model(*args, **kwargs)
144
-
142
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
145
143
  with no_init_weights():
146
144
  model_cls_name = model.model.language_model.__class__.__name__
147
145
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -192,7 +190,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
192
190
  return self.language_model.get_input_embeddings()
193
191
 
194
192
  @classmethod
195
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
193
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
196
194
  return model.multi_modal_projector
197
195
 
198
196
  @classmethod
@@ -302,7 +300,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
302
300
  ]
303
301
  pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size]
304
302
  vision_out_buffer = []
305
- for i in range(self.config.vision_config.num_hidden_layers + 2):
303
+ for _ in range(self.config.vision_config.num_hidden_layers + 2):
306
304
  vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
307
305
  vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
308
306
 
@@ -12,13 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from transformers import PretrainedConfig
16
15
 
17
16
  from ....utils import logging
18
17
  from ...models.decoderonly import (
19
18
  RBLNDecoderOnlyModel,
20
19
  RBLNDecoderOnlyModelForCausalLM,
21
- RBLNDecoderOnlyModelForCausalLMConfig,
22
20
  )
23
21
  from .mistral_architecture import MistralWrapper
24
22
 
@@ -85,16 +83,6 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
85
83
 
86
84
  _decoder_wrapper_cls = MistralWrapper
87
85
 
88
- @classmethod
89
- def _update_sliding_window_config(
90
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
- ):
92
- rbln_config.cache_impl = "sliding_window"
93
- rbln_config.sliding_window = model_config.sliding_window
94
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
-
96
- return rbln_config
97
-
98
86
 
99
87
  class RBLNMistralModel(RBLNDecoderOnlyModel):
100
88
  """
@@ -103,13 +91,3 @@ class RBLNMistralModel(RBLNDecoderOnlyModel):
103
91
  """
104
92
 
105
93
  _decoder_wrapper_cls = MistralWrapper
106
-
107
- @classmethod
108
- def _update_sliding_window_config(
109
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
110
- ):
111
- rbln_config.cache_impl = "sliding_window"
112
- rbln_config.sliding_window = model_config.sliding_window
113
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
114
-
115
- return rbln_config
@@ -69,7 +69,7 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
69
69
  return layer
70
70
 
71
71
  @classmethod
72
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
72
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
73
73
  for i in range(len(model.model.decoder.layers)):
74
74
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
75
75
 
@@ -95,7 +95,7 @@ class RBLNOPTModel(RBLNDecoderOnlyModel):
95
95
  return layer
96
96
 
97
97
  @classmethod
98
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
98
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
99
99
  for i in range(len(model.decoder.layers)):
100
100
  model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
101
101
 
@@ -54,7 +54,7 @@ class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
54
54
  support_causal_attn = True
55
55
 
56
56
  @classmethod
57
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
57
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
58
58
  return PegasusWrapper(
59
59
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
60
60
  )
@@ -229,7 +229,7 @@ class RBLNPixtralVisionModel(RBLNModel):
229
229
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
230
230
 
231
231
  @classmethod
232
- def wrap_model_if_needed(
232
+ def _wrap_model_if_needed(
233
233
  cls, model: torch.nn.Module, rbln_config: RBLNPixtralVisionModelConfig
234
234
  ) -> torch.nn.Module:
235
235
  wrapper_cfg = {
@@ -293,6 +293,18 @@ class RBLNPixtralVisionModel(RBLNModel):
293
293
  return_dict: bool = True,
294
294
  **kwargs,
295
295
  ) -> Union[Tuple, BaseModelOutput]:
296
+ """
297
+ Forward pass for the RBLN-optimized Pixtral vision model.
298
+
299
+ Args:
300
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using PixtralImageProcessor. See PixtralImageProcessor.call() for details (PixtralProcessor uses PixtralImageProcessor for processing images).
301
+ image_sizes (torch.Tensor of shape (batch_size, 2), optional) — The sizes of the images in the batch, being (height, width) for each image.
302
+ output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
303
+ return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
304
+
305
+ Returns:
306
+ BaseModelOutput or tuple(torch.FloatTensor)
307
+ """
296
308
  output_hidden_states = (
297
309
  output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
298
310
  )
@@ -24,8 +24,8 @@ class PixtralAttention(nn.Module):
24
24
  def __init__(self, self_attention):
25
25
  super().__init__()
26
26
  self.original_model = self_attention
27
- self.num_heads = getattr(self.original_model, "num_heads", None) or getattr(
28
- self.original_model.config, "num_attention_heads"
27
+ self.num_heads = (
28
+ getattr(self.original_model, "num_heads", None) or self.original_model.config.num_attention_heads
29
29
  )
30
30
  self.head_dim = self.original_model.head_dim
31
31
  self.scaling = self.head_dim**-0.5
@@ -12,13 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from transformers import PretrainedConfig
16
15
 
17
16
  from ....utils import logging
18
17
  from ...models.decoderonly import (
19
18
  RBLNDecoderOnlyModel,
20
19
  RBLNDecoderOnlyModelForCausalLM,
21
- RBLNDecoderOnlyModelForCausalLMConfig,
22
20
  )
23
21
  from .qwen2_architecture import QWEN2Wrapper
24
22
 
@@ -87,19 +85,6 @@ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
87
85
 
88
86
  _decoder_wrapper_cls = QWEN2Wrapper
89
87
 
90
- @classmethod
91
- def _update_sliding_window_config(
92
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
93
- ):
94
- # https://github.com/huggingface/transformers/issues/35896
95
- # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
96
- # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
97
-
98
- rbln_config.cache_impl = "sliding_window"
99
- rbln_config.sliding_window = model_config.sliding_window
100
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
101
- return rbln_config
102
-
103
88
 
104
89
  class RBLNQwen2Model(RBLNDecoderOnlyModel):
105
90
  """
@@ -108,16 +93,3 @@ class RBLNQwen2Model(RBLNDecoderOnlyModel):
108
93
  """
109
94
 
110
95
  _decoder_wrapper_cls = QWEN2Wrapper
111
-
112
- @classmethod
113
- def _update_sliding_window_config(
114
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
115
- ):
116
- # https://github.com/huggingface/transformers/issues/35896
117
- # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
118
- # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
119
-
120
- rbln_config.cache_impl = "sliding_window"
121
- rbln_config.sliding_window = model_config.sliding_window
122
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
123
- return rbln_config
@@ -88,7 +88,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
88
88
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
89
89
 
90
90
  @classmethod
91
- def wrap_model_if_needed(
91
+ def _wrap_model_if_needed(
92
92
  cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
93
93
  ):
94
94
  return Qwen2_5_VisionTransformerWrapper(model).eval()
@@ -111,10 +111,10 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
111
111
  model_config: "PretrainedConfig" = None,
112
112
  rbln_config: Optional[RBLNQwen2_5_VisionTransformerPretrainedModelConfig] = None,
113
113
  ) -> RBLNQwen2_5_VisionTransformerPretrainedModelConfig:
114
- window_size = getattr(model_config, "window_size")
115
- patch_size = getattr(model_config, "patch_size")
116
- hidden_size = getattr(model_config, "hidden_size")
117
- num_heads = getattr(model_config, "num_heads")
114
+ window_size = model_config.window_size
115
+ patch_size = model_config.patch_size
116
+ hidden_size = model_config.hidden_size
117
+ num_heads = model_config.num_heads
118
118
  head_dim = hidden_size // num_heads
119
119
  window_seq_len = (window_size // patch_size) ** 2
120
120
 
@@ -294,10 +294,10 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
294
294
  try:
295
295
  ws_index = torch.searchsorted(self.max_seq_lens, window_padded_len).item()
296
296
  max_seq_len = self.max_seq_lens[ws_index]
297
- except Exception:
297
+ except Exception as e:
298
298
  raise ValueError(
299
299
  f"Required seq_len({window_padded_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
300
- )
300
+ ) from e
301
301
 
302
302
  # Padding for Window Attention Layers
303
303
  hidden_state_padded, cos_padded, sin_padded, window_attn_masks, window_valid_lengths = (
@@ -393,8 +393,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
393
393
  return True
394
394
 
395
395
  @classmethod
396
- def get_pytorch_model(cls, *args, **kwargs):
397
- model = super().get_pytorch_model(*args, **kwargs)
396
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
398
397
  model.model.lm_head = model.lm_head
399
398
  model.lm_head = None
400
399
  del model.lm_head