optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -88,8 +88,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
88
88
  def setup_runtime(self):
89
89
  # Initialize resources to be used across Runtime instances (prefill and decode phases)
90
90
  page_table_manager = RBLNPageTableManager(self.rbln_config)
91
- dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
92
- out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
91
+ if self.rbln_config.use_position_ids:
92
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=self.dtype)
93
+ else:
94
+ dec_attn_mask = torch.zeros(
95
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype
96
+ )
93
97
 
94
98
  common_kwargs = {
95
99
  "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
@@ -97,12 +101,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
97
101
  "dec_attn_mask": dec_attn_mask,
98
102
  "page_table_manager": page_table_manager,
99
103
  "rbln_config": self.rbln_config,
104
+ "config": self.config,
100
105
  }
101
106
  self.prefill_decoder = RBLNRuntimeModel(
102
107
  runtime=self.model[0],
103
108
  phase="prefill",
104
109
  batch_size=self.rbln_config.batch_size,
105
- out_buffers=out_buffers,
110
+ logits_last_dim=self.logits_last_dim,
106
111
  **common_kwargs,
107
112
  )
108
113
  if self.can_generate():
@@ -119,12 +124,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
119
124
  self.decoder = self.decoders[self.rbln_config.batch_size]
120
125
 
121
126
  @property
122
- def prefill_output_size(self):
123
- return (
124
- 1,
125
- self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
- self.config.hidden_size,
127
- )
127
+ def logits_last_dim(self):
128
+ return self.config.hidden_size
128
129
 
129
130
  @classmethod
130
131
  def get_quantized_model(
@@ -216,7 +217,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
216
217
  return self.rbln_config.kvcache_num_blocks
217
218
 
218
219
  @classmethod
219
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
220
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
220
221
  return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
221
222
 
222
223
  @classmethod
@@ -272,7 +273,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
272
273
  @classmethod
273
274
  @torch.inference_mode()
274
275
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
275
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
276
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
276
277
  prefill_compile_config = rbln_config.compile_cfgs[0]
277
278
 
278
279
  # Here we use meta tensor, for the memory efficiency.
@@ -340,10 +341,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
340
341
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
341
342
  model_config: PretrainedConfig,
342
343
  ):
343
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
344
+ num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
344
345
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
345
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
346
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
346
+ num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
347
+ hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
347
348
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
348
349
  is_prefill = query_length > 1
349
350
 
@@ -439,10 +440,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
439
440
  # Returns:
440
441
  # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
441
442
 
442
- raise NotImplementedError(
443
- "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
444
- "See method docstring for required configuration details."
443
+ rbln_config.sliding_window = model_config.sliding_window
444
+ sliding_window_layers = []
445
+
446
+ for i in range(model_config.num_hidden_layers):
447
+ if hasattr(model_config, "layer_types"):
448
+ if model_config.layer_types[i] == "sliding_attention":
449
+ sliding_window_layers.append(i)
450
+ else:
451
+ sliding_window_layers.append(i)
452
+
453
+ rbln_config.sliding_window_layers = sliding_window_layers
454
+
455
+ rbln_config.cache_impl = (
456
+ "sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
445
457
  )
458
+ return rbln_config
446
459
 
447
460
  @classmethod
448
461
  def _update_attention_config(
@@ -466,13 +479,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
466
479
 
467
480
  # Update kvcache_num_blocks based on the attention implementation.
468
481
  if rbln_config.attn_impl == "flash_attn":
469
- estimated_max_num_blocks = cls.get_maximum_num_blocks(
470
- config=model_config,
471
- tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
472
- kvcache_block_size=rbln_config.kvcache_block_size,
473
- nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
474
- n_model_params=sum(p.numel() for p in model.parameters()),
475
- num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
482
+ estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
483
+ model=model, model_config=model_config, rbln_config=rbln_config
476
484
  )
477
485
 
478
486
  if rbln_config.kvcache_num_blocks is None:
@@ -511,7 +519,6 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
511
519
  f" than the required number of blocks ({num_full_blocks})."
512
520
  "This can cause a failure during model compilation."
513
521
  )
514
-
515
522
  logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
516
523
 
517
524
  return rbln_config
@@ -531,8 +538,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
531
538
  if rbln_config.max_seq_len is None:
532
539
  raise ValueError("`max_seq_len` should be specified.")
533
540
 
534
- if getattr(model_config, "sliding_window", None) is not None and getattr(
535
- model_config, "use_sliding_window", True
541
+ layer_types = getattr(model_config, "layer_types", None)
542
+ all_full_attention = layer_types is not None and all(t == "full_attention" for t in layer_types)
543
+
544
+ if (
545
+ getattr(model_config, "sliding_window", None) is not None
546
+ and getattr(model_config, "use_sliding_window", True)
547
+ and not all_full_attention
536
548
  ):
537
549
  rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
538
550
  if rbln_config.sliding_window is not None:
@@ -608,34 +620,74 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
608
620
  input_ids: Optional[torch.LongTensor] = None,
609
621
  inputs_embeds: Optional[torch.Tensor] = None,
610
622
  attention_mask: Optional[torch.LongTensor] = None,
623
+ position_ids: Optional[torch.Tensor] = None,
611
624
  position_embed: Optional[torch.Tensor] = None,
625
+ output_hidden_states: Optional[bool] = None,
612
626
  **kwargs,
613
- ) -> Tuple[torch.FloatTensor]:
627
+ ) -> BaseModelOutputWithPast:
628
+ """
629
+ Args:
630
+ input_ids (torch.LongTensor, optional): The input IDs to the model.
631
+ inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
632
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
633
+ kwargs (dict[str, Any], optional): Additional keyword arguments.
634
+
635
+ Returns:
636
+ Dataclass containing the last hidden states of the model.
637
+ """
614
638
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
615
639
  batch_size = inputs.shape[0]
640
+ position_embed = kwargs.get("position_embed", None)
616
641
 
617
642
  if batch_size != self.rbln_config.batch_size:
618
643
  raise ValueError(
619
644
  f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
620
645
  )
621
646
 
647
+ output_hidden_states = (
648
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
649
+ )
650
+ if output_hidden_states != self.rbln_config.output_hidden_states:
651
+ raise ValueError(
652
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
653
+ f"Please compile again with the correct argument."
654
+ )
655
+
622
656
  all_last_hidden_states = []
657
+ all_hidden_states = (
658
+ tuple(
659
+ torch.zeros(
660
+ self.rbln_config.batch_size,
661
+ inputs.shape[1],
662
+ self.config.hidden_size,
663
+ dtype=self.rbln_config.torch_dtype,
664
+ )
665
+ for _ in range(self.config.num_hidden_layers + 1)
666
+ )
667
+ if output_hidden_states
668
+ else None
669
+ )
623
670
  for b_idx in range(self.rbln_config.batch_size):
624
671
  query_length = (
625
672
  attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
626
673
  )
627
674
  cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
628
- last_hidden_states = self.prefill_decoder(
629
- inputs[b_idx : b_idx + 1],
675
+ outputs = self.prefill_decoder(
676
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
677
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
630
678
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
679
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
631
680
  position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
632
681
  cache_position=cache_position,
633
682
  batch_idx=b_idx,
634
- ).logits
635
- all_last_hidden_states.append(last_hidden_states)
683
+ )
684
+ all_last_hidden_states.append(outputs.logits)
685
+ if self.rbln_config.output_hidden_states:
686
+ for l_idx in range(self.config.num_hidden_layers + 1):
687
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
636
688
 
637
689
  last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
638
- return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
690
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)
639
691
 
640
692
 
641
693
  class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
@@ -661,12 +713,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
661
713
  auto_model_class = AutoModelForCausalLM
662
714
 
663
715
  @property
664
- def prefill_output_size(self):
665
- return (
666
- 1,
667
- self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
668
- self.config.vocab_size,
669
- )
716
+ def logits_last_dim(self):
717
+ return self.config.vocab_size
670
718
 
671
719
  @classmethod
672
720
  def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
@@ -731,6 +779,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
731
779
  token_type_ids: Optional[torch.Tensor] = None,
732
780
  lora_int_ids: Optional[torch.Tensor] = None,
733
781
  return_dict: Optional[torch.Tensor] = None,
782
+ output_hidden_states: Optional[bool] = None,
734
783
  **kwargs,
735
784
  ) -> Tuple[torch.FloatTensor]:
736
785
  # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
@@ -754,24 +803,55 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
754
803
  )
755
804
  padded_cache_lengths = torch.zeros_like(generate_idx)
756
805
 
806
+ output_hidden_states = (
807
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
808
+ )
809
+ if output_hidden_states != self.rbln_config.output_hidden_states:
810
+ raise ValueError(
811
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
812
+ f"Please compile again with the correct argument."
813
+ )
814
+
757
815
  # Prefill
758
816
  if cache_position is None:
759
817
  logits = []
760
818
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
761
819
  batch_size = inputs.shape[0]
820
+ input_len = inputs.shape[1]
821
+ if batch_size > self.rbln_config.batch_size:
822
+ raise ValueError(
823
+ f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
824
+ )
825
+ if input_len > self.rbln_config.max_seq_len:
826
+ raise ValueError(
827
+ f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
828
+ )
829
+
830
+ all_hidden_states = (
831
+ tuple(
832
+ torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
833
+ for _ in range(self.config.num_hidden_layers + 1)
834
+ )
835
+ if self.rbln_config.output_hidden_states
836
+ else None
837
+ )
762
838
  for b_idx in range(batch_size):
763
839
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
764
- output = self.prefill_decoder(
840
+ outputs = self.prefill_decoder(
765
841
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
766
842
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
767
843
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
844
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
768
845
  cache_position=cache_position,
769
846
  batch_idx=b_idx,
770
847
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
771
848
  lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
772
849
  )
773
- padded_cache_lengths[b_idx] += output.padded_cache_lengths
774
- logits.append(output.logits)
850
+ padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
851
+ logits.append(outputs.logits)
852
+ if self.rbln_config.output_hidden_states:
853
+ for l_idx in range(self.config.num_hidden_layers + 1):
854
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
775
855
  logits = torch.cat(logits, dim=0)
776
856
  # Decoder
777
857
  else:
@@ -783,17 +863,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
783
863
  f"Available batch sizes are: {list(self.decoders.keys())}. "
784
864
  f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
785
865
  )
786
- logits = self.decoders[batch_size](
866
+ if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
867
+ raise ValueError(
868
+ f"Cache position exceeds the maximum sequence length.\n"
869
+ f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
870
+ f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
871
+ f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
872
+ f"or `max_length` in the generation config."
873
+ )
874
+
875
+ outputs = self.decoders[batch_size](
787
876
  input_ids=input_ids,
788
877
  inputs_embeds=inputs_embeds,
789
878
  cache_position=cache_position,
790
879
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
791
880
  lora_int_ids=lora_int_ids,
792
- ).logits
881
+ )
882
+ logits = outputs.logits
883
+ all_hidden_states = outputs.hidden_states
793
884
 
794
885
  if not return_dict:
795
- return logits, generate_idx, padded_cache_lengths
886
+ return logits, generate_idx, padded_cache_lengths, all_hidden_states
796
887
  else:
797
888
  return RBLNDecoderOnlyOutput(
798
- logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
889
+ logits=logits,
890
+ generate_idx=generate_idx,
891
+ padded_cache_lengths=padded_cache_lengths,
892
+ hidden_states=all_hidden_states,
799
893
  )
@@ -13,6 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import DepthEstimatorOutput
20
+
16
21
  from ...modeling_generic import RBLNModelForDepthEstimation
17
22
 
18
23
 
@@ -23,3 +28,15 @@ class RBLNDepthAnythingForDepthEstimation(RBLNModelForDepthEstimation):
23
28
  This class provides hardware-accelerated inference for Depth Anything V2
24
29
  models on RBLN devices, providing the most capable monocular depth estimation (MDE) model.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
33
+ """
34
+ Forward pass for the RBLN-optimized DepthAnythingForDepthEstimation model.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
38
+
39
+ Returns:
40
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
41
+ """
42
+ return super().forward(pixel_values, **kwargs)
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import QuestionAnsweringModelOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForQuestionAnswering
16
21
 
17
22
 
@@ -25,3 +30,22 @@ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
25
30
  """
26
31
 
27
32
  rbln_model_input_names = ["input_ids", "attention_mask"]
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: Optional[torch.Tensor] = None,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ **kwargs,
39
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
40
+ """
41
+ Forward pass for the RBLN-optimized DistilBERT model for question answering tasks.
42
+
43
+ Args:
44
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
45
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
46
+
47
+ Returns:
48
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
49
+ """
50
+
51
+ return super().forward(input_ids, attention_mask, **kwargs)
@@ -13,6 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import DepthEstimatorOutput
20
+
16
21
  from ...modeling_generic import RBLNModelForDepthEstimation
17
22
 
18
23
 
@@ -23,3 +28,15 @@ class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
23
28
  This class provides hardware-accelerated inference for DPT (Dense Prediction Transformer)
24
29
  models on RBLN devices, supporting monocular depth estimation from single images.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
33
+ """
34
+ Forward pass for the RBLN-optimized DPT model.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
38
+
39
+ Returns:
40
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
41
+ """
42
+ return super().forward(pixel_values, **kwargs)
@@ -64,6 +64,7 @@ class Gemma3TextModel(DecoderOnlyModel):
64
64
  global_block_tables: Optional[torch.Tensor] = None,
65
65
  local_block_tables: Optional[torch.Tensor] = None,
66
66
  lora_int_id: Optional[torch.Tensor] = None,
67
+ output_hidden_states: Optional[bool] = None,
67
68
  ):
68
69
  # retrieve input_ids and inputs_embeds
69
70
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -96,7 +97,10 @@ class Gemma3TextModel(DecoderOnlyModel):
96
97
 
97
98
  sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
98
99
 
100
+ all_hidden_states = () if output_hidden_states else None
99
101
  for layer_idx, layer in enumerate(self.layers):
102
+ if output_hidden_states:
103
+ all_hidden_states += (hidden_states,)
100
104
  is_sliding = True if layer_idx in self.sliding_window_layers else False
101
105
  hidden_states = layer(
102
106
  hidden_states=hidden_states,
@@ -110,7 +114,9 @@ class Gemma3TextModel(DecoderOnlyModel):
110
114
  )
111
115
 
112
116
  hidden_states = self.get_last_layernorm()(hidden_states)
113
- return hidden_states
117
+ if output_hidden_states:
118
+ all_hidden_states += (hidden_states,)
119
+ return hidden_states, all_hidden_states
114
120
 
115
121
 
116
122
  class Gemma3DecoderLayer(DecoderOnlyLayer):
@@ -16,7 +16,7 @@ from typing import Optional
16
16
  import rebel
17
17
  import torch
18
18
 
19
- from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
19
+ from ...modeling_outputs import RBLNGemma3ForCausalLMOutput
20
20
  from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
21
21
  from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
22
22
 
@@ -26,7 +26,6 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
26
26
  super().__init__(*args, **kwargs)
27
27
  self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
28
28
  self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
29
- self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
30
29
 
31
30
  def _prepare_prefill_inputs(self, *args, **kwargs):
32
31
  (
@@ -106,6 +105,8 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
106
105
  )
107
106
 
108
107
  step = 0
108
+ output_logits = []
109
+ all_hidden_states = [] if self.rbln_config.output_hidden_states else None
109
110
  while step < query_length:
110
111
  if self.rbln_config.use_image_prefill:
111
112
  # Check if the prefill chunk is an image prefill
@@ -146,7 +147,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
146
147
  query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
147
148
 
148
149
  if is_image_prefill:
149
- logits = self.image_prefill(
150
+ outputs = self.image_prefill(
150
151
  input_chunk,
151
152
  cache_pos_chunk,
152
153
  block_tables,
@@ -157,7 +158,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
157
158
  lora_int_ids if self.rbln_config.use_lora else None,
158
159
  )
159
160
  else:
160
- logits = self.prefill(
161
+ outputs = self.prefill(
161
162
  input_chunk,
162
163
  cache_pos_chunk,
163
164
  block_tables,
@@ -168,78 +169,49 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
168
169
  lora_int_ids if self.rbln_config.use_lora else None,
169
170
  )
170
171
 
172
+ if self.rbln_config.output_hidden_states:
173
+ output_logits.append(outputs[0])
174
+ all_hidden_states.append(tuple(outputs[1:]))
175
+ else:
176
+ output_logits.append(outputs)
177
+
171
178
  padded_cache_lengths += current_padded_cache_lengths
172
179
  step += num_processed_tokens
173
180
 
174
- if not is_external_block_tables:
175
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
176
-
177
- return RBLNGemma3ForCausalLMOutput(
178
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
179
- )
180
-
181
- def decode_forward(
182
- self,
183
- inputs: torch.Tensor,
184
- cache_position: torch.Tensor = None,
185
- block_tables: torch.Tensor = None,
186
- is_external_block_tables: bool = None,
187
- attention_mask: Optional[torch.Tensor] = None,
188
- position_embed: Optional[torch.Tensor] = None,
189
- position_ids: Optional[torch.Tensor] = None,
190
- local_block_tables: Optional[torch.Tensor] = None,
191
- lora_int_ids: Optional[torch.Tensor] = None,
192
- ) -> torch.FloatTensor:
193
- if self.rbln_config.use_lora and lora_int_ids is None:
194
- if self.lora_int_ids is None:
195
- raise ValueError(
196
- "lora_int_id is required when using LoRA. "
197
- "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
198
- )
199
-
200
- lora_int_ids = self.lora_int_ids
201
-
202
- if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
203
- raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
204
-
205
- batch_size = inputs.shape[0]
206
- if batch_size != self.batch_size:
207
- raise RuntimeError(
208
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
209
- )
181
+ if self.rbln_config.output_hidden_states:
182
+ num_hidden_layers = len(all_hidden_states[0]) - 1
183
+ concatenated_hidden_states = ()
184
+ for l_idx in range(num_hidden_layers + 1):
185
+ l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
186
+ l_hidden_states = l_hidden_states[:, :query_length, :]
187
+ concatenated_hidden_states += (l_hidden_states,)
210
188
 
211
- if batch_size != cache_position.shape[0]:
212
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
189
+ all_hidden_states = concatenated_hidden_states
213
190
 
214
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
215
- if is_external_block_tables:
216
- if attention_mask is None:
217
- raise ValueError("attention_mask should be provided with external block tables.")
218
- if local_block_tables is None:
219
- raise ValueError("local_block_tables should be provided with external block tables.")
191
+ # Aggregate output_logits
192
+ output_logits = torch.concat(output_logits, dim=-2)
193
+ if self.rbln_config.logits_to_keep > 0:
194
+ output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
220
195
  else:
221
- local_block_tables = (
222
- local_block_tables
223
- if local_block_tables is not None
224
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
225
- )
226
- if self.rbln_config.use_attention_mask and attention_mask is None:
227
- for b_idx in range(batch_size):
228
- decoding_step = cache_position[b_idx].item()
229
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
230
- raise ValueError(
231
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
232
- )
233
- self.dec_attn_mask[b_idx, decoding_step] = 1
234
-
235
- attention_mask = self.dec_attn_mask
236
-
237
- if self.batch_size < block_tables.shape[0]:
238
- block_tables = block_tables[: self.batch_size]
196
+ output_logits = output_logits[:, :query_length, :]
197
+ # index copy for masked output_logits
198
+ if attention_mask is not None:
199
+ new_output_logits = torch.full(
200
+ (1, attention_mask.shape[-1], output_logits.shape[-1]),
201
+ fill_value=1e-10,
202
+ dtype=output_logits.dtype,
203
+ )
204
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
205
+ new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
239
206
 
240
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
241
- attention_mask = attention_mask[: self.batch_size]
207
+ output_logits = new_output_logits
242
208
 
243
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
209
+ if not is_external_block_tables:
210
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
244
211
 
245
- return RBLNDecoderOnlyOutput(logits=logits)
212
+ return RBLNGemma3ForCausalLMOutput(
213
+ logits=output_logits,
214
+ padded_cache_lengths=padded_cache_lengths,
215
+ attention_mask=chunked_attention_mask,
216
+ hidden_states=all_hidden_states,
217
+ )