optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,9 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Any, Callable
16
+ from typing import Any, Callable, Optional, Tuple, Union
17
17
 
18
+ import torch
18
19
  from transformers import BartForConditionalGeneration, PreTrainedModel
20
+ from transformers.modeling_outputs import Seq2SeqModelOutput
19
21
 
20
22
  from ....utils.logging import get_logger
21
23
  from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
@@ -35,6 +37,25 @@ class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
37
  on RBLN devices, optimized for feature extraction use cases.
36
38
  """
37
39
 
40
+ def forward(
41
+ self,
42
+ input_ids: Optional[torch.Tensor] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ **kwargs,
45
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
46
+ """
47
+ Forward pass for the RBLN-optimized BART model for feature extraction tasks.
48
+
49
+ Args:
50
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
51
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
52
+
53
+ Returns:
54
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a Seq2SeqModelOutput object.
55
+ """
56
+
57
+ return super().forward(input_ids, attention_mask, **kwargs)
58
+
38
59
 
39
60
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
40
61
  """
@@ -48,7 +69,7 @@ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
48
69
  support_causal_attn = True
49
70
 
50
71
  @classmethod
51
- def wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
72
+ def _wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
52
73
  return BartWrapper(
53
74
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
54
75
  )
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+
4
+ class BertModelWrapper(torch.nn.Module):
5
+ def __init__(self, model, rbln_config):
6
+ super().__init__()
7
+ self.model = model
8
+ self.rbln_config = rbln_config
9
+
10
+ def forward(self, *args, **kwargs):
11
+ output = self.model(*args, **kwargs)
12
+ if isinstance(output, torch.Tensor):
13
+ return output
14
+ elif isinstance(output, tuple):
15
+ return tuple(x for x in output if x is not None)
16
+ return output
@@ -12,15 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ....utils.logging import get_logger
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPoolingAndCrossAttentions,
20
+ MaskedLMOutput,
21
+ QuestionAnsweringModelOutput,
22
+ )
23
+
16
24
  from ...modeling_generic import (
17
25
  RBLNModelForMaskedLM,
18
26
  RBLNModelForQuestionAnswering,
19
27
  RBLNTransformerEncoderForFeatureExtraction,
20
28
  )
21
-
22
-
23
- logger = get_logger(__name__)
29
+ from .bert_architecture import BertModelWrapper
30
+ from .configuration_bert import RBLNBertModelConfig
24
31
 
25
32
 
26
33
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
@@ -34,6 +41,46 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
34
41
 
35
42
  rbln_model_input_names = ["input_ids", "attention_mask"]
36
43
 
44
+ @classmethod
45
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
46
+ return BertModelWrapper(model, rbln_config)
47
+
48
+ def forward(
49
+ self,
50
+ input_ids: Optional[torch.Tensor] = None,
51
+ attention_mask: Optional[torch.Tensor] = None,
52
+ token_type_ids: Optional[torch.Tensor] = None,
53
+ position_ids: Optional[torch.Tensor] = None,
54
+ **kwargs,
55
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple]:
56
+ """
57
+ Forward pass for the RBLN-optimized BERT model for feature extraction tasks.
58
+
59
+ Args:
60
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
61
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
62
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
63
+ position_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of positions of each input sequence tokens in the position embeddings.
64
+
65
+ Returns:
66
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
67
+ """
68
+
69
+ input_map = {
70
+ "input_ids": input_ids,
71
+ "attention_mask": attention_mask,
72
+ "token_type_ids": token_type_ids,
73
+ "position_ids": position_ids,
74
+ }
75
+
76
+ model_input_names = getattr(self.rbln_config, "model_input_names", None)
77
+ if model_input_names is None:
78
+ model_input_names = self.rbln_model_input_names
79
+
80
+ ordered_inputs = [input_map[name] for name in model_input_names if name in input_map]
81
+
82
+ return super().forward(*ordered_inputs, **kwargs)
83
+
37
84
 
38
85
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
86
  """
@@ -46,6 +93,27 @@ class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
46
93
 
47
94
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
48
95
 
96
+ def forward(
97
+ self,
98
+ input_ids: Optional[torch.Tensor] = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ token_type_ids: Optional[torch.Tensor] = None,
101
+ **kwargs,
102
+ ) -> Union[MaskedLMOutput, Tuple]:
103
+ """
104
+ Forward pass for the RBLN-optimized BERT model for masked language modeling tasks.
105
+
106
+ Args:
107
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
108
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
109
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
110
+
111
+ Returns:
112
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
113
+ """
114
+
115
+ return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
116
+
49
117
 
50
118
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
51
119
  """
@@ -57,3 +125,24 @@ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
57
125
  """
58
126
 
59
127
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
128
+
129
+ def forward(
130
+ self,
131
+ input_ids: Optional[torch.Tensor] = None,
132
+ attention_mask: Optional[torch.Tensor] = None,
133
+ token_type_ids: Optional[torch.Tensor] = None,
134
+ **kwargs,
135
+ ) -> Union[QuestionAnsweringModelOutput, Tuple]:
136
+ """
137
+ Forward pass for the RBLN-optimized BERT model for question answering tasks.
138
+
139
+ Args:
140
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
141
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
142
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
143
+
144
+ Returns:
145
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
146
+ """
147
+
148
+ return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
@@ -12,9 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNBlip2VisionModelConfig(RBLNModelConfig):
@@ -25,6 +29,16 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
25
29
  RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
26
30
  """
27
31
 
32
+ def __init__(
33
+ self,
34
+ batch_size: Optional[int] = None,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
41
+
28
42
 
29
43
  class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
30
44
  """
@@ -36,24 +50,34 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
36
50
 
37
51
  def __init__(
38
52
  self,
53
+ batch_size: Optional[int] = None,
39
54
  num_query_tokens: Optional[int] = None,
40
55
  image_text_hidden_size: Optional[int] = None,
41
56
  **kwargs,
42
57
  ):
43
58
  """
44
59
  Args:
45
- batch_size (Optional[int]): The batch size for inference. Defaults to 1.
46
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
47
-
48
- Raises:
49
- ValueError: If batch_size is not a positive integer.
60
+ num_query_tokens (Optional[int]): The number of query tokens passed through the Transformer.
61
+ image_text_hidden_size (Optional[int]): Dimensionality of the hidden state of the image-text fusion layer.
62
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
50
63
  """
51
64
  super().__init__(**kwargs)
65
+ self.batch_size = batch_size or 1
66
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
67
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
68
+
52
69
  self.num_query_tokens = num_query_tokens
53
70
  self.image_text_hidden_size = image_text_hidden_size
54
71
 
55
72
 
56
73
  class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
74
+ """
75
+ Configuration class for RBLNBlip2ForConditionalGeneration.
76
+
77
+ This configuration class stores the configuration parameters specific to
78
+ RBLN-optimized BLIP-2 models for conditional generation tasks that involve both image and text inputs.
79
+ """
80
+
57
81
  submodules = ["vision_model", "qformer", "language_model"]
58
82
 
59
83
  def __init__(
@@ -62,14 +86,15 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
62
86
  vision_model: Optional[RBLNModelConfig] = None,
63
87
  qformer: Optional[RBLNModelConfig] = None,
64
88
  language_model: Optional[RBLNModelConfig] = None,
65
- **kwargs: Dict[str, Any],
89
+ **kwargs: Any,
66
90
  ):
67
91
  """
68
92
  Args:
69
93
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
70
94
  vision_model (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
95
+ qformer (Optional[RBLNModelConfig]): Configuration for the RBLN-optimized BLIP-2 Q-Former model.
71
96
  language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
72
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
97
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
73
98
 
74
99
  Raises:
75
100
  ValueError: If batch_size is not a positive integer.
@@ -79,6 +104,12 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
79
104
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
80
105
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
81
106
 
82
- self.vision_model = self.init_submodule_config(RBLNBlip2VisionModelConfig, vision_model)
83
- self.language_model = language_model
84
- self.qformer = self.init_submodule_config(RBLNBlip2QFormerModelConfig, qformer)
107
+ if self.batch_size != 1:
108
+ logger.warning("Ignore batch_size for Blip2 vision model. It will be set to 1.")
109
+ logger.warning("Ignore batch_size for Blip2 qformer. It will be set to 1.")
110
+
111
+ self.vision_model = self.initialize_submodule_config(
112
+ submodule_config=vision_model, batch_size=1, force_kwargs=True
113
+ )
114
+ self.qformer = self.initialize_submodule_config(submodule_config=qformer, batch_size=1, force_kwargs=True)
115
+ self.language_model = self.initialize_submodule_config(submodule_config=language_model)
@@ -14,7 +14,7 @@
14
14
 
15
15
  import inspect
16
16
  from pathlib import Path
17
- from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  from transformers import (
@@ -30,38 +30,31 @@ from transformers.utils import logging
30
30
 
31
31
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
32
  from ....modeling import RBLNModel
33
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
34
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
33
35
 
34
36
 
35
37
  logger = logging.get_logger(__name__)
36
38
 
37
39
  if TYPE_CHECKING:
38
- from transformers import (
39
- AutoFeatureExtractor,
40
- AutoProcessor,
41
- AutoTokenizer,
42
- )
40
+ import rebel
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
43
42
 
44
43
 
45
- class LoopProjector:
46
- def __init__(self, language_projection) -> None:
47
- self.language_projection = language_projection
44
+ class LoopProjector(LoopProcessor):
45
+ def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
46
+ super().__init__(model=language_projection)
48
47
 
49
- def forward(self, *args, **kwargs):
50
- query_output = args[0]
48
+ def _get_batch_size(self, query_output, **kwargs):
49
+ return query_output.shape[0]
51
50
 
52
- batch_size = query_output.shape[0]
53
- outputs = []
54
- for i in range(batch_size):
55
- outputs.append(self.language_projection(query_output[i : i + 1]))
56
-
57
- outputs = torch.cat(outputs, dim=0)
58
- return outputs
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
52
+ query_output_item = query_output[index : index + 1]
53
+ return ([query_output_item], {})
59
54
 
60
- def __call__(self, *args: Any, **kwds: Any) -> Any:
61
- return self.forward(*args, **kwds)
62
-
63
- def __repr__(self) -> str:
64
- return repr(self.language_projection)
55
+ def _process_outputs(self, outputs: list, **kwargs):
56
+ output = torch.cat(outputs, dim=0)
57
+ return output
65
58
 
66
59
 
67
60
  class RBLNBlip2VisionModel(RBLNModel):
@@ -72,11 +65,13 @@ class RBLNBlip2VisionModel(RBLNModel):
72
65
  on RBLN devices, supporting image encoding for multimodal vision-language tasks.
73
66
  """
74
67
 
68
+ _tp_support = False
69
+
75
70
  def get_input_embeddings(self):
76
71
  return self.embeddings
77
72
 
78
73
  @classmethod
79
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
74
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
75
  class Blip2VisionModelWrapper(torch.nn.Module):
81
76
  def __init__(self, model: "Blip2VisionModel") -> None:
82
77
  super().__init__()
@@ -100,8 +95,7 @@ class RBLNBlip2VisionModel(RBLNModel):
100
95
  (
101
96
  "pixel_values",
102
97
  [
103
- # support for vllm CB (prefill)
104
- 1,
98
+ rbln_config.batch_size,
105
99
  model_config.num_channels,
106
100
  model_config.image_size,
107
101
  model_config.image_size,
@@ -116,12 +110,21 @@ class RBLNBlip2VisionModel(RBLNModel):
116
110
 
117
111
  def forward(
118
112
  self,
119
- pixel_values,
120
- output_attentions: Optional[bool] = None,
121
- output_hidden_states: Optional[bool] = None,
122
- return_dict: Optional[bool] = None,
113
+ pixel_values: torch.FloatTensor,
123
114
  interpolate_pos_encoding: bool = False,
115
+ return_dict: Optional[bool] = None,
124
116
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
117
+ """
118
+ Forward pass for the RBLN-optimized Blip2VisionModel model.
119
+
120
+ Args:
121
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
122
+ interpolate_pos_encoding (bool, optional): Whether to interpolate the positional encoding of the image embeddings. Defaults to False.
123
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
124
+
125
+ Returns:
126
+ BaseModelOutputWithPooling or tuple(torch.FloatTensor): The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
127
+ """
125
128
  batch_size = pixel_values.shape[0]
126
129
  outputs = []
127
130
  for i in range(batch_size):
@@ -151,11 +154,13 @@ class RBLNBlip2QFormerModel(RBLNModel):
151
154
  mechanisms for multimodal understanding tasks.
152
155
  """
153
156
 
157
+ _tp_support = False
158
+
154
159
  def get_input_embeddings(self):
155
160
  return self.embeddings.word_embeddings
156
161
 
157
162
  @classmethod
158
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
163
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
159
164
  class Blip2QFormerModelWrapper(torch.nn.Module):
160
165
  def __init__(self, model: "Blip2QFormerModel"):
161
166
  super().__init__()
@@ -178,7 +183,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
178
183
  return Blip2QFormerModelWrapper(model).eval()
179
184
 
180
185
  @classmethod
181
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
186
+ def _update_submodule_config(
187
+ cls,
188
+ model: "PreTrainedModel",
189
+ rbln_config: RBLNModelConfig,
190
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
191
+ ):
182
192
  if rbln_config.num_query_tokens is None:
183
193
  rbln_config.num_query_tokens = model.config.num_query_tokens
184
194
 
@@ -199,7 +209,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
199
209
  (
200
210
  "query_embeds",
201
211
  [
202
- 1,
212
+ rbln_config.batch_size,
203
213
  rbln_config.num_query_tokens,
204
214
  model_config.hidden_size,
205
215
  ],
@@ -208,7 +218,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
208
218
  (
209
219
  "encoder_hidden_states",
210
220
  [
211
- 1,
221
+ rbln_config.batch_size,
212
222
  # image_text_hidden_size + cls token
213
223
  rbln_config.image_text_hidden_size + 1,
214
224
  model_config.encoder_hidden_size,
@@ -218,7 +228,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
218
228
  (
219
229
  "encoder_attention_mask",
220
230
  # image_text_hidden_size + cls token
221
- [1, rbln_config.image_text_hidden_size + 1],
231
+ [rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
222
232
  "int64",
223
233
  ),
224
234
  ]
@@ -230,17 +240,22 @@ class RBLNBlip2QFormerModel(RBLNModel):
230
240
  def forward(
231
241
  self,
232
242
  query_embeds: torch.FloatTensor,
233
- query_length: Optional[int] = None,
234
- attention_mask: Optional[torch.FloatTensor] = None,
235
- head_mask: Optional[torch.FloatTensor] = None,
236
243
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
237
244
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
238
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
239
- use_cache: Optional[bool] = None,
240
- output_attentions: Optional[bool] = None,
241
- output_hidden_states: Optional[bool] = None,
242
245
  return_dict: Optional[bool] = None,
243
246
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
247
+ """
248
+ The forward pass for the RBLN-optimized Blip2QFormerModel model.
249
+
250
+ Args:
251
+ query_embeds (torch.FloatTensor): Hidden states to be used in the attention computation.
252
+ encoder_hidden_states (torch.FloatTensor, optional): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.
253
+ encoder_attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder.
254
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
255
+
256
+ Returns:
257
+ BaseModelOutputWithPoolingAndCrossAttentions or tuple(torch.FloatTensor): The model outputs. If `return_dict=False` is passed, returns a tuple of tensors. Otherwise, returns a `BaseModelOutputWithPoolingAndCrossAttentions` object.
258
+ """
244
259
  batch_size = query_embeds.shape[0]
245
260
  outputs = []
246
261
  for i in range(batch_size):
@@ -265,7 +280,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
265
280
  )
266
281
 
267
282
 
268
- class RBLNBlip2ForConditionalGeneration(RBLNModel):
283
+ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
269
284
  """
270
285
  RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
271
286
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -348,7 +363,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
348
363
  return self.language_model.get_input_embeddings()
349
364
 
350
365
  @classmethod
351
- def wrap_model_if_needed(cls, model, rbln_config):
366
+ def _wrap_model_if_needed(cls, model, rbln_config):
352
367
  return model.language_projection
353
368
 
354
369
  @classmethod
@@ -433,3 +448,79 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
433
448
  )
434
449
 
435
450
  return inputs_embeds
451
+
452
+ @torch.no_grad()
453
+ def generate(
454
+ self,
455
+ pixel_values: torch.FloatTensor,
456
+ input_ids: Optional[torch.LongTensor] = None,
457
+ attention_mask: Optional[torch.LongTensor] = None,
458
+ inputs_embeds: Optional[torch.FloatTensor] = None,
459
+ interpolate_pos_encoding: bool = False,
460
+ **generate_kwargs,
461
+ ) -> List[torch.LongTensor]:
462
+ """
463
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
464
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/blip-2#transformers.Blip2ForConditionalGeneration.generate) for more details.
465
+
466
+ Args:
467
+ pixel_values (torch.FloatTensor): Input images to be processed.
468
+ input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
469
+ attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
470
+ inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
471
+ interpolate_pos_encoding (bool, optional, defaults to False) — Whether to interpolate the positional encoding of the image embeddings.
472
+ Returns:
473
+ A list of strings of length batch_size * num_captions.
474
+ """
475
+ batch_size = pixel_values.shape[0]
476
+ image_embeds = self.vision_model(
477
+ pixel_values,
478
+ return_dict=True,
479
+ interpolate_pos_encoding=interpolate_pos_encoding,
480
+ ).last_hidden_state
481
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
482
+
483
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
484
+ query_outputs = self.qformer(
485
+ query_embeds=query_tokens,
486
+ encoder_hidden_states=image_embeds,
487
+ encoder_attention_mask=image_attention_mask,
488
+ return_dict=True,
489
+ )
490
+ query_output = query_outputs.last_hidden_state
491
+
492
+ if query_output.dtype != image_embeds.dtype:
493
+ query_output = query_output.to(image_embeds.dtype)
494
+
495
+ language_model_inputs = self.language_projection(query_output)
496
+
497
+ if inputs_embeds is None:
498
+ if input_ids is None:
499
+ image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
500
+ start_tokens = image_tokens + [self.config.text_config.bos_token_id]
501
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
502
+ input_ids = input_ids.repeat(batch_size, 1)
503
+ inputs_embeds = self.get_input_embeddings()(input_ids)
504
+
505
+ if attention_mask is None:
506
+ attention_mask = torch.ones_like(input_ids)
507
+
508
+ if input_ids is None:
509
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
510
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
511
+ )
512
+ special_image_mask = special_image_mask.all(-1)
513
+ else:
514
+ special_image_mask = input_ids == self.config.image_token_id
515
+
516
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
517
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
518
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
519
+
520
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
521
+ if not self.language_model.config.is_encoder_decoder:
522
+ inputs["input_ids"] = input_ids
523
+
524
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
525
+
526
+ return outputs
@@ -12,20 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
- def __init__(self, batch_size: Optional[int] = None, **kwargs: Dict[str, Any]):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
22
22
  """
23
23
  Args:
24
24
  batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
25
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
25
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
26
26
 
27
27
  Raises:
28
- ValueError: If batch_size is not a positive integer.
28
+ ValueError: If `batch_size` is not a positive integer.
29
29
  """
30
30
  super().__init__(**kwargs)
31
31
  self.batch_size = batch_size or 1
@@ -50,17 +50,20 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
50
50
  interpolate_pos_encoding: Optional[bool] = None,
51
51
  output_hidden_states: Optional[bool] = None,
52
52
  output_attentions: Optional[bool] = None,
53
- **kwargs: Dict[str, Any],
53
+ **kwargs: Any,
54
54
  ):
55
55
  """
56
56
  Args:
57
57
  batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
58
58
  image_size (Optional[int]): The size of input images. Can be an integer for square images,
59
59
  a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
60
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
60
+ interpolate_pos_encoding (Optional[bool]): Whether or not to interpolate pre-trained position encodings. Defaults to `False`.
61
+ output_hidden_states (Optional[bool]): Whether or not to return the hidden states of all layers.
62
+ output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers
63
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
61
64
 
62
65
  Raises:
63
- ValueError: If batch_size is not a positive integer.
66
+ ValueError: If `batch_size` is not a positive integer.
64
67
  """
65
68
  super().__init__(**kwargs)
66
69
  self.batch_size = batch_size or 1