optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +48 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +35 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  27. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  28. optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
  29. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  30. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  31. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  32. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  33. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  34. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  35. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  36. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  37. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  38. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  39. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  40. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  41. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  42. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  43. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  44. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  45. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  46. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  47. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
  49. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  51. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  52. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  53. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  54. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  55. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  56. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  57. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  58. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  59. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  60. optimum/rbln/utils/depreacate_utils.py +16 -0
  61. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
  62. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
  63. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,379 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ AutoModelForImageTextToText,
21
+ LlavaForConditionalGeneration,
22
+ PretrainedConfig,
23
+ PreTrainedModel,
24
+ )
25
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
26
+ from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
27
+
28
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
29
+ from ....modeling import RBLNModel
30
+ from ....utils.logging import get_logger
31
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyForCausalLMOutput
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import (
38
+ AutoFeatureExtractor,
39
+ AutoProcessor,
40
+ AutoTokenizer,
41
+ PretrainedConfig,
42
+ )
43
+
44
+
45
+ class LoopVisionTower:
46
+ def __init__(self, vision_tower: RBLNModel) -> None:
47
+ self.vision_tower = vision_tower
48
+
49
+ def forward(self, *args, **kwargs):
50
+ pixel_values = args[0]
51
+ image_sizes = kwargs.pop("image_sizes", None)
52
+
53
+ outputs = []
54
+ for i in range(pixel_values.shape[0]):
55
+ outputs.append(
56
+ self.vision_tower(
57
+ pixel_values[i : i + 1], image_sizes[i : i + 1] if image_sizes is not None else None, **kwargs
58
+ )
59
+ )
60
+
61
+ if hasattr(self.vision_tower.rbln_config, "max_image_size"):
62
+ last_hidden_states = [output.last_hidden_state for output in outputs]
63
+ last_hidden_states = torch.cat(last_hidden_states, dim=1)
64
+ hidden_states = tuple(
65
+ torch.cat(
66
+ [output.hidden_states[layer_idx] for output in outputs],
67
+ dim=1,
68
+ )
69
+ for layer_idx in range(len(outputs[0].hidden_states))
70
+ )
71
+
72
+ else:
73
+ last_hidden_states = [output.last_hidden_state for output in outputs]
74
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
75
+ hidden_states = [output.hidden_states for output in outputs]
76
+ hidden_states = tuple(
77
+ torch.cat(tuple((hidden_states[n][i] for n in range(pixel_values.shape[0]))), dim=0)
78
+ for i in range(len(hidden_states[0]))
79
+ )
80
+
81
+ return BaseModelOutputWithPooling(
82
+ last_hidden_state=last_hidden_states,
83
+ hidden_states=hidden_states,
84
+ )
85
+
86
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
87
+ return self.forward(*args, **kwds)
88
+
89
+ def __repr__(self) -> str:
90
+ return repr(self.vision_tower)
91
+
92
+
93
+ class LoopProjector:
94
+ def __init__(self, multi_modal_projector) -> None:
95
+ self.multi_modal_projector = multi_modal_projector
96
+
97
+ def forward(self, *args, **kwargs):
98
+ # Loop instead of batch
99
+ image_feature = args[0]
100
+
101
+ outputs = []
102
+ for i in range(image_feature.shape[0]):
103
+ outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
104
+
105
+ # FIXME:: This can be optimized using out= API of rbln runtime.
106
+ outputs = torch.cat(outputs, dim=0)
107
+ return outputs
108
+
109
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
110
+ return self.forward(*args, **kwds)
111
+
112
+ def __repr__(self) -> str:
113
+ return repr(self.multi_modal_projector)
114
+
115
+
116
+ class RBLNLlavaForConditionalGeneration(RBLNModel):
117
+ auto_model_class = AutoModelForImageTextToText
118
+ _rbln_submodules = [
119
+ {"name": "vision_tower"},
120
+ {"name": "language_model"},
121
+ ]
122
+
123
+ def __getattr__(self, __name: str) -> Any:
124
+ def redirect(func):
125
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
126
+
127
+ val = getattr(LlavaForConditionalGeneration, __name)
128
+
129
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
130
+ return redirect(val)
131
+ return val
132
+
133
+ def can_generate(self):
134
+ return True
135
+
136
+ def __post_init__(self, **kwargs):
137
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
138
+ self.language_model = self.rbln_submodules[1]
139
+ self.multi_modal_projector = LoopProjector(self.model[0])
140
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
141
+ return super().__post_init__(**kwargs)
142
+
143
+ def get_attn_impl(self) -> str:
144
+ return self.rbln_config.language_model.attn_impl
145
+
146
+ def get_kvcache_num_blocks(self) -> int:
147
+ return self.rbln_config.language_model.kvcache_num_blocks
148
+
149
+ def get_input_embeddings(self):
150
+ return self.language_model.get_input_embeddings()
151
+
152
+ @classmethod
153
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
154
+ return model.multi_modal_projector
155
+
156
+ @classmethod
157
+ def _update_rbln_config(
158
+ cls,
159
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
160
+ model: Optional["PreTrainedModel"] = None,
161
+ model_config: Optional["PretrainedConfig"] = None,
162
+ rbln_config: Optional[RBLNModelConfig] = None,
163
+ ) -> RBLNModelConfig:
164
+ if hasattr(rbln_config.vision_tower, "max_image_size"):
165
+ num_positions = (
166
+ rbln_config.vision_tower.batch_size
167
+ * (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size)
168
+ * (rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size)
169
+ )
170
+ selected_image_feature_dim = num_positions
171
+
172
+ else:
173
+ num_positions = (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2 + 1
174
+ selected_image_feature_dim = num_positions - 1
175
+
176
+ input_info = [
177
+ (
178
+ "image_features",
179
+ [rbln_config.batch_size, selected_image_feature_dim, model_config.vision_config.hidden_size],
180
+ "float32",
181
+ )
182
+ ]
183
+
184
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
185
+ rbln_config.set_compile_cfgs([rbln_compile_config])
186
+ return rbln_config
187
+
188
+ def prepare_inputs_for_generation(
189
+ self,
190
+ input_ids,
191
+ inputs_embeds=None,
192
+ pixel_values=None,
193
+ attention_mask=None,
194
+ cache_position=None,
195
+ image_sizes=None,
196
+ generate_idx=None,
197
+ **kwargs,
198
+ ):
199
+ is_prefill_phase = generate_idx is None
200
+ model_inputs = {}
201
+
202
+ if is_prefill_phase:
203
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
204
+ cache_position = None
205
+ pixel_values = pixel_values
206
+ model_inputs.update({"image_sizes": image_sizes})
207
+ else:
208
+ if inputs_embeds is not None:
209
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
210
+
211
+ pixel_values = None
212
+ input_ids = input_ids[:, -1:]
213
+ cache_position = generate_idx
214
+ generate_idx = generate_idx + 1
215
+ model_inputs.update({"input_ids": input_ids})
216
+
217
+ if inputs_embeds is not None:
218
+ if self.rbln_config.use_inputs_embeds:
219
+ model_inputs.update({"inputs_embeds": inputs_embeds})
220
+ else:
221
+ raise ValueError(
222
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
223
+ )
224
+ else:
225
+ model_inputs.update({"input_ids": input_ids})
226
+
227
+ model_inputs.update(
228
+ {
229
+ "attention_mask": attention_mask,
230
+ "pixel_values": pixel_values,
231
+ "cache_position": cache_position,
232
+ "generate_idx": generate_idx,
233
+ }
234
+ )
235
+ return model_inputs
236
+
237
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
238
+ model_kwargs["generate_idx"] = outputs.generate_idx
239
+ return model_kwargs
240
+
241
+ def get_image_features(
242
+ self,
243
+ pixel_values: torch.FloatTensor,
244
+ vision_feature_layer: Union[int, List[int]],
245
+ vision_feature_select_strategy: str,
246
+ **kwargs,
247
+ ):
248
+ if vision_feature_select_strategy not in ["default", "full"]:
249
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
250
+
251
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
252
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
253
+
254
+ if isinstance(vision_feature_layer, int):
255
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
256
+ if vision_feature_select_strategy == "default":
257
+ selected_image_feature = selected_image_feature[:, 1:]
258
+ else:
259
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
260
+ if vision_feature_select_strategy == "default":
261
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
262
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
263
+
264
+ if hasattr(self.rbln_config.vision_tower, "max_image_size"):
265
+ num_real_patches = selected_image_feature.shape[1]
266
+ max_patches = (
267
+ (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size)
268
+ * (self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size)
269
+ * pixel_values.shape[0]
270
+ )
271
+ num_padding_patches = max_patches - num_real_patches
272
+
273
+ padding_tensor = torch.zeros(
274
+ (selected_image_feature.shape[0], num_padding_patches, selected_image_feature.shape[2]),
275
+ dtype=selected_image_feature.dtype,
276
+ )
277
+ padded_feature = torch.cat([selected_image_feature, padding_tensor], dim=1)
278
+ padded_projected_feature = self.multi_modal_projector(padded_feature)
279
+ image_features = padded_projected_feature[:, :num_real_patches, :]
280
+ else:
281
+ image_features = self.multi_modal_projector(selected_image_feature)
282
+
283
+ return image_features
284
+
285
+ def _preprocess_prefill(
286
+ self,
287
+ input_ids: Optional[torch.LongTensor] = None,
288
+ pixel_values: Optional[torch.FloatTensor] = None,
289
+ inputs_embeds: Optional[torch.FloatTensor] = None,
290
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
291
+ vision_feature_select_strategy: Optional[str] = None,
292
+ return_dict: Optional[bool] = None,
293
+ image_sizes: Optional[torch.Tensor] = None,
294
+ **lm_kwargs,
295
+ ):
296
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
297
+ vision_feature_layer = (
298
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
299
+ )
300
+ vision_feature_select_strategy = (
301
+ vision_feature_select_strategy
302
+ if vision_feature_select_strategy is not None
303
+ else self.config.vision_feature_select_strategy
304
+ )
305
+
306
+ if (input_ids is None) ^ (inputs_embeds is not None):
307
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
308
+
309
+ if pixel_values is not None and inputs_embeds is not None:
310
+ raise ValueError(
311
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
312
+ )
313
+
314
+ if inputs_embeds is None:
315
+ inputs_embeds = self.get_input_embeddings()(input_ids)
316
+
317
+ if pixel_values is not None:
318
+ image_features = self.get_image_features(
319
+ pixel_values=pixel_values,
320
+ vision_feature_layer=vision_feature_layer,
321
+ vision_feature_select_strategy=vision_feature_select_strategy,
322
+ image_sizes=image_sizes,
323
+ )
324
+
325
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
326
+ special_image_mask = special_image_mask.expand_as(inputs_embeds)
327
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
328
+
329
+ return inputs_embeds
330
+
331
+ def forward(
332
+ self,
333
+ input_ids: Optional[torch.LongTensor] = None,
334
+ pixel_values: Optional[torch.FloatTensor] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ inputs_embeds: Optional[torch.FloatTensor] = None,
337
+ return_dict: Optional[bool] = None,
338
+ cache_position: Optional[torch.LongTensor] = None,
339
+ image_sizes: Optional[torch.Tensor] = None,
340
+ generate_idx: Optional[torch.Tensor] = None,
341
+ **kwargs,
342
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
343
+ # Prefill
344
+ if cache_position is None:
345
+ inputs_embeds = self._preprocess_prefill(
346
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_sizes=image_sizes
347
+ )
348
+ logits = []
349
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
350
+ batch_size = inputs.shape[0]
351
+
352
+ for b_idx in range(batch_size):
353
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
354
+ output = self.language_model.prefill_decoder(
355
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
356
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
357
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
358
+ cache_position=cache_position,
359
+ batch_idx=b_idx,
360
+ )
361
+ logits.append(output.logits)
362
+
363
+ logits = torch.cat(logits, dim=0)
364
+
365
+ # Decoder
366
+ else:
367
+ logits = self.language_model.decoder(
368
+ input_ids=input_ids,
369
+ inputs_embeds=inputs_embeds,
370
+ cache_position=cache_position,
371
+ ).logits
372
+
373
+ if not return_dict:
374
+ return logits, generate_idx
375
+ else:
376
+ return RBLNDecoderOnlyForCausalLMOutput(
377
+ logits=logits,
378
+ generate_idx=generate_idx,
379
+ )
@@ -29,7 +29,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
29
29
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
32
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyForCausalLMOutput
33
33
 
34
34
 
35
35
  logger = get_logger(__name__)
@@ -258,7 +258,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
258
258
 
259
259
  def _update_model_kwargs_for_generation(
260
260
  self,
261
- outputs: RBLNDecoderOnlyOutput,
261
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
262
262
  model_kwargs: Dict[str, Any],
263
263
  **kwargs,
264
264
  ) -> Dict[str, Any]:
@@ -359,7 +359,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
359
359
  generate_idx: Optional[torch.Tensor] = None,
360
360
  batch_idx: Optional[int] = None,
361
361
  **kwargs,
362
- ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
362
+ ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
363
363
  vision_feature_layer = (
364
364
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
365
365
  )
@@ -418,7 +418,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
418
418
  cache_position=cache_position,
419
419
  )
420
420
  logits = output.logits
421
- return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
421
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits, generate_idx=generate_idx)
422
422
 
423
423
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
424
424
  def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_mistral import RBLNMistralForCausalLMConfig
16
- from .modeling_mistral import RBLNMistralForCausalLM
15
+ from .configuration_mistral import RBLNMistralForCausalLMConfig, RBLNMistralModelConfig
16
+ from .modeling_mistral import RBLNMistralForCausalLM, RBLNMistralModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNMistralModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Mistral models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -15,5 +15,5 @@
15
15
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
16
16
 
17
17
 
18
- class MistralForCausalLMWrapper(DecoderOnlyWrapper):
18
+ class MistralWrapper(DecoderOnlyWrapper):
19
19
  pass
@@ -15,8 +15,12 @@
15
15
  from transformers import PretrainedConfig
16
16
 
17
17
  from ....utils import logging
18
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
19
- from .mistral_architecture import MistralForCausalLMWrapper
18
+ from ...models.decoderonly import (
19
+ RBLNDecoderOnlyModel,
20
+ RBLNDecoderOnlyModelForCausalLM,
21
+ RBLNDecoderOnlyModelForCausalLMConfig,
22
+ )
23
+ from .mistral_architecture import MistralWrapper
20
24
 
21
25
 
22
26
  logger = logging.get_logger(__name__)
@@ -79,7 +83,26 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
79
83
  ```
80
84
  """
81
85
 
82
- _decoder_wrapper_cls = MistralForCausalLMWrapper
86
+ _decoder_wrapper_cls = MistralWrapper
87
+
88
+ @classmethod
89
+ def _update_sliding_window_config(
90
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
+ ):
92
+ rbln_config.cache_impl = "sliding_window"
93
+ rbln_config.sliding_window = model_config.sliding_window
94
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
+
96
+ return rbln_config
97
+
98
+
99
+ class RBLNMistralModel(RBLNDecoderOnlyModel):
100
+ """
101
+ The Mistral Model transformer without a language modeling head.
102
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
103
+ """
104
+
105
+ _decoder_wrapper_cls = MistralWrapper
83
106
 
84
107
  @classmethod
85
108
  def _update_sliding_window_config(
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_opt import RBLNOPTForCausalLMConfig
16
- from .modeling_opt import RBLNOPTForCausalLM
15
+ from .configuration_opt import RBLNOPTForCausalLMConfig, RBLNOPTModelConfig
16
+ from .modeling_opt import RBLNOPTForCausalLM, RBLNOPTModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -20,3 +20,10 @@ class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
20
20
  Configuration class for OPT causal language model.
21
21
  Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
22
  """
23
+
24
+
25
+ class RBLNOPTModelConfig(RBLNDecoderOnlyModelConfig):
26
+ """
27
+ Configuration class for OPT model.
28
+ Inherits from RBLNDecoderOnlyModelConfig with no additional parameters.
29
+ """
@@ -16,7 +16,7 @@ import torch.nn as nn
16
16
  from transformers import PreTrainedModel
17
17
 
18
18
  from ....utils import logging
19
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
19
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
20
20
  from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
21
21
  from .opt_architecture import OPTWrapper
22
22
 
@@ -88,3 +88,43 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
88
88
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
89
89
 
90
90
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
91
+
92
+
93
+ class RBLNOPTModel(RBLNDecoderOnlyModel):
94
+ """
95
+ The OPT Model transformer without a language modeling head.
96
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
97
+ """
98
+
99
+ _decoder_wrapper_cls = OPTWrapper
100
+ _use_rotary_emb = False
101
+
102
+ def modify_opt_decoder_layer(layer):
103
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
104
+ layer.mlp = mlp
105
+ del layer.fc1
106
+ del layer.fc2
107
+ del layer.activation_fn
108
+
109
+ return layer
110
+
111
+ @classmethod
112
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
113
+ wrapper_cfg = {
114
+ "max_seq_len": rbln_config.max_seq_len,
115
+ "attn_impl": rbln_config.attn_impl,
116
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
117
+ "kvcache_block_size": rbln_config.kvcache_block_size,
118
+ "use_rotary_emb": cls._use_rotary_emb,
119
+ "use_attention_mask": rbln_config.use_attention_mask,
120
+ "use_position_ids": rbln_config.use_position_ids,
121
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
122
+ "cache_impl": rbln_config.cache_impl,
123
+ "sliding_window": rbln_config.sliding_window,
124
+ "sliding_window_layers": rbln_config.sliding_window_layers,
125
+ }
126
+
127
+ for i in range(len(model.decoder.layers)):
128
+ model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
129
+
130
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
@@ -40,11 +40,11 @@ class OPTWrapper(DecoderOnlyWrapper):
40
40
  def get_rbln_model_class(self):
41
41
  return OPTModel
42
42
 
43
- def get_model_layer(self, causal_lm: "OPTForCausalLM"):
44
- return causal_lm.model.decoder
43
+ def get_model_layer(self, model: "OPTForCausalLM"):
44
+ return model.model.decoder if self.is_causal_lm else model.decoder
45
45
 
46
- def get_decoder_layers(self, causal_lm: "OPTForCausalLM"):
47
- return causal_lm.model.decoder.layers
46
+ def get_decoder_layers(self, model: "OPTForCausalLM"):
47
+ return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
48
48
 
49
49
 
50
50
  class OPTAttention(DecoderOnlyAttention):
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig, RBLNPegasusModelConfig
17
+ from .modeling_pegasus import RBLNPegasusForConditionalGeneration, RBLNPegasusModel
@@ -0,0 +1,34 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ """
21
+ Configuration class for RBLNPegasusModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized PEGASUS models for feature extraction tasks.
25
+ """
26
+
27
+
28
+ class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
29
+ """
30
+ Configuration class for RBLNPegasusForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized PEGASUS models for conditional text generation tasks.
34
+ """