optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -27,7 +27,7 @@
27
27
  # limitations under the License.
28
28
 
29
29
  import math
30
- from typing import Optional, Tuple
30
+ from typing import Optional
31
31
 
32
32
  import torch
33
33
  from transformers import PretrainedConfig
@@ -35,13 +35,16 @@ from transformers import PretrainedConfig
35
35
 
36
36
  def _compute_default_rope_parameters(
37
37
  config: Optional[PretrainedConfig] = None,
38
+ device: Optional["torch.device"] = None,
38
39
  seq_len: Optional[int] = None,
39
- ) -> Tuple["torch.Tensor", float]:
40
+ ) -> tuple["torch.Tensor", float]:
40
41
  """
41
42
  Computes the inverse frequencies according to the original RoPE implementation
42
43
  Args:
43
44
  config ([`~transformers.PretrainedConfig`]):
44
45
  The model configuration.
46
+ device (`torch.device`):
47
+ The device to use for initialization of the inverse frequencies.
45
48
  seq_len (`int`, *optional*):
46
49
  The current sequence length. Unused for this type of RoPE.
47
50
  Returns:
@@ -50,40 +53,38 @@ def _compute_default_rope_parameters(
50
53
  """
51
54
  base = config.rope_theta
52
55
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
53
- head_dim = (
54
- config.head_dim
55
- if hasattr(config, "head_dim") and config.head_dim is not None
56
- else config.hidden_size // config.num_attention_heads
57
- )
56
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
58
57
  dim = int(head_dim * partial_rotary_factor)
59
58
 
60
59
  attention_factor = 1.0 # Unused in this type of RoPE
61
60
 
62
61
  # Compute the inverse frequencies
63
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
62
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
64
63
  return inv_freq, attention_factor
65
64
 
66
65
 
67
66
  def _compute_linear_scaling_rope_parameters(
68
67
  config: Optional[PretrainedConfig] = None,
68
+ device: Optional["torch.device"] = None,
69
69
  seq_len: Optional[int] = None,
70
- ) -> Tuple["torch.Tensor", float]:
70
+ ) -> tuple["torch.Tensor", float]:
71
71
  """
72
72
  Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
73
73
  Args:
74
74
  config ([`~transformers.PretrainedConfig`]):
75
75
  The model configuration.
76
+ device (`torch.device`):
77
+ The device to use for initialization of the inverse frequencies.
76
78
  seq_len (`int`, *optional*):
77
79
  The current sequence length. Unused for this type of RoPE.
78
80
  Returns:
79
81
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
80
82
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
81
83
  """
82
-
83
84
  factor = config.rope_scaling["factor"]
84
85
 
85
86
  # Gets the default RoPE parameters
86
- inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
87
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
87
88
 
88
89
  # Then applies linear scaling to the frequencies.
89
90
  # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
@@ -94,20 +95,23 @@ def _compute_linear_scaling_rope_parameters(
94
95
 
95
96
  def _compute_dynamic_ntk_parameters(
96
97
  config: Optional[PretrainedConfig] = None,
98
+ device: Optional["torch.device"] = None,
97
99
  seq_len: Optional[int] = None,
98
- ) -> Tuple["torch.Tensor", float]:
100
+ ) -> tuple["torch.Tensor", float]:
99
101
  """
100
102
  Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
101
103
  Args:
102
104
  config ([`~transformers.PretrainedConfig`]):
103
105
  The model configuration.
106
+ device (`torch.device`):
107
+ The device to use for initialization of the inverse frequencies.
104
108
  seq_len (`int`, *optional*):
105
109
  The current sequence length, used to update the dynamic RoPE at inference time.
106
110
  Returns:
107
111
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
108
112
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
109
113
  """
110
-
114
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
111
115
  base = config.rope_theta
112
116
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
113
117
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@@ -117,6 +121,17 @@ def _compute_dynamic_ntk_parameters(
117
121
 
118
122
  attention_factor = 1.0 # Unused in this type of RoPE
119
123
 
124
+ # seq_len: default to max_position_embeddings, e.g. at init time
125
+ if seq_len is None:
126
+ seq_len = max_position_embeddings
127
+ elif isinstance(seq_len, torch.Tensor):
128
+ seq_len = torch.maximum(
129
+ seq_len,
130
+ torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
131
+ )
132
+ else:
133
+ seq_len = max(seq_len, max_position_embeddings)
134
+
120
135
  # Process with chunk_size to reduce precesion error
121
136
  chunk_size = 4096
122
137
  chunks = (seq_len + chunk_size - 1) // chunk_size
@@ -140,13 +155,17 @@ def _compute_dynamic_ntk_parameters(
140
155
  return final_inv_freq, attention_factor
141
156
 
142
157
 
143
- def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
158
+ def _compute_yarn_parameters(
159
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
160
+ ) -> tuple["torch.Tensor", float]:
144
161
  """
145
162
  Computes the inverse frequencies with NTK scaling. Please refer to the
146
- [original paper](https://arxiv.org/abs/2309.00071)
163
+ [original paper](https://huggingface.co/papers/2309.00071)
147
164
  Args:
148
165
  config ([`~transformers.PretrainedConfig`]):
149
166
  The model configuration.
167
+ device (`torch.device`):
168
+ The device to use for initialization of the inverse frequencies.
150
169
  seq_len (`int`, *optional*):
151
170
  The current sequence length. Unused for this type of RoPE.
152
171
  Returns:
@@ -158,13 +177,25 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
158
177
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
159
178
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
160
179
  dim = int(head_dim * partial_rotary_factor)
161
- max_position_embeddings = config.max_position_embeddings
162
180
  factor = config.rope_scaling["factor"]
181
+ attention_factor = config.rope_scaling.get("attention_factor")
182
+ mscale = config.rope_scaling.get("mscale")
183
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
184
+ original_max_position_embeddings = (
185
+ config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
186
+ )
187
+
188
+ def get_mscale(scale, mscale=1):
189
+ if scale <= 1:
190
+ return 1.0
191
+ return 0.1 * mscale * math.log(scale) + 1.0
163
192
 
164
193
  # Sets the attention factor as suggested in the paper
165
- attention_factor = config.rope_scaling.get("attention_factor")
166
194
  if attention_factor is None:
167
- attention_factor = 0.1 * math.log(factor) + 1.0
195
+ if mscale and mscale_all_dim:
196
+ attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
197
+ else:
198
+ attention_factor = get_mscale(factor)
168
199
 
169
200
  # Optional config options
170
201
  # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
@@ -176,10 +207,13 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
176
207
  """Inverse dimension formula to find the dimension based on the number of rotations"""
177
208
  return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
178
209
 
179
- def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
210
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
180
211
  """Find dimension range bounds based on rotations"""
181
- low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
182
- high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
212
+ low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
213
+ high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
214
+ if truncate:
215
+ low = math.floor(low)
216
+ high = math.ceil(high)
183
217
  return max(low, 0), min(high, dim - 1)
184
218
 
185
219
  def linear_ramp_factor(min, max, dim):
@@ -192,38 +226,40 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
192
226
 
193
227
  # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
194
228
  # to expand the possible context length. In other words, interpolation = apply scaling factor.
195
- pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
229
+ pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
196
230
  inv_freq_extrapolation = 1.0 / pos_freqs
197
231
  inv_freq_interpolation = 1.0 / (factor * pos_freqs)
198
232
 
199
- low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
233
+ truncate = config.rope_scaling.get("truncate", True)
234
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
200
235
 
201
236
  # Get n-dimensional rotational scaling corrected for extrapolation
202
- inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
237
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
203
238
  inv_freq = (
204
239
  inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
205
240
  + inv_freq_extrapolation * inv_freq_extrapolation_factor
206
241
  )
207
-
208
242
  return inv_freq, attention_factor
209
243
 
210
244
 
211
245
  def _compute_longrope_parameters(
212
- config: PretrainedConfig, seq_len: Optional[int] = None
213
- ) -> Tuple["torch.Tensor", float]:
246
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
247
+ ) -> tuple["torch.Tensor", float]:
214
248
  """
215
249
  Computes the inverse frequencies with LongRoPE scaling. Please refer to the
216
250
  [original implementation](https://github.com/microsoft/LongRoPE)
217
251
  Args:
218
252
  config ([`~transformers.PretrainedConfig`]):
219
253
  The model configuration.
254
+ device (`torch.device`):
255
+ The device to use for initialization of the inverse frequencies.
220
256
  seq_len (`int`, *optional*):
221
- The current sequence length. Unused for this type of RoPE.
257
+ The current sequence length.
222
258
  Returns:
223
259
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
224
260
  post-processing scaling factor applied to the computed cos/sin.
225
261
  """
226
-
262
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
227
263
  base = config.rope_theta
228
264
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
229
265
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@@ -237,40 +273,40 @@ def _compute_longrope_parameters(
237
273
  # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
238
274
  # values to compute the default attention scaling factor, instead of using `factor`.
239
275
  if hasattr(config, "original_max_position_embeddings"):
240
- max_position_embeddings = config.original_max_position_embeddings
241
- expanded_max_position_embeddings = config.max_position_embeddings
242
- factor = expanded_max_position_embeddings / max_position_embeddings
276
+ original_max_position_embeddings = config.original_max_position_embeddings
277
+ factor = config.max_position_embeddings / config.original_max_position_embeddings
243
278
  else:
244
- max_position_embeddings = config.max_position_embeddings
245
- expanded_max_position_embeddings = max_position_embeddings * factor
279
+ original_max_position_embeddings = config.max_position_embeddings
246
280
 
247
281
  # Sets the attention factor as suggested in the paper
248
282
  if attention_factor is None:
249
283
  if factor <= 1.0:
250
284
  attention_factor = 1.0
251
285
  else:
252
- attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
286
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
253
287
 
254
288
  # Compute the inverse frequencies -- scaled based on the target sequence length
255
- if expanded_max_position_embeddings > max_position_embeddings:
256
- ext_factors = torch.tensor(long_factor, dtype=torch.float32)
289
+ if seq_len and seq_len > original_max_position_embeddings:
290
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
257
291
  else:
258
- ext_factors = torch.tensor(short_factor, dtype=torch.float32)
259
- inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
292
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
293
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
260
294
  inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
261
295
 
262
296
  return inv_freq, attention_factor
263
297
 
264
298
 
265
299
  def _compute_llama3_parameters(
266
- config: PretrainedConfig, seq_len: Optional[int] = None
267
- ) -> Tuple["torch.Tensor", float]:
300
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
301
+ ) -> tuple["torch.Tensor", float]:
268
302
  """
269
303
  Computes the inverse frequencies for llama 3.1.
270
304
 
271
305
  Args:
272
306
  config ([`~transformers.PretrainedConfig`]):
273
307
  The model configuration.
308
+ device (`torch.device`):
309
+ The device to use for initialization of the inverse frequencies.
274
310
  seq_len (`int`, *optional*):
275
311
  The current sequence length. Unused for this type of RoPE.
276
312
  Returns:
@@ -278,7 +314,7 @@ def _compute_llama3_parameters(
278
314
  post-processing scaling factor applied to the computed cos/sin.
279
315
  """
280
316
  # Gets the default RoPE parameters
281
- inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
317
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
282
318
 
283
319
  factor = config.rope_scaling["factor"] # `8` in the original implementation
284
320
  low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
@@ -88,12 +88,16 @@ _import_structure = {
88
88
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
89
89
  "RBLNQwen2_5_VLForConditionalGeneration",
90
90
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
91
+ "RBLNQwen2_5_VLModel",
92
+ "RBLNQwen2_5_VLModelConfig",
91
93
  ],
92
94
  "qwen2_vl": [
93
95
  "RBLNQwen2VisionTransformerPretrainedModel",
94
96
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
95
97
  "RBLNQwen2VLForConditionalGeneration",
96
98
  "RBLNQwen2VLForConditionalGenerationConfig",
99
+ "RBLNQwen2VLModel",
100
+ "RBLNQwen2VLModelConfig",
97
101
  ],
98
102
  "decoderonly": [
99
103
  "RBLNDecoderOnlyModelConfig",
@@ -110,12 +114,14 @@ _import_structure = {
110
114
  ],
111
115
  "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
112
116
  "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig", "RBLNGemmaModel", "RBLNGemmaModelConfig"],
117
+ "gemma2": ["RBLNGemma2ForCausalLM", "RBLNGemma2ForCausalLMConfig", "RBLNGemma2Model", "RBLNGemma2ModelConfig"],
113
118
  "gemma3": [
114
119
  "RBLNGemma3ForCausalLM",
115
120
  "RBLNGemma3ForCausalLMConfig",
116
121
  "RBLNGemma3ForConditionalGeneration",
117
122
  "RBLNGemma3ForConditionalGenerationConfig",
118
123
  ],
124
+ "gpt_oss": ["RBLNGptOssForCausalLM", "RBLNGptOssForCausalLMConfig"],
119
125
  "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig", "RBLNGPT2Model", "RBLNGPT2ModelConfig"],
120
126
  "idefics3": [
121
127
  "RBLNIdefics3VisionTransformer",
@@ -132,6 +138,12 @@ _import_structure = {
132
138
  "RBLNPegasusForConditionalGenerationConfig",
133
139
  "RBLNPegasusModelConfig",
134
140
  ],
141
+ "paligemma": [
142
+ "RBLNPaliGemmaForConditionalGeneration",
143
+ "RBLNPaliGemmaForConditionalGenerationConfig",
144
+ "RBLNPaliGemmaModel",
145
+ "RBLNPaliGemmaModelConfig",
146
+ ],
135
147
  "llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
136
148
  "midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
137
149
  "pixtral": ["RBLNPixtralVisionModel", "RBLNPixtralVisionModelConfig"],
@@ -143,7 +155,9 @@ _import_structure = {
143
155
  ],
144
156
  "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig", "RBLNPhiModel", "RBLNPhiModelConfig"],
145
157
  "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2Model", "RBLNQwen2ModelConfig"],
158
+ "qwen2_moe": ["RBLNQwen2MoeForCausalLM", "RBLNQwen2MoeForCausalLMConfig"],
146
159
  "qwen3": ["RBLNQwen3ForCausalLM", "RBLNQwen3ForCausalLMConfig", "RBLNQwen3Model", "RBLNQwen3ModelConfig"],
160
+ "qwen3_moe": ["RBLNQwen3MoeForCausalLM", "RBLNQwen3MoeForCausalLMConfig"],
147
161
  "resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
148
162
  "roberta": [
149
163
  "RBLNRobertaForMaskedLM",
@@ -254,6 +268,7 @@ if TYPE_CHECKING:
254
268
  from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
255
269
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
256
270
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
271
+ from .gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig, RBLNGemma2Model, RBLNGemma2ModelConfig
257
272
  from .gemma3 import (
258
273
  RBLNGemma3ForCausalLM,
259
274
  RBLNGemma3ForCausalLMConfig,
@@ -261,6 +276,7 @@ if TYPE_CHECKING:
261
276
  RBLNGemma3ForConditionalGenerationConfig,
262
277
  )
263
278
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
279
+ from .gpt_oss import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
264
280
  from .grounding_dino import (
265
281
  RBLNGroundingDinoDecoder,
266
282
  RBLNGroundingDinoDecoderConfig,
@@ -281,6 +297,12 @@ if TYPE_CHECKING:
281
297
  from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
282
298
  from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
283
299
  from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
300
+ from .paligemma import (
301
+ RBLNPaliGemmaForConditionalGeneration,
302
+ RBLNPaliGemmaForConditionalGenerationConfig,
303
+ RBLNPaliGemmaModel,
304
+ RBLNPaliGemmaModelConfig,
305
+ )
284
306
  from .pegasus import (
285
307
  RBLNPegasusForConditionalGeneration,
286
308
  RBLNPegasusForConditionalGenerationConfig,
@@ -295,14 +317,20 @@ if TYPE_CHECKING:
295
317
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
296
318
  RBLNQwen2_5_VLForConditionalGeneration,
297
319
  RBLNQwen2_5_VLForConditionalGenerationConfig,
320
+ RBLNQwen2_5_VLModel,
321
+ RBLNQwen2_5_VLModelConfig,
298
322
  )
323
+ from .qwen2_moe import RBLNQwen2MoeForCausalLM, RBLNQwen2MoeForCausalLMConfig
299
324
  from .qwen2_vl import (
300
325
  RBLNQwen2VisionTransformerPretrainedModel,
301
326
  RBLNQwen2VisionTransformerPretrainedModelConfig,
302
327
  RBLNQwen2VLForConditionalGeneration,
303
328
  RBLNQwen2VLForConditionalGenerationConfig,
329
+ RBLNQwen2VLModel,
330
+ RBLNQwen2VLModelConfig,
304
331
  )
305
332
  from .qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig, RBLNQwen3Model, RBLNQwen3ModelConfig
333
+ from .qwen3_moe import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
306
334
  from .resnet import RBLNResNetForImageClassification, RBLNResNetForImageClassificationConfig
307
335
  from .roberta import (
308
336
  RBLNRobertaForMaskedLM,
@@ -12,10 +12,36 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...configuration_generic import RBLNModelForAudioClassificationConfig
15
+ from typing import Any, Optional
16
16
 
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.deprecation import deprecate_kwarg
17
19
 
18
- class RBLNASTForAudioClassificationConfig(RBLNModelForAudioClassificationConfig):
20
+
21
+ class RBLNASTForAudioClassificationConfig(RBLNModelConfig):
19
22
  """
20
23
  Configuration class for RBLNASTForAudioClassification.
21
24
  """
25
+
26
+ @deprecate_kwarg(old_name="num_mel_bins", version="0.10.0")
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ max_length: Optional[int] = None,
31
+ **kwargs: Any,
32
+ ):
33
+ """
34
+ Args:
35
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
36
+ max_length (Optional[int]): Maximum length of the audio input in time dimension.
37
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
38
+
39
+ Raises:
40
+ ValueError: If batch_size is not a positive integer.
41
+ """
42
+ super().__init__(**kwargs)
43
+ self.batch_size = batch_size or 1
44
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
45
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
+
47
+ self.max_length = max_length
@@ -12,17 +12,80 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...modeling_generic import RBLNModelForAudioClassification
15
+ from typing import TYPE_CHECKING, Optional
16
16
 
17
+ import torch
18
+ from transformers import AutoModelForAudioClassification
19
+ from transformers.modeling_outputs import SequenceClassifierOutput
17
20
 
18
- class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
21
+ from ....configuration_utils import RBLNCompileConfig
22
+ from ....modeling import RBLNModel
23
+ from .configuration_audio_spectrogram_transformer import RBLNASTForAudioClassificationConfig
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers import AutoFeatureExtractor, PretrainedConfig, PreTrainedModel
28
+
29
+
30
+ class RBLNASTForAudioClassification(RBLNModel):
19
31
  """
20
32
  Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2.
21
- This model inherits from [`RBLNModelForAudioClassification`]. Check the superclass documentation for the generic methods the library implements for all its models.
33
+ This model inherits from [RBLNModelForAudioClassification]. Check the superclass documentation for the generic methods the library implements for all its models.
22
34
 
23
- A class to convert and run pre-trained transformer-based `ASTForAudioClassification` models on RBLN devices.
24
- It implements the methods to convert a pre-trained transformers `ASTForAudioClassification` model into a RBLN transformer model by:
35
+ A class to convert and run pre-trained transformer-based ASTForAudioClassification models on RBLN devices.
36
+ It implements the methods to convert a pre-trained transformers ASTForAudioClassification model into a RBLN transformer model by:
25
37
 
26
38
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
27
39
  - compiling the resulting graph using the RBLN Compiler.
28
40
  """
41
+
42
+ auto_model_class = AutoModelForAudioClassification
43
+
44
+ @classmethod
45
+ def _update_rbln_config(
46
+ cls,
47
+ preprocessors: "AutoFeatureExtractor" = None,
48
+ model: Optional["PreTrainedModel"] = None,
49
+ model_config: "PretrainedConfig" = None,
50
+ rbln_config: Optional[RBLNASTForAudioClassificationConfig] = None,
51
+ ) -> RBLNASTForAudioClassificationConfig:
52
+ num_mel_bins = getattr(model_config, "num_mel_bins", None)
53
+
54
+ if rbln_config.max_length is None:
55
+ rbln_config.max_length = getattr(model_config, "max_length", None)
56
+ for feature_extractor in preprocessors:
57
+ if hasattr(feature_extractor, "max_length"):
58
+ rbln_config.max_length = feature_extractor.max_length
59
+ break
60
+
61
+ if rbln_config.max_length is None:
62
+ raise ValueError("max_length should be specified!")
63
+
64
+ input_info = [
65
+ (
66
+ "input_values",
67
+ [rbln_config.batch_size, rbln_config.max_length, num_mel_bins],
68
+ "float32",
69
+ ),
70
+ ]
71
+
72
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
73
+ return rbln_config
74
+
75
+ def forward(self, input_values: torch.Tensor, **kwargs) -> SequenceClassifierOutput:
76
+ """
77
+ Forward pass for the RBLN-optimized Audio Spectrogram Transformer model for audio classification.
78
+
79
+ Args:
80
+ input_values (torch.FloatTensor of shape (batch_size, max_length, num_mel_bins)):
81
+ Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
82
+ loading a .flac or .wav audio file into an array of type list[float], a numpy.ndarray or a torch.Tensor, *e.g.* via
83
+ the torchcodec library (pip install torchcodec) or the soundfile library (pip install soundfile).
84
+ To prepare the array into input_features, the [AutoFeatureExtractor] should be used for extracting the
85
+ mel features, padding and conversion into a tensor of type torch.FloatTensor.
86
+
87
+ Returns:
88
+ Returns a SequenceClassifierOutput object.
89
+ """
90
+
91
+ return super().forward(input_values, **kwargs)
@@ -150,6 +150,7 @@ class _BaseAutoModelClass:
150
150
  f"from the checkpoint, leading to potential unintended behavior. If this is not intentional, consider calling the "
151
151
  f"`from_pretrained()` method directly from the `RBLN{config.architectures[0]}` class instead.",
152
152
  UserWarning,
153
+ stacklevel=2,
153
154
  )
154
155
 
155
156
  return model_class
@@ -60,10 +60,10 @@ class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
60
60
  class BartDecoder(Seq2SeqDecoder):
61
61
  has_pos_emb = True
62
62
 
63
- def __post_init__(self):
64
- self.embed_positions = self._original_mod.embed_positions
65
- self.layernorm_embedding = self._original_mod.layernorm_embedding
66
- self.embed_scale = getattr(self._original_mod, "embed_scale", None)
63
+ def __post_init__(self, model: nn.Module):
64
+ self.embed_positions = model.embed_positions
65
+ self.layernorm_embedding = model.layernorm_embedding
66
+ self.embed_scale = getattr(model, "embed_scale", None)
67
67
 
68
68
  def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
69
  if attention_mask is not None:
@@ -112,11 +112,11 @@ class BartLayerFF(nn.Module):
112
112
 
113
113
 
114
114
  class BartDecoderLayer(Seq2SeqDecoderLayer):
115
- def __post_init__(self):
116
- self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
117
- self.encoder_attn = self._original_mod.encoder_attn
118
- self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
119
- self.ff_layer = BartLayerFF(self._original_mod)
115
+ def __post_init__(self, decoder_layer: nn.Module):
116
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
117
+ self.encoder_attn = decoder_layer.encoder_attn
118
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
119
+ self.ff_layer = BartLayerFF(decoder_layer)
120
120
 
121
121
  def pre_self_attn_layer_norm(self, hidden_states):
122
122
  return hidden_states
@@ -132,13 +132,13 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
132
132
 
133
133
 
134
134
  class BartSelfAttention(Seq2SeqSelfAttention):
135
- def __post_init__(self, use_attention_mask: bool = True):
136
- self.q_proj = self._original_mod.q_proj
137
- self.k_proj = self._original_mod.k_proj
138
- self.v_proj = self._original_mod.v_proj
139
- self.out_proj = self._original_mod.out_proj
140
- self.num_heads = self._original_mod.num_heads
141
- self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
135
+ def __post_init__(self, attn: nn.Module, use_attention_mask: bool = True):
136
+ self.q_proj = attn.q_proj
137
+ self.k_proj = attn.k_proj
138
+ self.v_proj = attn.v_proj
139
+ self.out_proj = attn.out_proj
140
+ self.num_heads = attn.num_heads
141
+ self.head_dim = attn.embed_dim // attn.num_heads
142
142
  self.scaling = self.head_dim**-0.5
143
143
  if use_attention_mask:
144
144
  self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
@@ -153,11 +153,11 @@ class BartSelfAttention(Seq2SeqSelfAttention):
153
153
 
154
154
 
155
155
  class BartCrossAttention(Seq2SeqCrossAttention):
156
- def __post_init__(self):
157
- self.q_proj = self._original_mod.q_proj
158
- self.k_proj = self._original_mod.k_proj
159
- self.v_proj = self._original_mod.v_proj
160
- self.out_proj = self._original_mod.out_proj
161
- self.num_heads = self._original_mod.num_heads
162
- self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
163
- self.embed_dim = self._original_mod.embed_dim
156
+ def __post_init__(self, attn: nn.Module):
157
+ self.q_proj = attn.q_proj
158
+ self.k_proj = attn.k_proj
159
+ self.v_proj = attn.v_proj
160
+ self.out_proj = attn.out_proj
161
+ self.num_heads = attn.num_heads
162
+ self.head_dim = attn.embed_dim // attn.num_heads
163
+ self.embed_dim = attn.embed_dim
@@ -13,9 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Any, Callable
16
+ from typing import Any, Callable, Optional, Tuple, Union
17
17
 
18
+ import torch
18
19
  from transformers import BartForConditionalGeneration, PreTrainedModel
20
+ from transformers.modeling_outputs import Seq2SeqModelOutput
19
21
 
20
22
  from ....utils.logging import get_logger
21
23
  from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
@@ -35,6 +37,25 @@ class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
37
  on RBLN devices, optimized for feature extraction use cases.
36
38
  """
37
39
 
40
+ def forward(
41
+ self,
42
+ input_ids: Optional[torch.Tensor] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ **kwargs,
45
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
46
+ """
47
+ Forward pass for the RBLN-optimized BART model for feature extraction tasks.
48
+
49
+ Args:
50
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
51
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
52
+
53
+ Returns:
54
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a Seq2SeqModelOutput object.
55
+ """
56
+
57
+ return super().forward(input_ids, attention_mask, **kwargs)
58
+
38
59
 
39
60
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
40
61
  """
@@ -48,7 +69,7 @@ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
48
69
  support_causal_attn = True
49
70
 
50
71
  @classmethod
51
- def wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
72
+ def _wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
52
73
  return BartWrapper(
53
74
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
54
75
  )