optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -27,7 +27,7 @@
27
27
  # limitations under the License.
28
28
 
29
29
  import math
30
- from typing import Optional, Tuple
30
+ from typing import Optional
31
31
 
32
32
  import torch
33
33
  from transformers import PretrainedConfig
@@ -35,13 +35,16 @@ from transformers import PretrainedConfig
35
35
 
36
36
  def _compute_default_rope_parameters(
37
37
  config: Optional[PretrainedConfig] = None,
38
+ device: Optional["torch.device"] = None,
38
39
  seq_len: Optional[int] = None,
39
- ) -> Tuple["torch.Tensor", float]:
40
+ ) -> tuple["torch.Tensor", float]:
40
41
  """
41
42
  Computes the inverse frequencies according to the original RoPE implementation
42
43
  Args:
43
44
  config ([`~transformers.PretrainedConfig`]):
44
45
  The model configuration.
46
+ device (`torch.device`):
47
+ The device to use for initialization of the inverse frequencies.
45
48
  seq_len (`int`, *optional*):
46
49
  The current sequence length. Unused for this type of RoPE.
47
50
  Returns:
@@ -50,40 +53,38 @@ def _compute_default_rope_parameters(
50
53
  """
51
54
  base = config.rope_theta
52
55
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
53
- head_dim = (
54
- config.head_dim
55
- if hasattr(config, "head_dim") and config.head_dim is not None
56
- else config.hidden_size // config.num_attention_heads
57
- )
56
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
58
57
  dim = int(head_dim * partial_rotary_factor)
59
58
 
60
59
  attention_factor = 1.0 # Unused in this type of RoPE
61
60
 
62
61
  # Compute the inverse frequencies
63
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
62
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
64
63
  return inv_freq, attention_factor
65
64
 
66
65
 
67
66
  def _compute_linear_scaling_rope_parameters(
68
67
  config: Optional[PretrainedConfig] = None,
68
+ device: Optional["torch.device"] = None,
69
69
  seq_len: Optional[int] = None,
70
- ) -> Tuple["torch.Tensor", float]:
70
+ ) -> tuple["torch.Tensor", float]:
71
71
  """
72
72
  Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
73
73
  Args:
74
74
  config ([`~transformers.PretrainedConfig`]):
75
75
  The model configuration.
76
+ device (`torch.device`):
77
+ The device to use for initialization of the inverse frequencies.
76
78
  seq_len (`int`, *optional*):
77
79
  The current sequence length. Unused for this type of RoPE.
78
80
  Returns:
79
81
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
80
82
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
81
83
  """
82
-
83
84
  factor = config.rope_scaling["factor"]
84
85
 
85
86
  # Gets the default RoPE parameters
86
- inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
87
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
87
88
 
88
89
  # Then applies linear scaling to the frequencies.
89
90
  # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
@@ -94,20 +95,23 @@ def _compute_linear_scaling_rope_parameters(
94
95
 
95
96
  def _compute_dynamic_ntk_parameters(
96
97
  config: Optional[PretrainedConfig] = None,
98
+ device: Optional["torch.device"] = None,
97
99
  seq_len: Optional[int] = None,
98
- ) -> Tuple["torch.Tensor", float]:
100
+ ) -> tuple["torch.Tensor", float]:
99
101
  """
100
102
  Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
101
103
  Args:
102
104
  config ([`~transformers.PretrainedConfig`]):
103
105
  The model configuration.
106
+ device (`torch.device`):
107
+ The device to use for initialization of the inverse frequencies.
104
108
  seq_len (`int`, *optional*):
105
109
  The current sequence length, used to update the dynamic RoPE at inference time.
106
110
  Returns:
107
111
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
108
112
  post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
109
113
  """
110
-
114
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
111
115
  base = config.rope_theta
112
116
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
113
117
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@@ -117,6 +121,17 @@ def _compute_dynamic_ntk_parameters(
117
121
 
118
122
  attention_factor = 1.0 # Unused in this type of RoPE
119
123
 
124
+ # seq_len: default to max_position_embeddings, e.g. at init time
125
+ if seq_len is None:
126
+ seq_len = max_position_embeddings
127
+ elif isinstance(seq_len, torch.Tensor):
128
+ seq_len = torch.maximum(
129
+ seq_len,
130
+ torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
131
+ )
132
+ else:
133
+ seq_len = max(seq_len, max_position_embeddings)
134
+
120
135
  # Process with chunk_size to reduce precesion error
121
136
  chunk_size = 4096
122
137
  chunks = (seq_len + chunk_size - 1) // chunk_size
@@ -140,13 +155,17 @@ def _compute_dynamic_ntk_parameters(
140
155
  return final_inv_freq, attention_factor
141
156
 
142
157
 
143
- def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
158
+ def _compute_yarn_parameters(
159
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
160
+ ) -> tuple["torch.Tensor", float]:
144
161
  """
145
162
  Computes the inverse frequencies with NTK scaling. Please refer to the
146
- [original paper](https://arxiv.org/abs/2309.00071)
163
+ [original paper](https://huggingface.co/papers/2309.00071)
147
164
  Args:
148
165
  config ([`~transformers.PretrainedConfig`]):
149
166
  The model configuration.
167
+ device (`torch.device`):
168
+ The device to use for initialization of the inverse frequencies.
150
169
  seq_len (`int`, *optional*):
151
170
  The current sequence length. Unused for this type of RoPE.
152
171
  Returns:
@@ -158,13 +177,25 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
158
177
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
159
178
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
160
179
  dim = int(head_dim * partial_rotary_factor)
161
- max_position_embeddings = config.max_position_embeddings
162
180
  factor = config.rope_scaling["factor"]
181
+ attention_factor = config.rope_scaling.get("attention_factor")
182
+ mscale = config.rope_scaling.get("mscale")
183
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
184
+ original_max_position_embeddings = (
185
+ config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
186
+ )
187
+
188
+ def get_mscale(scale, mscale=1):
189
+ if scale <= 1:
190
+ return 1.0
191
+ return 0.1 * mscale * math.log(scale) + 1.0
163
192
 
164
193
  # Sets the attention factor as suggested in the paper
165
- attention_factor = config.rope_scaling.get("attention_factor")
166
194
  if attention_factor is None:
167
- attention_factor = 0.1 * math.log(factor) + 1.0
195
+ if mscale and mscale_all_dim:
196
+ attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
197
+ else:
198
+ attention_factor = get_mscale(factor)
168
199
 
169
200
  # Optional config options
170
201
  # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
@@ -176,10 +207,13 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
176
207
  """Inverse dimension formula to find the dimension based on the number of rotations"""
177
208
  return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
178
209
 
179
- def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
210
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
180
211
  """Find dimension range bounds based on rotations"""
181
- low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
182
- high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
212
+ low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
213
+ high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
214
+ if truncate:
215
+ low = math.floor(low)
216
+ high = math.ceil(high)
183
217
  return max(low, 0), min(high, dim - 1)
184
218
 
185
219
  def linear_ramp_factor(min, max, dim):
@@ -192,38 +226,40 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
192
226
 
193
227
  # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
194
228
  # to expand the possible context length. In other words, interpolation = apply scaling factor.
195
- pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
229
+ pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
196
230
  inv_freq_extrapolation = 1.0 / pos_freqs
197
231
  inv_freq_interpolation = 1.0 / (factor * pos_freqs)
198
232
 
199
- low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
233
+ truncate = config.rope_scaling.get("truncate", True)
234
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
200
235
 
201
236
  # Get n-dimensional rotational scaling corrected for extrapolation
202
- inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
237
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
203
238
  inv_freq = (
204
239
  inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
205
240
  + inv_freq_extrapolation * inv_freq_extrapolation_factor
206
241
  )
207
-
208
242
  return inv_freq, attention_factor
209
243
 
210
244
 
211
245
  def _compute_longrope_parameters(
212
- config: PretrainedConfig, seq_len: Optional[int] = None
213
- ) -> Tuple["torch.Tensor", float]:
246
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
247
+ ) -> tuple["torch.Tensor", float]:
214
248
  """
215
249
  Computes the inverse frequencies with LongRoPE scaling. Please refer to the
216
250
  [original implementation](https://github.com/microsoft/LongRoPE)
217
251
  Args:
218
252
  config ([`~transformers.PretrainedConfig`]):
219
253
  The model configuration.
254
+ device (`torch.device`):
255
+ The device to use for initialization of the inverse frequencies.
220
256
  seq_len (`int`, *optional*):
221
- The current sequence length. Unused for this type of RoPE.
257
+ The current sequence length.
222
258
  Returns:
223
259
  Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
224
260
  post-processing scaling factor applied to the computed cos/sin.
225
261
  """
226
-
262
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
227
263
  base = config.rope_theta
228
264
  partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
229
265
  head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@@ -237,40 +273,40 @@ def _compute_longrope_parameters(
237
273
  # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
238
274
  # values to compute the default attention scaling factor, instead of using `factor`.
239
275
  if hasattr(config, "original_max_position_embeddings"):
240
- max_position_embeddings = config.original_max_position_embeddings
241
- expanded_max_position_embeddings = config.max_position_embeddings
242
- factor = expanded_max_position_embeddings / max_position_embeddings
276
+ original_max_position_embeddings = config.original_max_position_embeddings
277
+ factor = config.max_position_embeddings / config.original_max_position_embeddings
243
278
  else:
244
- max_position_embeddings = config.max_position_embeddings
245
- expanded_max_position_embeddings = max_position_embeddings * factor
279
+ original_max_position_embeddings = config.max_position_embeddings
246
280
 
247
281
  # Sets the attention factor as suggested in the paper
248
282
  if attention_factor is None:
249
283
  if factor <= 1.0:
250
284
  attention_factor = 1.0
251
285
  else:
252
- attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
286
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
253
287
 
254
288
  # Compute the inverse frequencies -- scaled based on the target sequence length
255
- if expanded_max_position_embeddings > max_position_embeddings:
256
- ext_factors = torch.tensor(long_factor, dtype=torch.float32)
289
+ if seq_len and seq_len > original_max_position_embeddings:
290
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
257
291
  else:
258
- ext_factors = torch.tensor(short_factor, dtype=torch.float32)
259
- inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
292
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
293
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
260
294
  inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
261
295
 
262
296
  return inv_freq, attention_factor
263
297
 
264
298
 
265
299
  def _compute_llama3_parameters(
266
- config: PretrainedConfig, seq_len: Optional[int] = None
267
- ) -> Tuple["torch.Tensor", float]:
300
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
301
+ ) -> tuple["torch.Tensor", float]:
268
302
  """
269
303
  Computes the inverse frequencies for llama 3.1.
270
304
 
271
305
  Args:
272
306
  config ([`~transformers.PretrainedConfig`]):
273
307
  The model configuration.
308
+ device (`torch.device`):
309
+ The device to use for initialization of the inverse frequencies.
274
310
  seq_len (`int`, *optional*):
275
311
  The current sequence length. Unused for this type of RoPE.
276
312
  Returns:
@@ -278,7 +314,7 @@ def _compute_llama3_parameters(
278
314
  post-processing scaling factor applied to the computed cos/sin.
279
315
  """
280
316
  # Gets the default RoPE parameters
281
- inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
317
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
282
318
 
283
319
  factor = config.rope_scaling["factor"] # `8` in the original implementation
284
320
  low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
@@ -88,12 +88,16 @@ _import_structure = {
88
88
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
89
89
  "RBLNQwen2_5_VLForConditionalGeneration",
90
90
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
91
+ "RBLNQwen2_5_VLModel",
92
+ "RBLNQwen2_5_VLModelConfig",
91
93
  ],
92
94
  "qwen2_vl": [
93
95
  "RBLNQwen2VisionTransformerPretrainedModel",
94
96
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
95
97
  "RBLNQwen2VLForConditionalGeneration",
96
98
  "RBLNQwen2VLForConditionalGenerationConfig",
99
+ "RBLNQwen2VLModel",
100
+ "RBLNQwen2VLModelConfig",
97
101
  ],
98
102
  "decoderonly": [
99
103
  "RBLNDecoderOnlyModelConfig",
@@ -110,12 +114,14 @@ _import_structure = {
110
114
  ],
111
115
  "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
112
116
  "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig", "RBLNGemmaModel", "RBLNGemmaModelConfig"],
117
+ "gemma2": ["RBLNGemma2ForCausalLM", "RBLNGemma2ForCausalLMConfig", "RBLNGemma2Model", "RBLNGemma2ModelConfig"],
113
118
  "gemma3": [
114
119
  "RBLNGemma3ForCausalLM",
115
120
  "RBLNGemma3ForCausalLMConfig",
116
121
  "RBLNGemma3ForConditionalGeneration",
117
122
  "RBLNGemma3ForConditionalGenerationConfig",
118
123
  ],
124
+ "gpt_oss": ["RBLNGptOssForCausalLM", "RBLNGptOssForCausalLMConfig"],
119
125
  "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig", "RBLNGPT2Model", "RBLNGPT2ModelConfig"],
120
126
  "idefics3": [
121
127
  "RBLNIdefics3VisionTransformer",
@@ -132,6 +138,12 @@ _import_structure = {
132
138
  "RBLNPegasusForConditionalGenerationConfig",
133
139
  "RBLNPegasusModelConfig",
134
140
  ],
141
+ "paligemma": [
142
+ "RBLNPaliGemmaForConditionalGeneration",
143
+ "RBLNPaliGemmaForConditionalGenerationConfig",
144
+ "RBLNPaliGemmaModel",
145
+ "RBLNPaliGemmaModelConfig",
146
+ ],
135
147
  "llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
136
148
  "midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
137
149
  "pixtral": ["RBLNPixtralVisionModel", "RBLNPixtralVisionModelConfig"],
@@ -143,7 +155,9 @@ _import_structure = {
143
155
  ],
144
156
  "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig", "RBLNPhiModel", "RBLNPhiModelConfig"],
145
157
  "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2Model", "RBLNQwen2ModelConfig"],
158
+ "qwen2_moe": ["RBLNQwen2MoeForCausalLM", "RBLNQwen2MoeForCausalLMConfig"],
146
159
  "qwen3": ["RBLNQwen3ForCausalLM", "RBLNQwen3ForCausalLMConfig", "RBLNQwen3Model", "RBLNQwen3ModelConfig"],
160
+ "qwen3_moe": ["RBLNQwen3MoeForCausalLM", "RBLNQwen3MoeForCausalLMConfig"],
147
161
  "resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
148
162
  "roberta": [
149
163
  "RBLNRobertaForMaskedLM",
@@ -254,6 +268,7 @@ if TYPE_CHECKING:
254
268
  from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
255
269
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
256
270
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
271
+ from .gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig, RBLNGemma2Model, RBLNGemma2ModelConfig
257
272
  from .gemma3 import (
258
273
  RBLNGemma3ForCausalLM,
259
274
  RBLNGemma3ForCausalLMConfig,
@@ -261,6 +276,7 @@ if TYPE_CHECKING:
261
276
  RBLNGemma3ForConditionalGenerationConfig,
262
277
  )
263
278
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
279
+ from .gpt_oss import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
264
280
  from .grounding_dino import (
265
281
  RBLNGroundingDinoDecoder,
266
282
  RBLNGroundingDinoDecoderConfig,
@@ -281,6 +297,12 @@ if TYPE_CHECKING:
281
297
  from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
282
298
  from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
283
299
  from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
300
+ from .paligemma import (
301
+ RBLNPaliGemmaForConditionalGeneration,
302
+ RBLNPaliGemmaForConditionalGenerationConfig,
303
+ RBLNPaliGemmaModel,
304
+ RBLNPaliGemmaModelConfig,
305
+ )
284
306
  from .pegasus import (
285
307
  RBLNPegasusForConditionalGeneration,
286
308
  RBLNPegasusForConditionalGenerationConfig,
@@ -295,14 +317,20 @@ if TYPE_CHECKING:
295
317
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
296
318
  RBLNQwen2_5_VLForConditionalGeneration,
297
319
  RBLNQwen2_5_VLForConditionalGenerationConfig,
320
+ RBLNQwen2_5_VLModel,
321
+ RBLNQwen2_5_VLModelConfig,
298
322
  )
323
+ from .qwen2_moe import RBLNQwen2MoeForCausalLM, RBLNQwen2MoeForCausalLMConfig
299
324
  from .qwen2_vl import (
300
325
  RBLNQwen2VisionTransformerPretrainedModel,
301
326
  RBLNQwen2VisionTransformerPretrainedModelConfig,
302
327
  RBLNQwen2VLForConditionalGeneration,
303
328
  RBLNQwen2VLForConditionalGenerationConfig,
329
+ RBLNQwen2VLModel,
330
+ RBLNQwen2VLModelConfig,
304
331
  )
305
332
  from .qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig, RBLNQwen3Model, RBLNQwen3ModelConfig
333
+ from .qwen3_moe import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
306
334
  from .resnet import RBLNResNetForImageClassification, RBLNResNetForImageClassificationConfig
307
335
  from .roberta import (
308
336
  RBLNRobertaForMaskedLM,
@@ -60,10 +60,10 @@ class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
60
60
  class BartDecoder(Seq2SeqDecoder):
61
61
  has_pos_emb = True
62
62
 
63
- def __post_init__(self):
64
- self.embed_positions = self._original_mod.embed_positions
65
- self.layernorm_embedding = self._original_mod.layernorm_embedding
66
- self.embed_scale = getattr(self._original_mod, "embed_scale", None)
63
+ def __post_init__(self, model: nn.Module):
64
+ self.embed_positions = model.embed_positions
65
+ self.layernorm_embedding = model.layernorm_embedding
66
+ self.embed_scale = getattr(model, "embed_scale", None)
67
67
 
68
68
  def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
69
  if attention_mask is not None:
@@ -112,11 +112,11 @@ class BartLayerFF(nn.Module):
112
112
 
113
113
 
114
114
  class BartDecoderLayer(Seq2SeqDecoderLayer):
115
- def __post_init__(self):
116
- self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
117
- self.encoder_attn = self._original_mod.encoder_attn
118
- self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
119
- self.ff_layer = BartLayerFF(self._original_mod)
115
+ def __post_init__(self, decoder_layer: nn.Module):
116
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
117
+ self.encoder_attn = decoder_layer.encoder_attn
118
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
119
+ self.ff_layer = BartLayerFF(decoder_layer)
120
120
 
121
121
  def pre_self_attn_layer_norm(self, hidden_states):
122
122
  return hidden_states
@@ -132,13 +132,13 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
132
132
 
133
133
 
134
134
  class BartSelfAttention(Seq2SeqSelfAttention):
135
- def __post_init__(self, use_attention_mask: bool = True):
136
- self.q_proj = self._original_mod.q_proj
137
- self.k_proj = self._original_mod.k_proj
138
- self.v_proj = self._original_mod.v_proj
139
- self.out_proj = self._original_mod.out_proj
140
- self.num_heads = self._original_mod.num_heads
141
- self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
135
+ def __post_init__(self, attn: nn.Module, use_attention_mask: bool = True):
136
+ self.q_proj = attn.q_proj
137
+ self.k_proj = attn.k_proj
138
+ self.v_proj = attn.v_proj
139
+ self.out_proj = attn.out_proj
140
+ self.num_heads = attn.num_heads
141
+ self.head_dim = attn.embed_dim // attn.num_heads
142
142
  self.scaling = self.head_dim**-0.5
143
143
  if use_attention_mask:
144
144
  self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
@@ -153,11 +153,11 @@ class BartSelfAttention(Seq2SeqSelfAttention):
153
153
 
154
154
 
155
155
  class BartCrossAttention(Seq2SeqCrossAttention):
156
- def __post_init__(self):
157
- self.q_proj = self._original_mod.q_proj
158
- self.k_proj = self._original_mod.k_proj
159
- self.v_proj = self._original_mod.v_proj
160
- self.out_proj = self._original_mod.out_proj
161
- self.num_heads = self._original_mod.num_heads
162
- self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
163
- self.embed_dim = self._original_mod.embed_dim
156
+ def __post_init__(self, attn: nn.Module):
157
+ self.q_proj = attn.q_proj
158
+ self.k_proj = attn.k_proj
159
+ self.v_proj = attn.v_proj
160
+ self.out_proj = attn.out_proj
161
+ self.num_heads = attn.num_heads
162
+ self.head_dim = attn.embed_dim // attn.num_heads
163
+ self.embed_dim = attn.embed_dim
@@ -77,11 +77,11 @@ class ColPaliModel(nn.Module):
77
77
  self, model, layers: List["ColPaliLayer"], output_hidden_states: bool = False, max_seq_len: int = 2048
78
78
  ):
79
79
  super().__init__()
80
- self._original_mod = model
81
80
  self.layers = nn.ModuleList(layers)
82
81
  self.output_hidden_states = output_hidden_states
83
- self.norm = self._original_mod.norm
84
- self.hidden_size = self._original_mod.config.hidden_size
82
+ self.config = model.config
83
+ self.norm = model.norm
84
+ self.hidden_size = self.config.hidden_size
85
85
  self.max_seq_len = max_seq_len
86
86
 
87
87
  def forward(
@@ -118,7 +118,6 @@ class ColPaliModel(nn.Module):
118
118
  class ColPaliLayer(nn.Module):
119
119
  def __init__(self, layer, self_attn: "ColPaliAttention"):
120
120
  super().__init__()
121
- self._original_mod = layer
122
121
  self.self_attn = self_attn
123
122
  self.mlp = layer.mlp
124
123
  self.input_layernorm = layer.input_layernorm
@@ -155,27 +154,22 @@ class ColPaliLayer(nn.Module):
155
154
  class ColPaliAttention(nn.Module):
156
155
  def __init__(self, self_attn):
157
156
  super().__init__()
158
- self._original_mod = self_attn
159
- self.num_heads = (
160
- getattr(self._original_mod, "num_heads", None) or self._original_mod.config.num_attention_heads
161
- )
162
- self.head_dim = self._original_mod.head_dim
157
+ self.config = self_attn.config
158
+ self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
159
+ self.head_dim = self_attn.head_dim
163
160
  self.scaling = self.head_dim**-0.5
164
161
 
165
- if hasattr(self._original_mod, "num_key_value_heads"):
166
- self.num_key_value_heads = self._original_mod.num_key_value_heads
167
- elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
168
- self.num_key_value_heads = self._original_mod.config.num_key_value_heads
162
+ if hasattr(self_attn, "num_key_value_heads"):
163
+ self.num_key_value_heads = self_attn.num_key_value_heads
164
+ elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
165
+ self.num_key_value_heads = self_attn.config.num_key_value_heads
169
166
  else:
170
167
  self.num_key_value_heads = self.num_heads
171
168
 
172
- self.__post_init__()
173
-
174
- def __post_init__(self):
175
- self.q_proj = self._original_mod.q_proj
176
- self.k_proj = self._original_mod.k_proj
177
- self.v_proj = self._original_mod.v_proj
178
- self.o_proj = self._original_mod.o_proj
169
+ self.q_proj = self_attn.q_proj
170
+ self.k_proj = self_attn.k_proj
171
+ self.v_proj = self_attn.v_proj
172
+ self.o_proj = self_attn.o_proj
179
173
 
180
174
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
175
  query_states = self.q_proj(hidden_states)
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, List, Optional, Union
14
+ from typing import Any, Optional
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
17
  from ....utils.logging import get_logger
@@ -33,7 +33,9 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
33
33
 
34
34
  # Create a configuration object
35
35
  config = RBLNColPaliForRetrievalConfig(
36
- max_seq_lens=1152,
36
+ vlm={
37
+ "language_model": {"prefill_chunk_size": 8192},
38
+ }
37
39
  output_hidden_states=False,
38
40
  tensor_parallel_size=4
39
41
  )
@@ -47,24 +49,21 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
47
49
  ```
48
50
  """
49
51
 
50
- submodules = ["vision_tower"]
52
+ _allow_no_compile_cfgs = True
53
+ submodules = ["vlm"]
51
54
 
52
55
  def __init__(
53
56
  self,
54
57
  batch_size: Optional[int] = None,
55
- max_seq_lens: Union[int, List[int]] = None,
58
+ vlm: Optional[RBLNModelConfig] = None,
56
59
  output_hidden_states: Optional[bool] = None,
57
- vision_tower: Optional[RBLNModelConfig] = None,
58
60
  **kwargs: Any,
59
61
  ):
60
62
  """
61
63
  Args:
62
64
  batch_size (Optional[int]): The batch size for the model.
63
- vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
64
- max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
65
- This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
66
- output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
67
- vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
65
+ vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
66
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
68
67
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
69
68
  Raises:
70
69
  ValueError: If batch_size is not a positive integer.
@@ -74,11 +73,7 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
74
73
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
74
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
75
 
77
- if self.batch_size != 1:
78
- logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
-
80
- self.vision_tower = self.initialize_submodule_config(
81
- submodule_config=vision_tower, batch_size=1, force_kwargs=True
76
+ self.output_hidden_states = output_hidden_states or False
77
+ self.vlm = self.initialize_submodule_config(
78
+ submodule_config=vlm, batch_size=batch_size, output_hidden_states=output_hidden_states
82
79
  )
83
- self.max_seq_lens = max_seq_lens
84
- self.output_hidden_states = output_hidden_states