optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (57) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +40 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +31 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +204 -44
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +124 -208
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +565 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +0 -6
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +10 -6
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  27. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  28. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  29. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  30. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  31. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  32. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  33. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  34. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  35. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  36. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  37. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  38. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  39. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  40. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  41. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  42. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  43. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  44. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  45. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  46. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  47. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  48. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  49. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  50. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  51. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  53. optimum/rbln/utils/depreacate_utils.py +16 -0
  54. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/METADATA +1 -1
  55. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/RECORD +57 -51
  56. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/WHEEL +0 -0
  57. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/licenses/LICENSE +0 -0
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_gemma import RBLNGemmaForCausalLMConfig
16
- from .modeling_gemma import RBLNGemmaForCausalLM
15
+ from .configuration_gemma import RBLNGemmaForCausalLMConfig, RBLNGemmaModelConfig
16
+ from .modeling_gemma import RBLNGemmaForCausalLM, RBLNGemmaModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNGemmaModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Gemma models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .gemma_architecture import GemmaWrapper
18
18
 
19
19
 
@@ -81,3 +81,15 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = GemmaWrapper
84
+
85
+
86
+ class RBLNGemmaModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Gemma Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+
91
+ A class to convert and run pre-trained transformers based GemmaModel model on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers GemmaModel model into a RBLN transformer model by:
93
+ """
94
+
95
+ _decoder_wrapper_cls = GemmaWrapper
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  from typing import Any, Dict, Optional
15
15
 
16
- import rebel
17
-
18
16
  from ....configuration_utils import RBLNModelConfig
19
17
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
20
18
  from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
@@ -39,10 +37,6 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
39
37
  )
40
38
  self.image_prefill_chunk_size = image_prefill_chunk_size
41
39
 
42
- npu = self.npu or rebel.get_npu_name()
43
- if npu == "RBLN-CA02":
44
- raise NotImplementedError("Gemma3 is currently not supported on RBLN-CA02")
45
-
46
40
 
47
41
  class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
48
42
  submodules = ["vision_tower", "language_model"]
@@ -31,7 +31,11 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
31
31
 
32
32
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
33
  from ....modeling import RBLNModel
34
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
34
+ from ..decoderonly.modeling_decoderonly import (
35
+ RBLNDecoderOnlyForCausalLMOutput,
36
+ RBLNDecoderOnlyModelForCausalLM,
37
+ RBLNRuntimeModel,
38
+ )
35
39
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
36
40
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
41
 
@@ -41,7 +45,7 @@ if TYPE_CHECKING:
41
45
 
42
46
 
43
47
  @dataclass
44
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
48
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
45
49
  attention_mask: Optional[torch.Tensor] = None
46
50
 
47
51
 
@@ -197,7 +201,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
197
201
 
198
202
  def _update_model_kwargs_for_generation(
199
203
  self,
200
- outputs: RBLNDecoderOnlyOutput,
204
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
201
205
  model_kwargs: Dict[str, Any],
202
206
  **kwargs,
203
207
  ) -> Dict[str, Any]:
@@ -266,7 +270,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
266
270
  position_ids: Optional[torch.Tensor] = None,
267
271
  token_type_ids: Optional[torch.Tensor] = None,
268
272
  **lm_kwargs: Dict[str, Any],
269
- ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
273
+ ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
270
274
  # prefill
271
275
  if cache_position is None:
272
276
  logits = []
@@ -304,7 +308,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
304
308
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
305
309
  ).logits
306
310
 
307
- return RBLNDecoderOnlyOutput(
311
+ return RBLNDecoderOnlyForCausalLMOutput(
308
312
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
309
313
  )
310
314
 
@@ -509,7 +513,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
509
513
 
510
514
  logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
511
515
 
512
- return RBLNDecoderOnlyOutput(logits=logits)
516
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
513
517
 
514
518
 
515
519
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig
16
- from .modeling_gpt2 import RBLNGPT2LMHeadModel
15
+ from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig, RBLNGPT2ModelConfig
16
+ from .modeling_gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2Model
@@ -12,11 +12,39 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
19
  """
20
- Configuration class for GPT-2 causal language model.
21
- Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
20
+ Configuration class for RBLN GPT2 models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+ """
24
+
25
+
26
+ class RBLNGPT2ModelConfig(RBLNDecoderOnlyModelConfig):
27
+ """
28
+ Configuration class for RBLN GPT2 models.
29
+
30
+ This class is an alias of RBLNDecoderOnlyModelConfig.
31
+
32
+ Example usage:
33
+ ```python
34
+ from optimum.rbln import RBLNGPT2Model, RBLNGPT2ModelConfig
35
+
36
+ # Create a configuration object
37
+ config = RBLNGPT2ModelConfig(
38
+ batch_size=1,
39
+ max_seq_len=1024,
40
+ tensor_parallel_size=4
41
+ )
42
+
43
+ # Use the configuration with from_pretrained
44
+ model = RBLNGPT2Model.from_pretrained(
45
+ "openai/gpt2",
46
+ export=True,
47
+ rbln_config=config
48
+ )
49
+ ```
22
50
  """
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
- from typing import TYPE_CHECKING, Tuple
16
+ from typing import TYPE_CHECKING, Tuple, Union
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
- from transformers import PreTrainedModel
21
20
 
22
21
  from ..decoderonly.decoderonly_architecture import (
23
22
  DecoderOnlyAttention,
@@ -28,7 +27,7 @@ from ..decoderonly.decoderonly_architecture import (
28
27
 
29
28
 
30
29
  if TYPE_CHECKING:
31
- from transformers import GPT2LMHeadModel
30
+ from transformers import GPT2LMHeadModel, GPT2Model
32
31
 
33
32
 
34
33
  class GPT2Wrapper(DecoderOnlyWrapper):
@@ -44,11 +43,11 @@ class GPT2Wrapper(DecoderOnlyWrapper):
44
43
  def get_attn_layer(self, layer: nn.Module):
45
44
  return layer.attn
46
45
 
47
- def get_model_layer(self, causal_lm: "GPT2LMHeadModel"):
48
- return causal_lm.transformer
46
+ def get_model_layer(self, model: Union["GPT2LMHeadModel", "GPT2Model"]):
47
+ return model.transformer if self.is_causal_lm else model
49
48
 
50
- def get_decoder_layers(self, causal_lm: PreTrainedModel):
51
- return causal_lm.transformer.h
49
+ def get_decoder_layers(self, model: Union["GPT2LMHeadModel", "GPT2Model"]):
50
+ return model.transformer.h if self.is_causal_lm else model.h
52
51
 
53
52
 
54
53
  class GPT2Model(DecoderOnlyModel):
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .gpt2_architecture import GPT2Wrapper
18
18
 
19
19
 
@@ -36,3 +36,18 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
36
36
 
37
37
  _decoder_wrapper_cls = GPT2Wrapper
38
38
  _use_rotary_emb = False
39
+
40
+
41
+ class RBLNGPT2Model(RBLNDecoderOnlyModel):
42
+ """
43
+ The GPT2 Model transformer without a language modeling head.
44
+
45
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the
46
+ library implements for all its model.
47
+
48
+ A class to convert and run pre-trained transformers based GPT2Model model on RBLN devices.
49
+ It implements the methods to convert a pre-trained transformers GPT2Model model into a RBLN transformer model by:
50
+ """
51
+
52
+ _decoder_wrapper_cls = GPT2Wrapper
53
+ _use_rotary_emb = False
@@ -35,7 +35,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
37
  from ..decoderonly.modeling_decoderonly import (
38
- RBLNDecoderOnlyOutput,
38
+ RBLNDecoderOnlyForCausalLMOutput,
39
39
  )
40
40
 
41
41
 
@@ -494,7 +494,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
494
494
  if not return_dict:
495
495
  return logits, generate_idx
496
496
  else:
497
- return RBLNDecoderOnlyOutput(
497
+ return RBLNDecoderOnlyForCausalLMOutput(
498
498
  logits=logits,
499
499
  generate_idx=generate_idx,
500
500
  )
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_llama import RBLNLlamaForCausalLMConfig
16
- from .modeling_llama import RBLNLlamaForCausalLM
15
+ from .configuration_llama import RBLNLlamaForCausalLMConfig, RBLNLlamaModelConfig
16
+ from .modeling_llama import RBLNLlamaForCausalLM, RBLNLlamaModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNLlamaModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Llama models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .llama_architecture import LlamaWrapper
18
18
 
19
19
 
@@ -81,3 +81,15 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = LlamaWrapper
84
+
85
+
86
+ class RBLNLlamaModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Llama Model transformer with a language modeling head (linear layer) on top.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+
91
+ A class to convert and run pre-trained transformers based LlamaModel model on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers LlamaModel model into a RBLN transformer model by:
93
+ """
94
+
95
+ _decoder_wrapper_cls = LlamaWrapper
@@ -29,7 +29,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
29
29
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
32
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyForCausalLMOutput
33
33
 
34
34
 
35
35
  logger = get_logger(__name__)
@@ -258,7 +258,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
258
258
 
259
259
  def _update_model_kwargs_for_generation(
260
260
  self,
261
- outputs: RBLNDecoderOnlyOutput,
261
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
262
262
  model_kwargs: Dict[str, Any],
263
263
  **kwargs,
264
264
  ) -> Dict[str, Any]:
@@ -359,7 +359,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
359
359
  generate_idx: Optional[torch.Tensor] = None,
360
360
  batch_idx: Optional[int] = None,
361
361
  **kwargs,
362
- ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
362
+ ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
363
363
  vision_feature_layer = (
364
364
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
365
365
  )
@@ -418,7 +418,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
418
418
  cache_position=cache_position,
419
419
  )
420
420
  logits = output.logits
421
- return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
421
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits, generate_idx=generate_idx)
422
422
 
423
423
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
424
424
  def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_mistral import RBLNMistralForCausalLMConfig
16
- from .modeling_mistral import RBLNMistralForCausalLM
15
+ from .configuration_mistral import RBLNMistralForCausalLMConfig, RBLNMistralModelConfig
16
+ from .modeling_mistral import RBLNMistralForCausalLM, RBLNMistralModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNMistralModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Mistral models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -15,5 +15,5 @@
15
15
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
16
16
 
17
17
 
18
- class MistralForCausalLMWrapper(DecoderOnlyWrapper):
18
+ class MistralWrapper(DecoderOnlyWrapper):
19
19
  pass
@@ -15,8 +15,12 @@
15
15
  from transformers import PretrainedConfig
16
16
 
17
17
  from ....utils import logging
18
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
19
- from .mistral_architecture import MistralForCausalLMWrapper
18
+ from ...models.decoderonly import (
19
+ RBLNDecoderOnlyModel,
20
+ RBLNDecoderOnlyModelForCausalLM,
21
+ RBLNDecoderOnlyModelForCausalLMConfig,
22
+ )
23
+ from .mistral_architecture import MistralWrapper
20
24
 
21
25
 
22
26
  logger = logging.get_logger(__name__)
@@ -79,7 +83,26 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
79
83
  ```
80
84
  """
81
85
 
82
- _decoder_wrapper_cls = MistralForCausalLMWrapper
86
+ _decoder_wrapper_cls = MistralWrapper
87
+
88
+ @classmethod
89
+ def _update_sliding_window_config(
90
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
+ ):
92
+ rbln_config.cache_impl = "sliding_window"
93
+ rbln_config.sliding_window = model_config.sliding_window
94
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
+
96
+ return rbln_config
97
+
98
+
99
+ class RBLNMistralModel(RBLNDecoderOnlyModel):
100
+ """
101
+ The Mistral Model transformer without a language modeling head.
102
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
103
+ """
104
+
105
+ _decoder_wrapper_cls = MistralWrapper
83
106
 
84
107
  @classmethod
85
108
  def _update_sliding_window_config(
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_opt import RBLNOPTForCausalLMConfig
16
- from .modeling_opt import RBLNOPTForCausalLM
15
+ from .configuration_opt import RBLNOPTForCausalLMConfig, RBLNOPTModelConfig
16
+ from .modeling_opt import RBLNOPTForCausalLM, RBLNOPTModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -20,3 +20,10 @@ class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
20
20
  Configuration class for OPT causal language model.
21
21
  Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
22
  """
23
+
24
+
25
+ class RBLNOPTModelConfig(RBLNDecoderOnlyModelConfig):
26
+ """
27
+ Configuration class for OPT model.
28
+ Inherits from RBLNDecoderOnlyModelConfig with no additional parameters.
29
+ """
@@ -16,7 +16,7 @@ import torch.nn as nn
16
16
  from transformers import PreTrainedModel
17
17
 
18
18
  from ....utils import logging
19
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
19
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
20
20
  from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
21
21
  from .opt_architecture import OPTWrapper
22
22
 
@@ -88,3 +88,43 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
88
88
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
89
89
 
90
90
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
91
+
92
+
93
+ class RBLNOPTModel(RBLNDecoderOnlyModel):
94
+ """
95
+ The OPT Model transformer without a language modeling head.
96
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
97
+ """
98
+
99
+ _decoder_wrapper_cls = OPTWrapper
100
+ _use_rotary_emb = False
101
+
102
+ def modify_opt_decoder_layer(layer):
103
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
104
+ layer.mlp = mlp
105
+ del layer.fc1
106
+ del layer.fc2
107
+ del layer.activation_fn
108
+
109
+ return layer
110
+
111
+ @classmethod
112
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
113
+ wrapper_cfg = {
114
+ "max_seq_len": rbln_config.max_seq_len,
115
+ "attn_impl": rbln_config.attn_impl,
116
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
117
+ "kvcache_block_size": rbln_config.kvcache_block_size,
118
+ "use_rotary_emb": cls._use_rotary_emb,
119
+ "use_attention_mask": rbln_config.use_attention_mask,
120
+ "use_position_ids": rbln_config.use_position_ids,
121
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
122
+ "cache_impl": rbln_config.cache_impl,
123
+ "sliding_window": rbln_config.sliding_window,
124
+ "sliding_window_layers": rbln_config.sliding_window_layers,
125
+ }
126
+
127
+ for i in range(len(model.decoder.layers)):
128
+ model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
129
+
130
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
@@ -40,11 +40,11 @@ class OPTWrapper(DecoderOnlyWrapper):
40
40
  def get_rbln_model_class(self):
41
41
  return OPTModel
42
42
 
43
- def get_model_layer(self, causal_lm: "OPTForCausalLM"):
44
- return causal_lm.model.decoder
43
+ def get_model_layer(self, model: "OPTForCausalLM"):
44
+ return model.model.decoder if self.is_causal_lm else model.decoder
45
45
 
46
- def get_decoder_layers(self, causal_lm: "OPTForCausalLM"):
47
- return causal_lm.model.decoder.layers
46
+ def get_decoder_layers(self, model: "OPTForCausalLM"):
47
+ return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
48
48
 
49
49
 
50
50
  class OPTAttention(DecoderOnlyAttention):
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig, RBLNPegasusModelConfig
17
+ from .modeling_pegasus import RBLNPegasusForConditionalGeneration, RBLNPegasusModel
@@ -0,0 +1,34 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ """
21
+ Configuration class for RBLNPegasusModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized PEGASUS models for feature extraction tasks.
25
+ """
26
+
27
+
28
+ class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
29
+ """
30
+ Configuration class for RBLNPegasusForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized PEGASUS models for conditional text generation tasks.
34
+ """