optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,45 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGemma2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Gemma2 models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNGemma2ForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=8192,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNGemma2ForCausalLM.from_pretrained(
33
+ "google/gemma-2-9b",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
39
+
40
+
41
+ class RBLNGemma2ModelConfig(RBLNDecoderOnlyModelConfig):
42
+ """
43
+ Configuration class for RBLN Gemma2 models.
44
+ This class is an alias of RBLNDecoderOnlyModelConfig.
45
+ """
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyModel
20
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
21
+
22
+
23
+ class Gemma2Wrapper(DecoderOnlyWrapper):
24
+ def get_rbln_layer_class(self):
25
+ return Gemma2DecoderLayer
26
+
27
+ def get_rbln_attn_class(self):
28
+ return Gemma2Attention
29
+
30
+ def get_rbln_model_class(self):
31
+ return Gemma2Model
32
+
33
+
34
+ class Gemma2DecoderLayer(DecoderOnlyLayer):
35
+ _PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
36
+ _POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
37
+
38
+ def forward(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ attention_mask: torch.Tensor,
42
+ seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
43
+ past_key_values: Tuple[Tuple[torch.Tensor]],
44
+ cos: Optional[torch.Tensor] = None,
45
+ sin: Optional[torch.Tensor] = None,
46
+ block_tables: Optional[torch.Tensor] = None,
47
+ lora_int_id: Optional[torch.Tensor] = None,
48
+ ):
49
+ residual = hidden_states
50
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
51
+
52
+ hidden_states = self.self_attn(
53
+ hidden_states=hidden_states,
54
+ attention_mask=attention_mask,
55
+ seq_positions=seq_positions,
56
+ past_key_values=past_key_values,
57
+ cos=cos,
58
+ sin=sin,
59
+ block_tables=block_tables,
60
+ lora_int_id=lora_int_id,
61
+ )
62
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
63
+ hidden_states = residual + hidden_states
64
+
65
+ # Fully Connected
66
+ residual = hidden_states
67
+ hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
68
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
69
+ hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
70
+ hidden_states = residual + hidden_states
71
+
72
+ return hidden_states
73
+
74
+
75
+ class Gemma2Attention(DecoderOnlyAttention):
76
+ def get_attn_scale(self, self_attn):
77
+ return self_attn.config.query_pre_attn_scalar**-0.5
78
+
79
+
80
+ class Gemma2Model(DecoderOnlyModel):
81
+ @property
82
+ def hidden_multiplier(self):
83
+ return self.config.hidden_size**0.5
@@ -0,0 +1,101 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ....utils import logging
17
+ from ...models.decoderonly import (
18
+ RBLNDecoderOnlyModel,
19
+ RBLNDecoderOnlyModelForCausalLM,
20
+ )
21
+ from .gemma2_architecture import Gemma2Wrapper
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class RBLNGemma2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
28
+ """
29
+ The Gemma2 Model transformer with a language modeling head (linear layer) on top.
30
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
31
+
32
+ A class to convert and run pre-trained transformers based Gemma2ForCausalLM model on RBLN devices.
33
+ It implements the methods to convert a pre-trained transformers Gemma2ForCausalLM model into a RBLN transformer model by:
34
+
35
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
+ - compiling the resulting graph using the RBLN compiler.
37
+
38
+ **Configuration:**
39
+ This model uses [`RBLNGemma2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
40
+ the `rbln_config` parameter should be an instance of [`RBLNGemma2ForCausalLMConfig`] or a dictionary conforming to its structure.
41
+
42
+ See the [`RBLNGemma2ForCausalLMConfig`] class for all available configuration options.
43
+ Examples:
44
+ ```python
45
+ from optimum.rbln import RBLNGemma2ForCausalLM
46
+ # Simple usage using rbln_* arguments
47
+ # `max_seq_len` is automatically inferred from the model config
48
+ model = RBLNGemma2ForCausalLM.from_pretrained(
49
+ "google/gemma-2-9b",
50
+ export=True,
51
+ rbln_batch_size=1,
52
+ rbln_tensor_parallel_size=4,
53
+ )
54
+ # Using a config dictionary
55
+ rbln_config = {
56
+ "batch_size": 1,
57
+ "max_seq_len": 8192,
58
+ "tensor_parallel_size": 4,
59
+ }
60
+ model = RBLNGemma2ForCausalLM.from_pretrained(
61
+ "google/gemma-2-9b",
62
+ export=True,
63
+ rbln_config=rbln_config
64
+ )
65
+ # Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
66
+ from optimum.rbln import RBLNGemma2ForCausalLMConfig
67
+ config = RBLNGemma2ForCausalLMConfig(
68
+ batch_size=1,
69
+ max_seq_len=8192,
70
+ tensor_parallel_size=4
71
+ )
72
+ model = RBLNGemma2ForCausalLM.from_pretrained(
73
+ "google/gemma-2-9b",
74
+ export=True,
75
+ rbln_config=config
76
+ )
77
+ ```
78
+ """
79
+
80
+ _decoder_wrapper_cls = Gemma2Wrapper
81
+
82
+
83
+ class RBLNGemma2Model(RBLNDecoderOnlyModel):
84
+ """
85
+ The Gemma2 Model transformer without a language modeling head.
86
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
87
+
88
+ A class to convert and run pre-trained transformers based Gemma2Model model on RBLN devices.
89
+ It implements the methods to convert a pre-trained transformers Gemma2Model model into a RBLN transformer model by:
90
+
91
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
92
+ - compiling the resulting graph using the RBLN compiler.
93
+
94
+ **Configuration:**
95
+ This model uses [`RBLNGemma2ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
96
+ the `rbln_config` parameter should be an instance of [`RBLNGemma2ModelConfig`] or a dictionary conforming to its structure.
97
+
98
+ See the [`RBLNGemma2ModelConfig`] class for all available configuration options.
99
+ """
100
+
101
+ _decoder_wrapper_cls = Gemma2Wrapper
@@ -58,13 +58,8 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
58
58
  )
59
59
  self.image_prefill_chunk_size = image_prefill_chunk_size
60
60
 
61
- @property
62
- def use_image_prefill(self):
63
- return self.image_prefill_chunk_size is not None
64
-
65
- @property
66
- def decoder_runtime_idx(self):
67
- return 2 if self.use_image_prefill else 1
61
+ if not (self.use_attention_mask and self.use_position_ids):
62
+ raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
68
63
 
69
64
 
70
65
  class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -16,7 +16,6 @@ import copy
16
16
  from typing import Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
20
19
 
21
20
  from ..decoderonly.decoderonly_architecture import (
22
21
  DecoderOnlyAttention,
@@ -95,16 +94,18 @@ class Gemma3TextModel(DecoderOnlyModel):
95
94
  else:
96
95
  seq_positions = cache_position[:, :1]
97
96
 
98
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
97
+ cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
98
+ sliding_cache_pos = (cache_seq_len, cache_offset)
99
99
 
100
100
  all_hidden_states = () if output_hidden_states else None
101
101
  for layer_idx, layer in enumerate(self.layers):
102
102
  if output_hidden_states:
103
103
  all_hidden_states += (hidden_states,)
104
104
  is_sliding = True if layer_idx in self.sliding_window_layers else False
105
+ is_sliding_decode = is_sliding and self.phase == "decode"
105
106
  hidden_states = layer(
106
107
  hidden_states=hidden_states,
107
- attention_mask=attention_mask,
108
+ attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
108
109
  seq_positions=sliding_cache_pos if is_sliding else seq_positions,
109
110
  past_key_values=past_key_values,
110
111
  cos=cos_local if is_sliding else cos_global,
@@ -120,11 +121,8 @@ class Gemma3TextModel(DecoderOnlyModel):
120
121
 
121
122
 
122
123
  class Gemma3DecoderLayer(DecoderOnlyLayer):
123
- def get_pre_feedforward_layernorm(self) -> Gemma3RMSNorm:
124
- return self._original_mod.pre_feedforward_layernorm
125
-
126
- def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
127
- return self._original_mod.post_feedforward_layernorm
124
+ _PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
125
+ _POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
128
126
 
129
127
  def forward(
130
128
  self,
@@ -164,13 +162,13 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
164
162
 
165
163
 
166
164
  class Gemma3Attention(DecoderOnlyAttention):
167
- def __post_init__(self):
168
- self.q_proj = self._original_mod.q_proj
169
- self.k_proj = self._original_mod.k_proj
170
- self.v_proj = self._original_mod.v_proj
171
- self.o_proj = self._original_mod.o_proj
172
- self.q_norm = self._original_mod.q_norm
173
- self.k_norm = self._original_mod.k_norm
174
-
175
- def get_attn_scale(self):
176
- return self._original_mod.config.query_pre_attn_scalar**-0.5
165
+ def __post_init__(self, self_attn):
166
+ self.q_proj = self_attn.q_proj
167
+ self.k_proj = self_attn.k_proj
168
+ self.v_proj = self_attn.v_proj
169
+ self.o_proj = self_attn.o_proj
170
+ self.q_norm = self_attn.q_norm
171
+ self.k_norm = self_attn.k_norm
172
+
173
+ def get_attn_scale(self, self_attn):
174
+ return self_attn.config.query_pre_attn_scalar**-0.5
@@ -13,11 +13,9 @@
13
13
  # limitations under the License.
14
14
  import importlib
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
17
17
 
18
- import rebel
19
18
  import torch
20
- from rebel.compile_context import CompileContext
21
19
  from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
22
20
  from transformers.modeling_outputs import BaseModelOutputWithPooling
23
21
  from transformers.modeling_utils import no_init_weights
@@ -29,10 +27,7 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
27
  from ...utils.rbln_runtime_wrapper import LoopProcessor
30
28
  from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
29
  from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
- from ..decoderonly.modeling_decoderonly import (
33
- RBLNDecoderOnlyModelForCausalLM,
34
- )
35
- from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
30
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
36
31
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
32
  from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
38
33
 
@@ -325,7 +320,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
325
320
  batch_size,
326
321
  inputs_embeds.shape[1],
327
322
  self.config.text_config.hidden_size,
328
- dtype=self.rbln_config.torch_dtype,
323
+ dtype=self.rbln_config.dtype,
329
324
  )
330
325
  for _ in range(self.config.text_config.num_hidden_layers + 1)
331
326
  )
@@ -455,174 +450,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
455
450
  f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
456
451
  )
457
452
 
458
- return rbln_config
459
-
460
- @classmethod
461
- def _update_rbln_config(
462
- cls,
463
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
464
- model: Optional["PreTrainedModel"] = None,
465
- model_config: Optional["PretrainedConfig"] = None,
466
- rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
467
- ) -> RBLNGemma3ForCausalLMConfig:
468
- # Update rbln_config with super class
469
- rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
470
-
471
- if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
472
- raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
473
-
474
- if rbln_config.use_image_prefill:
475
- if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
476
- raise NotImplementedError(
477
- "Not implemented for different prefill chunk sizes between text and image prefill."
478
- )
479
-
480
- # Update image prefill compile config
481
- img_prefill_input_info = cls.get_input_info(
482
- batch_size=1,
483
- query_length=rbln_config.image_prefill_chunk_size,
484
- rbln_config=rbln_config,
485
- model_config=model_config,
486
- )
487
- image_prefill_compile_config = RBLNCompileConfig(
488
- compiled_model_name="image_prefill", input_info=img_prefill_input_info
489
- )
490
- # Insert image_prefill compile config at index 1
491
- compile_cfgs = rbln_config.compile_cfgs
492
- compile_cfgs.insert(1, image_prefill_compile_config)
493
- rbln_config.set_compile_cfgs(compile_cfgs)
453
+ if "image_prefill" not in rbln_config.phases:
454
+ rbln_config.phases = ["prefill", "image_prefill", "decode"]
494
455
 
495
456
  return rbln_config
496
-
497
- @classmethod
498
- @torch.inference_mode()
499
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
500
- wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
501
-
502
- rbln_compile_configs = rbln_config.compile_cfgs
503
- prefill_compile_config = rbln_compile_configs[0]
504
-
505
- context = CompileContext(use_weight_sharing=True)
506
-
507
- # Here we use meta tensor, for the memory efficiency.
508
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
509
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
510
-
511
- # Mark static tensors (self kv states)
512
- static_tensors = {}
513
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
514
- if "past_key_values" in name:
515
- static_tensors[name] = tensor
516
- context.mark_static_address(tensor)
517
-
518
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
519
- try:
520
- if quantization:
521
- quantization.maybe_set_quantization_env()
522
- original_linear = torch.nn.functional.linear
523
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
524
- compiled_model = cls.compile(
525
- wrapped_model,
526
- compile_config,
527
- create_runtimes=rbln_config.create_runtimes,
528
- device=rbln_config.device,
529
- example_inputs=example_inputs,
530
- compile_context=compile_context,
531
- )
532
- return compiled_model
533
- finally:
534
- torch.nn.functional.linear = original_linear
535
- if quantization:
536
- quantization.maybe_reset_quantization_env()
537
-
538
- wrapped_model.phase = "prefill"
539
- compiled_prefill = compile_model(
540
- wrapped_model,
541
- prefill_compile_config,
542
- prefill_example_inputs,
543
- context,
544
- rbln_config.quantization,
545
- )
546
- compiled_models = {"prefill": compiled_prefill}
547
-
548
- if rbln_config.use_image_prefill:
549
- image_prefill_compile_config = rbln_compile_configs[1]
550
- image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
551
- fill=0, static_tensors=static_tensors
552
- )
553
- wrapped_model.phase = "image_prefill"
554
- compiled_image_prefill = compile_model(
555
- wrapped_model,
556
- image_prefill_compile_config,
557
- image_prefill_example_inputs,
558
- context,
559
- rbln_config.quantization,
560
- )
561
- compiled_models["image_prefill"] = compiled_image_prefill
562
-
563
- wrapped_model.phase = "decode"
564
- for batch_size, dec_compile_config in zip(
565
- rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
566
- ):
567
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
568
- compiled_decoder = compile_model(
569
- wrapped_model,
570
- dec_compile_config,
571
- dec_example_inputs,
572
- context,
573
- rbln_config.quantization,
574
- )
575
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
576
-
577
- return compiled_models
578
-
579
- @classmethod
580
- def _create_runtimes(
581
- cls,
582
- compiled_models: List[rebel.RBLNCompiledModel],
583
- rbln_config: RBLNGemma3ForCausalLMConfig,
584
- ) -> List[rebel.Runtime]:
585
- expected_model_names = [
586
- "prefill",
587
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
588
- ]
589
- if rbln_config.use_image_prefill:
590
- expected_model_names.insert(1, "image_prefill")
591
-
592
- if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
593
- cls._raise_missing_compiled_file_error(expected_model_names)
594
-
595
- ret_val = [
596
- rebel.Runtime(
597
- compiled_models[0],
598
- tensor_type="pt",
599
- device=rbln_config.device_map["prefill"],
600
- activate_profiler=rbln_config.activate_profiler,
601
- timeout=rbln_config.timeout,
602
- )
603
- ]
604
- if rbln_config.use_image_prefill:
605
- ret_val.append(
606
- rebel.Runtime(
607
- compiled_models[1],
608
- tensor_type="pt",
609
- device=rbln_config.device_map["image_prefill"],
610
- activate_profiler=rbln_config.activate_profiler,
611
- timeout=rbln_config.timeout,
612
- ),
613
- )
614
-
615
- ret_val.extend(
616
- [
617
- rebel.Runtime(
618
- compiled_models[i + rbln_config.decoder_runtime_idx],
619
- tensor_type="pt",
620
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
621
- activate_profiler=rbln_config.activate_profiler,
622
- timeout=rbln_config.timeout,
623
- )
624
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
625
- ]
626
- )
627
-
628
- return ret_val
@@ -20,8 +20,6 @@ import torch.nn as nn
20
20
 
21
21
  from ..decoderonly.decoderonly_architecture import (
22
22
  DecoderOnlyAttention,
23
- DecoderOnlyLayer,
24
- DecoderOnlyModel,
25
23
  DecoderOnlyWrapper,
26
24
  )
27
25
 
@@ -34,12 +32,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
34
32
  def get_rbln_attn_class(self):
35
33
  return GPT2Attention
36
34
 
37
- def get_rbln_layer_class(self):
38
- return GPT2Layer
39
-
40
- def get_rbln_model_class(self):
41
- return GPT2Model
42
-
43
35
  def get_attn_layer(self, layer: nn.Module):
44
36
  return layer.attn
45
37
 
@@ -50,30 +42,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
50
42
  return model.transformer.h if self.is_causal_lm else model.h
51
43
 
52
44
 
53
- class GPT2Model(DecoderOnlyModel):
54
- def get_last_layernorm(self) -> nn.LayerNorm:
55
- return self._original_mod.ln_f
56
-
57
- def get_embedding(self) -> nn.Embedding:
58
- return self._original_mod.wte
59
-
60
- def get_pos_embedding(self) -> nn.Embedding:
61
- return self._original_mod.wpe
62
-
63
-
64
- class GPT2Layer(DecoderOnlyLayer):
65
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.ln_1
67
-
68
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
69
- return self._original_mod.ln_2
70
-
71
-
72
45
  class GPT2Attention(DecoderOnlyAttention):
73
- def __post_init__(self):
74
- self.c_attn = self._original_mod.c_attn
75
- self.o_proj = self._original_mod.c_proj
76
- self.split_size = self._original_mod.split_size
46
+ def __post_init__(self, self_attn):
47
+ self.c_attn = self_attn.c_attn
48
+ self.o_proj = self_attn.c_proj
49
+ self.split_size = self_attn.split_size
50
+ self.num_key_value_heads = self_attn.num_heads
77
51
 
78
52
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
53
  if lora_int_id is not None:
@@ -82,12 +56,12 @@ class GPT2Attention(DecoderOnlyAttention):
82
56
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
83
57
  return query_states, key_states, value_states
84
58
 
85
- def get_attn_scale(self):
59
+ def get_attn_scale(self, self_attn):
86
60
  scale = 1.0
87
- if self._original_mod.scale_attn_weights:
61
+ if self_attn.scale_attn_weights:
88
62
  scale /= math.sqrt(self.head_dim)
89
63
 
90
- if self._original_mod.scale_attn_by_inverse_layer_idx:
64
+ if self_attn.scale_attn_by_inverse_layer_idx:
91
65
  scale /= 1 + self.layer_idx
92
66
 
93
67
  return scale
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gpt_oss import RBLNGptOssForCausalLMConfig
16
+ from .modeling_gpt_oss import RBLNGptOssForCausalLM
@@ -0,0 +1,42 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN GptOss models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNGptOssForCausalLMConfig(
30
+ batch_size=1,
31
+ tensor_parallel_size=8,
32
+ kvcache_partition_len=8192,
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNGptOssForCausalLM.from_pretrained(
37
+ "openai/gpt-oss-20b",
38
+ export=True,
39
+ rbln_config=config,
40
+ )
41
+ ```
42
+ """