optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +48 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +35 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  27. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  28. optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
  29. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  30. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  31. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  32. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  33. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  34. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  35. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  36. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  37. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  38. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  39. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  40. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  41. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  42. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  43. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  44. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  45. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  46. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  47. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
  49. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  51. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  52. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  53. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  54. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  55. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  56. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  57. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  58. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  59. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  60. optimum/rbln/utils/depreacate_utils.py +16 -0
  61. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
  62. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
  63. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable
17
+
18
+ from transformers import PegasusForConditionalGeneration, PreTrainedModel
19
+
20
+ from ....utils.logging import get_logger
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
22
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig
24
+ from .pegasus_architecture import PegasusWrapper
25
+
26
+
27
+ logger = get_logger()
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedModel
32
+
33
+
34
+ class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
35
+ """
36
+ RBLN optimized PEGASUS model for feature extraction tasks.
37
+
38
+ This class provides hardware-accelerated inference for PEGASUS encoder models
39
+ on RBLN devices, optimized for feature extraction use cases.
40
+ """
41
+
42
+
43
+ class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
44
+ """
45
+ RBLN optimized PEGASUS model for conditional text generation tasks.
46
+
47
+ This class provides hardware-accelerated inference for PEGASUS models
48
+ on RBLN devices, supporting sequence-to-sequence generation tasks
49
+ such as summarization, translation, and text generation.
50
+ """
51
+
52
+ support_causal_attn = True
53
+
54
+ @classmethod
55
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
56
+ return PegasusWrapper(
57
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
58
+ )
59
+
60
+ def __getattr__(self, __name: str) -> Any:
61
+ def redirect(func):
62
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
63
+
64
+ val = getattr(PegasusForConditionalGeneration, __name)
65
+
66
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
67
+ return redirect(val)
68
+
69
+ return val
@@ -0,0 +1,163 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.modeling_attn_mask_utils import (
20
+ _prepare_4d_attention_mask,
21
+ )
22
+ from transformers.utils import logging
23
+
24
+ from ..seq2seq.seq2seq_architecture import (
25
+ Seq2SeqCrossAttention,
26
+ Seq2SeqDecoder,
27
+ Seq2SeqDecoderLayer,
28
+ Seq2SeqDecoderWrapper,
29
+ Seq2SeqEncoderWrapper,
30
+ Seq2SeqForConditionalGeneration,
31
+ Seq2SeqSelfAttention,
32
+ )
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ class PegasusWrapper:
39
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
40
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
41
+ self.decoder = PegasusDecoderWrapper(model, use_attention_mask=use_attention_mask)
42
+
43
+
44
+ class PegasusDecoderWrapper(Seq2SeqDecoderWrapper):
45
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
46
+ new_layers = []
47
+ for layer in model.get_decoder().layers:
48
+ self_attn = PegasusSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
49
+ cross_attn = PegasusCrossAttention(layer.encoder_attn)
50
+ new_layers.append(PegasusDecoderLayer(layer, self_attn, cross_attn))
51
+
52
+ decoder_model = PegasusDecoder(model.get_decoder(), new_layers)
53
+ new_model = PegasusForConditionalGeneration(model, decoder_model)
54
+
55
+ return new_model
56
+
57
+
58
+ class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
59
+ pass
60
+
61
+
62
+ class PegasusDecoder(Seq2SeqDecoder):
63
+ has_pos_emb = True
64
+
65
+ def __post_init__(self):
66
+ self.embed_positions = self._original_mod.embed_positions
67
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
68
+ self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
69
+
70
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
71
+ if attention_mask is not None:
72
+ attention_mask = attention_mask[:, None, None, :]
73
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
74
+
75
+ return attention_mask, encoder_attention_mask
76
+
77
+ def apply_position_embedding(self, inputs_embeds, cache_position):
78
+ hidden_all = []
79
+ for i in range(inputs_embeds.shape[0]):
80
+ positions_idx = cache_position[i]
81
+ position_weight = self.embed_positions.weight
82
+ position = position_weight[positions_idx]
83
+ batch_hidden = position + inputs_embeds[i]
84
+ hidden_all.append(batch_hidden)
85
+ hidden_states = torch.stack(hidden_all, dim=0)
86
+
87
+ return hidden_states
88
+
89
+ def get_embedding(self):
90
+ if self.embed_scale is not None:
91
+ return lambda x: self.embed_tokens(x) * self.embed_scale
92
+ else:
93
+ return self.embed_tokens
94
+
95
+
96
+ class PegasusLayerFF(nn.Module):
97
+ def __init__(self, decoder_layer):
98
+ super().__init__()
99
+ self.fc1 = decoder_layer.fc1
100
+ self.fc2 = decoder_layer.fc2
101
+ self.activation_fn = decoder_layer.activation_fn
102
+ self.layer_norm = decoder_layer.final_layer_norm
103
+
104
+ def forward(self, hidden_states):
105
+ # Residual Connection
106
+ residual = hidden_states
107
+ hidden_states = self.layer_norm(hidden_states)
108
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
109
+ hidden_states = self.fc2(hidden_states)
110
+ hidden_states = residual + hidden_states
111
+ return hidden_states
112
+
113
+
114
+ class PegasusDecoderLayer(Seq2SeqDecoderLayer):
115
+ def __post_init__(self):
116
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
117
+ self.encoder_attn = self._original_mod.encoder_attn
118
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
119
+ self.ff_layer = PegasusLayerFF(self._original_mod)
120
+
121
+ def pre_self_attn_layer_norm(self, hidden_states):
122
+ return self.self_attn_layer_norm(hidden_states)
123
+
124
+ def post_self_attn_layer_norm(self, hidden_states):
125
+ return hidden_states
126
+
127
+ def pre_cross_attn_layer_norm(self, hidden_states):
128
+ return self.encoder_attn_layer_norm(hidden_states)
129
+
130
+ def post_cross_attn_layer_norm(self, hidden_states):
131
+ return hidden_states
132
+
133
+
134
+ class PegasusSelfAttention(Seq2SeqSelfAttention):
135
+ def __post_init__(self, use_attention_mask: bool = True):
136
+ self.q_proj = self._original_mod.q_proj
137
+ self.k_proj = self._original_mod.k_proj
138
+ self.v_proj = self._original_mod.v_proj
139
+ self.out_proj = self._original_mod.out_proj
140
+ self.num_heads = self._original_mod.num_heads
141
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
142
+ self.scaling = self.head_dim**-0.5
143
+ if use_attention_mask:
144
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
145
+ else:
146
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
147
+
148
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
149
+ query_states = self.q_proj(hidden_states) * self.scaling
150
+ key_states = self.k_proj(hidden_states)
151
+ value_states = self.v_proj(hidden_states)
152
+ return query_states, key_states, value_states
153
+
154
+
155
+ class PegasusCrossAttention(Seq2SeqCrossAttention):
156
+ def __post_init__(self):
157
+ self.q_proj = self._original_mod.q_proj
158
+ self.k_proj = self._original_mod.k_proj
159
+ self.v_proj = self._original_mod.v_proj
160
+ self.out_proj = self._original_mod.out_proj
161
+ self.num_heads = self._original_mod.num_heads
162
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
163
+ self.embed_dim = self._original_mod.embed_dim
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_phi import RBLNPhiForCausalLMConfig
16
- from .modeling_phi import RBLNPhiForCausalLM
15
+ from .configuration_phi import RBLNPhiForCausalLMConfig, RBLNPhiModelConfig
16
+ from .modeling_phi import RBLNPhiForCausalLM, RBLNPhiModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNPhiModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Phi models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .phi_architecture import PhiWrapper
18
18
 
19
19
 
@@ -81,3 +81,12 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = PhiWrapper
84
+
85
+
86
+ class RBLNPhiModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Phi Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+ """
91
+
92
+ _decoder_wrapper_cls = PhiWrapper
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Optional, Tuple
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import PhiForCausalLM
@@ -27,7 +27,7 @@ from ..decoderonly.decoderonly_architecture import (
27
27
 
28
28
 
29
29
  if TYPE_CHECKING:
30
- from transformers import PhiForCausalLM
30
+ from transformers import PhiForCausalLM, PhiModel
31
31
 
32
32
 
33
33
  class PhiWrapper(DecoderOnlyWrapper):
@@ -40,11 +40,11 @@ class PhiWrapper(DecoderOnlyWrapper):
40
40
  def get_rbln_model_class(self):
41
41
  return PhiModel
42
42
 
43
- def get_model_layer(self, causal_lm: "PhiForCausalLM"):
44
- return causal_lm.model
43
+ def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
44
+ return model.model if self.is_causal_lm else model
45
45
 
46
- def get_decoder_layers(self, causal_lm: "PhiForCausalLM"):
47
- return causal_lm.model.layers
46
+ def get_decoder_layers(self, model: Union["PhiForCausalLM", "PhiModel"]):
47
+ return model.model.layers if self.is_causal_lm else model.layers
48
48
 
49
49
 
50
50
  class PhiAttention(DecoderOnlyAttention):
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_pixtral import RBLNPixtralVisionModelConfig
16
+ from .modeling_pixtral import RBLNPixtralVisionModel
@@ -0,0 +1,43 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNPixtralVisionModelConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ max_image_size: Tuple = None,
24
+ batch_size: Optional[int] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ **kwargs: Dict[str, Any],
27
+ ):
28
+ """
29
+ Args:
30
+ max_image_size (Tuple): The size of max input images. A tuple (max_height, max_width)
31
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
32
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
33
+
34
+ Raises:
35
+ ValueError: If batch_size is not a positive integer.
36
+ """
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
41
+
42
+ self.max_image_size = max_image_size
43
+ self.output_hidden_states = output_hidden_states