optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -18,9 +18,6 @@ import torch.nn as nn
18
18
 
19
19
  from ....utils import logging
20
20
  from ...models.decoderonly.decoderonly_architecture import (
21
- DecoderOnlyAttention,
22
- DecoderOnlyLayer,
23
- DecoderOnlyModel,
24
21
  DecoderOnlyWrapper,
25
22
  )
26
23
 
@@ -42,36 +39,3 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
42
39
 
43
40
  def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
44
41
  return causal_lm.transformer
45
-
46
- def get_rbln_attn_class(self):
47
- return ExaoneAttention
48
-
49
- def get_rbln_layer_class(self):
50
- return ExaoneLayer
51
-
52
- def get_rbln_model_class(self):
53
- return ExaoneModel
54
-
55
-
56
- class ExaoneModel(DecoderOnlyModel):
57
- def get_embedding(self) -> nn.Embedding:
58
- return self._original_mod.wte
59
-
60
- def get_last_layernorm(self) -> nn.LayerNorm:
61
- return self._original_mod.ln_f
62
-
63
-
64
- class ExaoneLayer(DecoderOnlyLayer):
65
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.ln_1
67
-
68
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
69
- return self._original_mod.ln_2
70
-
71
-
72
- class ExaoneAttention(DecoderOnlyAttention):
73
- def __post_init__(self):
74
- self.q_proj = self._original_mod.q_proj
75
- self.k_proj = self._original_mod.k_proj
76
- self.v_proj = self._original_mod.v_proj
77
- self.o_proj = self._original_mod.out_proj
@@ -24,4 +24,4 @@ class GemmaWrapper(DecoderOnlyWrapper):
24
24
  class GemmaModel(DecoderOnlyModel):
25
25
  @property
26
26
  def hidden_multiplier(self):
27
- return self._original_mod.config.hidden_size**0.5
27
+ return self.config.hidden_size**0.5
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gemma2 import RBLNGemma2ForCausalLMConfig, RBLNGemma2ModelConfig
16
+ from .modeling_gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2Model
@@ -0,0 +1,45 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGemma2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Gemma2 models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNGemma2ForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=8192,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNGemma2ForCausalLM.from_pretrained(
33
+ "google/gemma-2-9b",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
39
+
40
+
41
+ class RBLNGemma2ModelConfig(RBLNDecoderOnlyModelConfig):
42
+ """
43
+ Configuration class for RBLN Gemma2 models.
44
+ This class is an alias of RBLNDecoderOnlyModelConfig.
45
+ """
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyModel
20
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
21
+
22
+
23
+ class Gemma2Wrapper(DecoderOnlyWrapper):
24
+ def get_rbln_layer_class(self):
25
+ return Gemma2DecoderLayer
26
+
27
+ def get_rbln_attn_class(self):
28
+ return Gemma2Attention
29
+
30
+ def get_rbln_model_class(self):
31
+ return Gemma2Model
32
+
33
+
34
+ class Gemma2DecoderLayer(DecoderOnlyLayer):
35
+ _PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
36
+ _POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
37
+
38
+ def forward(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ attention_mask: torch.Tensor,
42
+ seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
43
+ past_key_values: Tuple[Tuple[torch.Tensor]],
44
+ cos: Optional[torch.Tensor] = None,
45
+ sin: Optional[torch.Tensor] = None,
46
+ block_tables: Optional[torch.Tensor] = None,
47
+ lora_int_id: Optional[torch.Tensor] = None,
48
+ ):
49
+ residual = hidden_states
50
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
51
+
52
+ hidden_states = self.self_attn(
53
+ hidden_states=hidden_states,
54
+ attention_mask=attention_mask,
55
+ seq_positions=seq_positions,
56
+ past_key_values=past_key_values,
57
+ cos=cos,
58
+ sin=sin,
59
+ block_tables=block_tables,
60
+ lora_int_id=lora_int_id,
61
+ )
62
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
63
+ hidden_states = residual + hidden_states
64
+
65
+ # Fully Connected
66
+ residual = hidden_states
67
+ hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
68
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
69
+ hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
70
+ hidden_states = residual + hidden_states
71
+
72
+ return hidden_states
73
+
74
+
75
+ class Gemma2Attention(DecoderOnlyAttention):
76
+ def get_attn_scale(self, self_attn):
77
+ return self_attn.config.query_pre_attn_scalar**-0.5
78
+
79
+
80
+ class Gemma2Model(DecoderOnlyModel):
81
+ @property
82
+ def hidden_multiplier(self):
83
+ return self.config.hidden_size**0.5
@@ -0,0 +1,101 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ....utils import logging
17
+ from ...models.decoderonly import (
18
+ RBLNDecoderOnlyModel,
19
+ RBLNDecoderOnlyModelForCausalLM,
20
+ )
21
+ from .gemma2_architecture import Gemma2Wrapper
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class RBLNGemma2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
28
+ """
29
+ The Gemma2 Model transformer with a language modeling head (linear layer) on top.
30
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
31
+
32
+ A class to convert and run pre-trained transformers based Gemma2ForCausalLM model on RBLN devices.
33
+ It implements the methods to convert a pre-trained transformers Gemma2ForCausalLM model into a RBLN transformer model by:
34
+
35
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
+ - compiling the resulting graph using the RBLN compiler.
37
+
38
+ **Configuration:**
39
+ This model uses [`RBLNGemma2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
40
+ the `rbln_config` parameter should be an instance of [`RBLNGemma2ForCausalLMConfig`] or a dictionary conforming to its structure.
41
+
42
+ See the [`RBLNGemma2ForCausalLMConfig`] class for all available configuration options.
43
+ Examples:
44
+ ```python
45
+ from optimum.rbln import RBLNGemma2ForCausalLM
46
+ # Simple usage using rbln_* arguments
47
+ # `max_seq_len` is automatically inferred from the model config
48
+ model = RBLNGemma2ForCausalLM.from_pretrained(
49
+ "google/gemma-2-9b",
50
+ export=True,
51
+ rbln_batch_size=1,
52
+ rbln_tensor_parallel_size=4,
53
+ )
54
+ # Using a config dictionary
55
+ rbln_config = {
56
+ "batch_size": 1,
57
+ "max_seq_len": 8192,
58
+ "tensor_parallel_size": 4,
59
+ }
60
+ model = RBLNGemma2ForCausalLM.from_pretrained(
61
+ "google/gemma-2-9b",
62
+ export=True,
63
+ rbln_config=rbln_config
64
+ )
65
+ # Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
66
+ from optimum.rbln import RBLNGemma2ForCausalLMConfig
67
+ config = RBLNGemma2ForCausalLMConfig(
68
+ batch_size=1,
69
+ max_seq_len=8192,
70
+ tensor_parallel_size=4
71
+ )
72
+ model = RBLNGemma2ForCausalLM.from_pretrained(
73
+ "google/gemma-2-9b",
74
+ export=True,
75
+ rbln_config=config
76
+ )
77
+ ```
78
+ """
79
+
80
+ _decoder_wrapper_cls = Gemma2Wrapper
81
+
82
+
83
+ class RBLNGemma2Model(RBLNDecoderOnlyModel):
84
+ """
85
+ The Gemma2 Model transformer without a language modeling head.
86
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
87
+
88
+ A class to convert and run pre-trained transformers based Gemma2Model model on RBLN devices.
89
+ It implements the methods to convert a pre-trained transformers Gemma2Model model into a RBLN transformer model by:
90
+
91
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
92
+ - compiling the resulting graph using the RBLN compiler.
93
+
94
+ **Configuration:**
95
+ This model uses [`RBLNGemma2ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
96
+ the `rbln_config` parameter should be an instance of [`RBLNGemma2ModelConfig`] or a dictionary conforming to its structure.
97
+
98
+ See the [`RBLNGemma2ModelConfig`] class for all available configuration options.
99
+ """
100
+
101
+ _decoder_wrapper_cls = Gemma2Wrapper
@@ -16,7 +16,6 @@ import copy
16
16
  from typing import Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
20
19
 
21
20
  from ..decoderonly.decoderonly_architecture import (
22
21
  DecoderOnlyAttention,
@@ -64,6 +63,7 @@ class Gemma3TextModel(DecoderOnlyModel):
64
63
  global_block_tables: Optional[torch.Tensor] = None,
65
64
  local_block_tables: Optional[torch.Tensor] = None,
66
65
  lora_int_id: Optional[torch.Tensor] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
67
  ):
68
68
  # retrieve input_ids and inputs_embeds
69
69
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -94,13 +94,18 @@ class Gemma3TextModel(DecoderOnlyModel):
94
94
  else:
95
95
  seq_positions = cache_position[:, :1]
96
96
 
97
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
97
+ cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
98
+ sliding_cache_pos = (cache_seq_len, cache_offset)
98
99
 
100
+ all_hidden_states = () if output_hidden_states else None
99
101
  for layer_idx, layer in enumerate(self.layers):
102
+ if output_hidden_states:
103
+ all_hidden_states += (hidden_states,)
100
104
  is_sliding = True if layer_idx in self.sliding_window_layers else False
105
+ is_sliding_decode = is_sliding and self.phase == "decode"
101
106
  hidden_states = layer(
102
107
  hidden_states=hidden_states,
103
- attention_mask=attention_mask,
108
+ attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
104
109
  seq_positions=sliding_cache_pos if is_sliding else seq_positions,
105
110
  past_key_values=past_key_values,
106
111
  cos=cos_local if is_sliding else cos_global,
@@ -110,15 +115,14 @@ class Gemma3TextModel(DecoderOnlyModel):
110
115
  )
111
116
 
112
117
  hidden_states = self.get_last_layernorm()(hidden_states)
113
- return hidden_states
118
+ if output_hidden_states:
119
+ all_hidden_states += (hidden_states,)
120
+ return hidden_states, all_hidden_states
114
121
 
115
122
 
116
123
  class Gemma3DecoderLayer(DecoderOnlyLayer):
117
- def get_pre_feedforward_layernorm(self) -> Gemma3RMSNorm:
118
- return self._original_mod.pre_feedforward_layernorm
119
-
120
- def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
121
- return self._original_mod.post_feedforward_layernorm
124
+ _PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
125
+ _POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
122
126
 
123
127
  def forward(
124
128
  self,
@@ -158,13 +162,13 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
158
162
 
159
163
 
160
164
  class Gemma3Attention(DecoderOnlyAttention):
161
- def __post_init__(self):
162
- self.q_proj = self._original_mod.q_proj
163
- self.k_proj = self._original_mod.k_proj
164
- self.v_proj = self._original_mod.v_proj
165
- self.o_proj = self._original_mod.o_proj
166
- self.q_norm = self._original_mod.q_norm
167
- self.k_norm = self._original_mod.k_norm
168
-
169
- def get_attn_scale(self):
170
- return self._original_mod.config.query_pre_attn_scalar**-0.5
165
+ def __post_init__(self, self_attn):
166
+ self.q_proj = self_attn.q_proj
167
+ self.k_proj = self_attn.k_proj
168
+ self.v_proj = self_attn.v_proj
169
+ self.o_proj = self_attn.o_proj
170
+ self.q_norm = self_attn.q_norm
171
+ self.k_norm = self_attn.k_norm
172
+
173
+ def get_attn_scale(self, self_attn):
174
+ return self_attn.config.query_pre_attn_scalar**-0.5
@@ -16,7 +16,7 @@ from typing import Optional
16
16
  import rebel
17
17
  import torch
18
18
 
19
- from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
19
+ from ...modeling_outputs import RBLNGemma3ForCausalLMOutput
20
20
  from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
21
21
  from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
22
22
 
@@ -26,7 +26,6 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
26
26
  super().__init__(*args, **kwargs)
27
27
  self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
28
28
  self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
29
- self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
30
29
 
31
30
  def _prepare_prefill_inputs(self, *args, **kwargs):
32
31
  (
@@ -106,6 +105,8 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
106
105
  )
107
106
 
108
107
  step = 0
108
+ output_logits = []
109
+ all_hidden_states = [] if self.rbln_config.output_hidden_states else None
109
110
  while step < query_length:
110
111
  if self.rbln_config.use_image_prefill:
111
112
  # Check if the prefill chunk is an image prefill
@@ -146,7 +147,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
146
147
  query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
147
148
 
148
149
  if is_image_prefill:
149
- logits = self.image_prefill(
150
+ outputs = self.image_prefill(
150
151
  input_chunk,
151
152
  cache_pos_chunk,
152
153
  block_tables,
@@ -157,7 +158,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
157
158
  lora_int_ids if self.rbln_config.use_lora else None,
158
159
  )
159
160
  else:
160
- logits = self.prefill(
161
+ outputs = self.prefill(
161
162
  input_chunk,
162
163
  cache_pos_chunk,
163
164
  block_tables,
@@ -168,78 +169,49 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
168
169
  lora_int_ids if self.rbln_config.use_lora else None,
169
170
  )
170
171
 
172
+ if self.rbln_config.output_hidden_states:
173
+ output_logits.append(outputs[0])
174
+ all_hidden_states.append(tuple(outputs[1:]))
175
+ else:
176
+ output_logits.append(outputs)
177
+
171
178
  padded_cache_lengths += current_padded_cache_lengths
172
179
  step += num_processed_tokens
173
180
 
174
- if not is_external_block_tables:
175
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
176
-
177
- return RBLNGemma3ForCausalLMOutput(
178
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
179
- )
180
-
181
- def decode_forward(
182
- self,
183
- inputs: torch.Tensor,
184
- cache_position: torch.Tensor = None,
185
- block_tables: torch.Tensor = None,
186
- is_external_block_tables: bool = None,
187
- attention_mask: Optional[torch.Tensor] = None,
188
- position_embed: Optional[torch.Tensor] = None,
189
- position_ids: Optional[torch.Tensor] = None,
190
- local_block_tables: Optional[torch.Tensor] = None,
191
- lora_int_ids: Optional[torch.Tensor] = None,
192
- ) -> torch.FloatTensor:
193
- if self.rbln_config.use_lora and lora_int_ids is None:
194
- if self.lora_int_ids is None:
195
- raise ValueError(
196
- "lora_int_id is required when using LoRA. "
197
- "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
198
- )
199
-
200
- lora_int_ids = self.lora_int_ids
201
-
202
- if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
203
- raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
204
-
205
- batch_size = inputs.shape[0]
206
- if batch_size != self.batch_size:
207
- raise RuntimeError(
208
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
209
- )
181
+ if self.rbln_config.output_hidden_states:
182
+ num_hidden_layers = len(all_hidden_states[0]) - 1
183
+ concatenated_hidden_states = ()
184
+ for l_idx in range(num_hidden_layers + 1):
185
+ l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
186
+ l_hidden_states = l_hidden_states[:, :query_length, :]
187
+ concatenated_hidden_states += (l_hidden_states,)
210
188
 
211
- if batch_size != cache_position.shape[0]:
212
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
189
+ all_hidden_states = concatenated_hidden_states
213
190
 
214
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
215
- if is_external_block_tables:
216
- if attention_mask is None:
217
- raise ValueError("attention_mask should be provided with external block tables.")
218
- if local_block_tables is None:
219
- raise ValueError("local_block_tables should be provided with external block tables.")
191
+ # Aggregate output_logits
192
+ output_logits = torch.concat(output_logits, dim=-2)
193
+ if self.rbln_config.logits_to_keep > 0:
194
+ output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
220
195
  else:
221
- local_block_tables = (
222
- local_block_tables
223
- if local_block_tables is not None
224
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
225
- )
226
- if self.rbln_config.use_attention_mask and attention_mask is None:
227
- for b_idx in range(batch_size):
228
- decoding_step = cache_position[b_idx].item()
229
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
230
- raise ValueError(
231
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
232
- )
233
- self.dec_attn_mask[b_idx, decoding_step] = 1
234
-
235
- attention_mask = self.dec_attn_mask
236
-
237
- if self.batch_size < block_tables.shape[0]:
238
- block_tables = block_tables[: self.batch_size]
196
+ output_logits = output_logits[:, :query_length, :]
197
+ # index copy for masked output_logits
198
+ if attention_mask is not None:
199
+ new_output_logits = torch.full(
200
+ (1, attention_mask.shape[-1], output_logits.shape[-1]),
201
+ fill_value=1e-10,
202
+ dtype=output_logits.dtype,
203
+ )
204
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
205
+ new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
239
206
 
240
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
241
- attention_mask = attention_mask[: self.batch_size]
207
+ output_logits = new_output_logits
242
208
 
243
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
209
+ if not is_external_block_tables:
210
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
244
211
 
245
- return RBLNDecoderOnlyOutput(logits=logits)
212
+ return RBLNGemma3ForCausalLMOutput(
213
+ logits=output_logits,
214
+ padded_cache_lengths=padded_cache_lengths,
215
+ attention_mask=chunked_attention_mask,
216
+ hidden_states=all_hidden_states,
217
+ )