optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +53 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  54. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  55. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  56. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  57. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  58. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  59. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
  60. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  63. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  64. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  65. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  66. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  67. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  68. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  69. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  70. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  71. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
  72. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  73. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  74. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  75. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  76. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  77. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  78. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  79. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  80. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  81. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  82. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  83. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  84. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  85. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  86. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  87. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  91. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  92. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  93. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  94. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
  95. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
  96. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  97. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  98. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  99. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  100. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  101. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  102. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  103. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  104. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  105. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  106. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  108. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  109. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  110. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  111. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  112. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  113. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  114. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  115. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  116. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  117. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  118. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  119. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  123. optimum/rbln/utils/model_utils.py +20 -0
  124. optimum/rbln/utils/submodule.py +6 -8
  125. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
  126. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
  127. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  128. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  129. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
  130. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_distilbert import RBLNDistilBertForQuestionAnsweringConfig
16
+ from .modeling_distilbert import RBLNDistilBertForQuestionAnswering
17
+
18
+
19
+ __all__ = ["RBLNDistilBertForQuestionAnsweringConfig", "RBLNDistilBertForQuestionAnswering"]
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForQuestionAnsweringConfig
16
+
17
+
18
+ class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
19
+ ""
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...modeling_generic import RBLNModelForQuestionAnswering
16
+
17
+
18
+ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
19
+ rbln_model_input_names = ["input_ids", "attention_mask"]
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNExaoneForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLN Exaone models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNExaoneForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNExaoneForCausalLM.from_pretrained(
37
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """
@@ -60,7 +60,11 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
60
60
  new_layer = ExaoneLayer(layer, new_self_attn)
61
61
  new_layers.append(new_layer)
62
62
  new_model = ExaoneModel(
63
- causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
63
+ causal_lm.transformer,
64
+ new_layers,
65
+ partition_len=self.kvcache_partition_len,
66
+ max_seq_len=max_seq_len,
67
+ sliding_window_layers=self.sliding_window_layers,
64
68
  )
65
69
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
66
70
  return new_causal_lm
@@ -13,7 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ import inspect
17
+ from typing import Any, Callable
18
+
16
19
  from transformers import AutoModelForCausalLM
20
+ from transformers.generation.utils import GenerationMixin
17
21
 
18
22
  from ....utils import logging
19
23
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -25,22 +29,79 @@ logger = logging.get_logger(__name__)
25
29
 
26
30
  class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
27
31
  """
28
- The Exaone Model transformer with a language modeling head on top (linear layer with weights tied to the input
29
- embeddings).
32
+ The Exaone Model transformer with a language modeling head (linear layer) on top.
33
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
30
34
 
31
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
32
- library implements for all its model.
35
+ A class to convert and run pre-trained transformers based ExaoneForCausalLM model on RBLN devices.
36
+ It implements the methods to convert a pre-trained transformers ExaoneForCausalLM model into a RBLN transformer model by:
33
37
 
34
- It implements the methods to convert a pre-trained transformers Exaone model into a RBLN transformer model by:
35
38
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
39
  - compiling the resulting graph using the RBLN compiler.
37
40
 
41
+ **Configuration:**
42
+ This model uses [`RBLNExaoneForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
43
+ the `rbln_config` parameter should be an instance of [`RBLNExaoneForCausalLMConfig`] or a dictionary conforming to its structure.
44
+
45
+ See the [`RBLNExaoneForCausalLMConfig`] class for all available configuration options.
46
+
47
+ Examples:
48
+ ```python
49
+ from optimum.rbln import RBLNExaoneForCausalLM
50
+
51
+ # Simple usage using rbln_* arguments
52
+ # `max_seq_len` is automatically inferred from the model config
53
+ model = RBLNExaoneForCausalLM.from_pretrained(
54
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
55
+ export=True,
56
+ rbln_batch_size=1,
57
+ rbln_tensor_parallel_size=4,
58
+ )
59
+
60
+
61
+ # Using a config dictionary
62
+ rbln_config = {
63
+ "batch_size": 1,
64
+ "max_seq_len": 4096,
65
+ "tensor_parallel_size": 4,
66
+ }
67
+ model = RBLNExaoneForCausalLM.from_pretrained(
68
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
69
+ export=True,
70
+ rbln_config=rbln_config
71
+ )
72
+
73
+
74
+ # Using a RBLNExaoneForCausalLMConfig instance (recommended for type checking)
75
+ from optimum.rbln import RBLNExaoneForCausalLMConfig
76
+
77
+ config = RBLNExaoneForCausalLMConfig(
78
+ batch_size=1,
79
+ max_seq_len=4096,
80
+ tensor_parallel_size=4
81
+ )
82
+ model = RBLNExaoneForCausalLM.from_pretrained(
83
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
84
+ export=True,
85
+ rbln_config=config
86
+ )
87
+ ```
38
88
  """
39
89
 
40
90
  _decoder_wrapper_cls = ExaoneForCausalLMWrapper
41
91
  _hf_class = AutoModelForCausalLM
92
+ _supports_cache_class = True
42
93
 
43
94
  @classmethod
44
95
  def from_pretrained(cls, *args, **kwargs):
45
96
  kwargs.setdefault("trust_remote_code", True)
46
97
  return super().from_pretrained(*args, **kwargs)
98
+
99
+ def __getattr__(self, __name: str) -> Any:
100
+ def redirect(func):
101
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
102
+
103
+ val = getattr(GenerationMixin, __name)
104
+
105
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
106
+ return redirect(val)
107
+ return val
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLN Gemma models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNGemmaForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNGemmaForCausalLM.from_pretrained(
37
+ "google/gemma-7b",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """
@@ -52,7 +52,11 @@ class GemmaWrapper(DecoderOnlyWrapper):
52
52
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
53
  new_layers.append(new_layer)
54
54
  new_model = GemmaModel(
55
- causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
55
+ causal_lm.model,
56
+ new_layers,
57
+ partition_len=self.kvcache_partition_len,
58
+ max_seq_len=max_seq_len,
59
+ sliding_window_layers=self.sliding_window_layers,
56
60
  )
57
61
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
58
62
  return new_causal_lm
@@ -27,8 +27,57 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
27
27
 
28
28
  A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
29
29
  It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
30
+
30
31
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
31
32
  - compiling the resulting graph using the RBLN compiler.
33
+
34
+ **Configuration:**
35
+ This model uses [`RBLNGemmaForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
36
+ the `rbln_config` parameter should be an instance of [`RBLNGemmaForCausalLMConfig`] or a dictionary conforming to its structure.
37
+
38
+ See the [`RBLNGemmaForCausalLMConfig`] class for all available configuration options.
39
+
40
+ Examples:
41
+ ```python
42
+ from optimum.rbln import RBLNGemmaForCausalLM
43
+
44
+ # Simple usage using rbln_* arguments
45
+ # `max_seq_len` is automatically inferred from the model config
46
+ model = RBLNGemmaForCausalLM.from_pretrained(
47
+ "google/gemma-7b",
48
+ export=True,
49
+ rbln_batch_size=1,
50
+ rbln_tensor_parallel_size=4,
51
+ )
52
+
53
+
54
+ # Using a config dictionary
55
+ rbln_config = {
56
+ "batch_size": 1,
57
+ "max_seq_len": 4096,
58
+ "tensor_parallel_size": 4,
59
+ }
60
+ model = RBLNGemmaForCausalLM.from_pretrained(
61
+ "google/gemma-7b",
62
+ export=True,
63
+ rbln_config=rbln_config
64
+ )
65
+
66
+
67
+ # Using a RBLNGemmaForCausalLMConfig instance (recommended for type checking)
68
+ from optimum.rbln import RBLNGemmaForCausalLMConfig
69
+
70
+ config = RBLNGemmaForCausalLMConfig(
71
+ batch_size=1,
72
+ max_seq_len=4096,
73
+ tensor_parallel_size=4
74
+ )
75
+ model = RBLNGemmaForCausalLM.from_pretrained(
76
+ "google/gemma-7b",
77
+ export=True,
78
+ rbln_config=config
79
+ )
80
+ ```
32
81
  """
33
82
 
34
83
  _decoder_wrapper_cls = GemmaWrapper
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Optional
14
+ from typing import Any, Dict, Optional
15
15
 
16
16
  import rebel
17
17
 
@@ -26,7 +26,7 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
26
26
  prefill_chunk_size: Optional[int] = None,
27
27
  use_position_ids: Optional[bool] = None,
28
28
  use_attention_mask: Optional[bool] = None,
29
- **kwargs,
29
+ **kwargs: Dict[str, Any],
30
30
  ):
31
31
  # use_attention_mask and use_position_ids are always True for Gemma3
32
32
  use_attention_mask = use_attention_mask or True
@@ -53,7 +53,7 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
53
53
  batch_size: Optional[int] = None,
54
54
  vision_tower: Optional[RBLNModelConfig] = None,
55
55
  language_model: Optional[RBLNModelConfig] = None,
56
- **kwargs,
56
+ **kwargs: Dict[str, Any],
57
57
  ):
58
58
  """
59
59
  Args:
@@ -16,11 +16,9 @@ import copy
16
16
  from typing import TYPE_CHECKING, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from torch import nn
20
19
  from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
21
20
 
22
21
  from ..decoderonly.decoderonly_architecture import (
23
- AttentionOp,
24
22
  DecoderOnlyAttention,
25
23
  DecoderOnlyFlashAttention,
26
24
  DecoderOnlyForCausalLM,
@@ -28,7 +26,6 @@ from ..decoderonly.decoderonly_architecture import (
28
26
  DecoderOnlyModel,
29
27
  DecoderOnlyWrapper,
30
28
  RotaryEmbedding,
31
- SlidingWindowAttentionOp,
32
29
  slice_and_unsqueeze_cos_sin,
33
30
  )
34
31
 
@@ -50,13 +47,14 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
50
47
 
51
48
  def convert_to_rbln_causal_lm(self, causal_lm: "Gemma3ForCausalLM", max_seq_len: int):
52
49
  new_layers = []
53
- for layer in causal_lm.model.layers:
54
- if layer.is_sliding:
50
+ for layer_idx, layer in enumerate(causal_lm.model.layers):
51
+ if layer_idx in self.sliding_window_layers:
55
52
  new_self_attn = Gemma3Attention(
56
53
  layer.self_attn,
57
54
  use_attention_mask=None, # FIXME: no use in SWA
58
55
  use_position_ids=self.use_position_ids,
59
56
  kvcache_block_size=self.config.sliding_window,
57
+ is_sliding=True,
60
58
  )
61
59
  else:
62
60
  if self.attn_impl == "eager":
@@ -65,6 +63,7 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
65
63
  use_attention_mask=self.use_attention_mask,
66
64
  use_position_ids=self.use_position_ids,
67
65
  kvcache_block_size=self.kvcache_block_size,
66
+ is_sliding=False,
68
67
  )
69
68
  elif self.attn_impl == "flash_attn":
70
69
  new_self_attn = Gemma3FlashAttention(
@@ -85,131 +84,14 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
85
84
  new_layers,
86
85
  partition_len=self.kvcache_partition_len,
87
86
  max_seq_len=max_seq_len,
87
+ sliding_window_layers=self.sliding_window_layers,
88
88
  )
89
- new_causal_lm = Gemma3ForCausalLM(causal_lm, new_model)
89
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
90
90
  return new_causal_lm
91
91
 
92
- def forward(self, *args):
93
- if self.phase == "decode":
94
- (
95
- input_ids_or_inputs_embeds,
96
- attention_mask, # used in global layer, 2D attn_mask for padded KVcache.
97
- cache_position,
98
- position_ids,
99
- golbal_block_tables,
100
- local_block_tables,
101
- *past_key_values,
102
- ) = args
103
- query_position = None
104
-
105
- elif "prefill" in self.phase:
106
- (
107
- input_ids_or_inputs_embeds,
108
- attention_mask,
109
- cache_position,
110
- position_ids,
111
- query_position,
112
- golbal_block_tables,
113
- local_block_tables,
114
- *past_key_values,
115
- ) = args
116
-
117
- else:
118
- raise ValueError(f"Unknown phase: {self.phase}")
119
-
120
- if input_ids_or_inputs_embeds.ndim == 2:
121
- input_ids = input_ids_or_inputs_embeds
122
- inputs_embeds = None
123
- elif input_ids_or_inputs_embeds.ndim == 3:
124
- input_ids = None
125
- inputs_embeds = input_ids_or_inputs_embeds
126
- else:
127
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
128
-
129
- if len(past_key_values) != 2 * self.num_hidden_layers:
130
- raise ValueError(
131
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
132
- )
133
-
134
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
135
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
136
- _past_key_values = []
137
- for i in range(self.config.num_hidden_layers):
138
- key_states = past_key_values[i * 2]
139
- value_states = past_key_values[i * 2 + 1]
140
- past_key_value = [key_states, value_states]
141
- _past_key_values.append(past_key_value)
142
- past_key_values = _past_key_values
143
-
144
- logit = self.causal_lm(
145
- input_ids=input_ids,
146
- inputs_embeds=inputs_embeds,
147
- attention_mask=attention_mask,
148
- cache_position=cache_position,
149
- position_ids=position_ids,
150
- query_position=query_position,
151
- past_key_values=past_key_values,
152
- rotary_emb=(self.rotary_emb_global, self.rotary_emb_local),
153
- global_block_tables=golbal_block_tables,
154
- local_block_tables=local_block_tables,
155
- )
156
-
157
- return logit
158
-
159
-
160
- class Gemma3ForCausalLM(DecoderOnlyForCausalLM):
161
- def forward(
162
- self,
163
- input_ids: torch.Tensor = None,
164
- inputs_embeds: torch.Tensor = None,
165
- attention_mask: torch.Tensor = None,
166
- cache_position: torch.Tensor = None,
167
- position_ids: torch.Tensor = None,
168
- query_position: torch.Tensor = None,
169
- past_key_values: Tuple[Tuple[torch.Tensor]] = None,
170
- rotary_emb: nn.Module = None,
171
- global_block_tables: Optional[torch.Tensor] = None,
172
- local_block_tables: Optional[torch.Tensor] = None,
173
- ):
174
- # outputs
175
- hidden_states = self.model(
176
- input_ids=input_ids,
177
- inputs_embeds=inputs_embeds,
178
- attention_mask=attention_mask,
179
- cache_position=cache_position,
180
- position_ids=position_ids,
181
- query_position=query_position,
182
- past_key_values=past_key_values,
183
- rotary_emb=rotary_emb,
184
- global_block_tables=global_block_tables,
185
- local_block_tables=local_block_tables,
186
- )
187
-
188
- if "prefill" in self.phase:
189
- hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
190
-
191
- logits = self.lm_head(hidden_states)
192
-
193
- # Apply final logit softmaxing if configured, e.g. for Gemma2
194
- if getattr(self.config, "final_logit_softcapping", None) is not None:
195
- logits = logits / self.config.final_logit_softcapping
196
- logits = torch.tanh(logits)
197
- logits = logits * self.config.final_logit_softcapping
198
-
199
- return logits
200
-
201
92
 
202
93
  class Gemma3TextModel(DecoderOnlyModel):
203
- def get_local_cache_positions(self, position_ids, query_position):
204
- max_cache_len = self._original_mod.config.sliding_window
205
- valid_input_len = 1 if query_position is None else query_position + 1
206
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
207
- cache_offset = (
208
- torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
209
- ) # cache offset for next steps
210
-
211
- return cache_seq_len, cache_offset
212
-
94
+ # Different from DecoderOnlyModel, this model has global and local rotary embeddings.
213
95
  def forward(
214
96
  self,
215
97
  input_ids: torch.Tensor = None,
@@ -254,37 +136,23 @@ class Gemma3TextModel(DecoderOnlyModel):
254
136
 
255
137
  sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
256
138
 
257
- for layer in self.layers:
258
- if layer.is_sliding:
259
- hidden_states = layer(
260
- hidden_states=hidden_states,
261
- attention_mask=attention_mask,
262
- seq_positions=sliding_cache_pos,
263
- past_key_values=past_key_values,
264
- cos=cos_local,
265
- sin=sin_local,
266
- block_tables=local_block_tables,
267
- )
268
- else:
269
- hidden_states = layer(
270
- hidden_states=hidden_states,
271
- attention_mask=attention_mask,
272
- seq_positions=seq_positions,
273
- past_key_values=past_key_values,
274
- cos=cos_global,
275
- sin=sin_global,
276
- block_tables=global_block_tables,
277
- )
139
+ for layer_idx, layer in enumerate(self.layers):
140
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
141
+ hidden_states = layer(
142
+ hidden_states=hidden_states,
143
+ attention_mask=attention_mask,
144
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
145
+ past_key_values=past_key_values,
146
+ cos=cos_local if is_sliding else cos_global,
147
+ sin=sin_local if is_sliding else sin_global,
148
+ block_tables=local_block_tables if is_sliding else global_block_tables,
149
+ )
278
150
 
279
151
  hidden_states = self.get_last_layernorm()(hidden_states)
280
152
  return hidden_states
281
153
 
282
154
 
283
155
  class Gemma3DecoderLayer(DecoderOnlyLayer):
284
- def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
285
- super().__init__(layer, self_attn)
286
- self.is_sliding = self._original_mod.is_sliding
287
-
288
156
  def get_pre_feedforward_layernorm(self) -> Gemma3RMSNorm:
289
157
  return self._original_mod.pre_feedforward_layernorm
290
158
 
@@ -328,69 +196,10 @@ class Gemma3Attention(DecoderOnlyAttention):
328
196
  self.o_proj = self._original_mod.o_proj
329
197
  self.q_norm = self._original_mod.q_norm
330
198
  self.k_norm = self._original_mod.k_norm
331
- self.is_sliding = self._original_mod.is_sliding
332
199
 
333
200
  def get_attn_scale(self):
334
201
  return self._original_mod.config.query_pre_attn_scalar**-0.5
335
202
 
336
- def get_attention(self):
337
- if self._original_mod.is_sliding:
338
- return SlidingWindowAttentionOp(
339
- self.num_heads,
340
- self.head_dim,
341
- self.num_key_value_heads,
342
- self.use_attention_mask,
343
- self.use_position_ids,
344
- )
345
- else:
346
- return AttentionOp(
347
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
348
- )
349
-
350
- def forward(
351
- self,
352
- hidden_states: torch.Tensor,
353
- attention_mask: torch.Tensor,
354
- seq_positions: torch.LongTensor,
355
- past_key_values: Tuple[Tuple[torch.Tensor]],
356
- cos: Optional[torch.Tensor] = None,
357
- sin: Optional[torch.Tensor] = None,
358
- block_tables: Optional[torch.Tensor] = None,
359
- ):
360
- batch_size, query_length, _ = hidden_states.size()
361
-
362
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
363
-
364
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
365
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
366
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
367
- 1, 2
368
- )
369
-
370
- query_states = self.q_norm(query_states)
371
- key_states = self.k_norm(key_states)
372
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
373
-
374
- batch_size = query_states.shape[0]
375
- if batch_size > 1 and "prefill" in self.phase:
376
- raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
377
-
378
- attn_output = self.attention(
379
- query_states,
380
- key_states,
381
- value_states,
382
- attention_mask,
383
- past_key_state=past_key_values[self.layer_idx][0],
384
- past_value_state=past_key_values[self.layer_idx][1],
385
- seq_position=seq_positions,
386
- scale=self.scale,
387
- block_tables=block_tables,
388
- block_size=self.kvcache_block_size,
389
- )
390
-
391
- attn_outputs = self.o_proj(attn_output)
392
- return attn_outputs
393
-
394
203
 
395
204
  class Gemma3FlashAttention(DecoderOnlyFlashAttention):
396
205
  def __post_init__(self):
@@ -400,47 +209,6 @@ class Gemma3FlashAttention(DecoderOnlyFlashAttention):
400
209
  self.o_proj = self._original_mod.o_proj
401
210
  self.q_norm = self._original_mod.q_norm
402
211
  self.k_norm = self._original_mod.k_norm
403
- self.is_sliding = self._original_mod.is_sliding
404
212
 
405
213
  def get_attn_scale(self):
406
214
  return self._original_mod.config.query_pre_attn_scalar**-0.5
407
-
408
- def forward(
409
- self,
410
- hidden_states: torch.Tensor,
411
- attention_mask: torch.Tensor,
412
- seq_positions: torch.LongTensor,
413
- past_key_values: Tuple[Tuple[torch.Tensor]],
414
- cos: Optional[torch.Tensor] = None,
415
- sin: Optional[torch.Tensor] = None,
416
- block_tables: Optional[torch.Tensor] = None,
417
- ):
418
- batch_size, query_length, _ = hidden_states.size()
419
-
420
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
421
-
422
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
423
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
424
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
425
- 1, 2
426
- )
427
-
428
- query_states = self.q_norm(query_states)
429
- key_states = self.k_norm(key_states)
430
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
431
-
432
- attn_output = self.attention(
433
- query_states,
434
- key_states,
435
- value_states,
436
- attention_mask,
437
- past_key_state=past_key_values[self.layer_idx][0],
438
- past_value_state=past_key_values[self.layer_idx][1],
439
- seq_position=seq_positions,
440
- scale=self.scale,
441
- block_tables=block_tables,
442
- kvcache_block_size=self.kvcache_block_size,
443
- )
444
-
445
- attn_outputs = self.o_proj(attn_output)
446
- return attn_outputs