optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_distilbert import RBLNDistilBertForQuestionAnsweringConfig
16
+ from .modeling_distilbert import RBLNDistilBertForQuestionAnswering
17
+
18
+
19
+ __all__ = ["RBLNDistilBertForQuestionAnsweringConfig", "RBLNDistilBertForQuestionAnswering"]
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForQuestionAnsweringConfig
16
+
17
+
18
+ class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
19
+ """
20
+ Configuration class for RBLNDistilBertForQuestionAnswering.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized DistilBERT models for question answering tasks.
24
+ """
@@ -0,0 +1,27 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...modeling_generic import RBLNModelForQuestionAnswering
16
+
17
+
18
+ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
19
+ """
20
+ RBLN optimized DistilBERT model for question answering tasks.
21
+
22
+ This class provides hardware-accelerated inference for DistilBERT models
23
+ on RBLN devices, supporting extractive question answering tasks where
24
+ the model predicts start and end positions of answers in text.
25
+ """
26
+
27
+ rbln_model_input_names = ["input_ids", "attention_mask"]
@@ -16,4 +16,9 @@ from ...configuration_generic import RBLNModelForDepthEstimationConfig
16
16
 
17
17
 
18
18
  class RBLNDPTForDepthEstimationConfig(RBLNModelForDepthEstimationConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLNDPTForDepthEstimation.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized DPT (Dense Prediction Transformer) models for depth estimation tasks.
24
+ """
@@ -17,4 +17,9 @@ from ...modeling_generic import RBLNModelForDepthEstimation
17
17
 
18
18
 
19
19
  class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
20
- pass
20
+ """
21
+ RBLN optimized DPT model for depth estimation tasks.
22
+
23
+ This class provides hardware-accelerated inference for DPT (Dense Prediction Transformer)
24
+ models on RBLN devices, supporting monocular depth estimation from single images.
25
+ """
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNExaoneForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLN Exaone models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNExaoneForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNExaoneForCausalLM.from_pretrained(
37
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """
@@ -60,7 +60,11 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
60
60
  new_layer = ExaoneLayer(layer, new_self_attn)
61
61
  new_layers.append(new_layer)
62
62
  new_model = ExaoneModel(
63
- causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
63
+ causal_lm.transformer,
64
+ new_layers,
65
+ partition_len=self.kvcache_partition_len,
66
+ max_seq_len=max_seq_len,
67
+ sliding_window_layers=self.sliding_window_layers,
64
68
  )
65
69
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
66
70
  return new_causal_lm
@@ -13,7 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ import inspect
17
+ from typing import Any, Callable
18
+
16
19
  from transformers import AutoModelForCausalLM
20
+ from transformers.generation.utils import GenerationMixin
17
21
 
18
22
  from ....utils import logging
19
23
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -25,22 +29,79 @@ logger = logging.get_logger(__name__)
25
29
 
26
30
  class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
27
31
  """
28
- The Exaone Model transformer with a language modeling head on top (linear layer with weights tied to the input
29
- embeddings).
32
+ The Exaone Model transformer with a language modeling head (linear layer) on top.
33
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
30
34
 
31
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
32
- library implements for all its model.
35
+ A class to convert and run pre-trained transformers based ExaoneForCausalLM model on RBLN devices.
36
+ It implements the methods to convert a pre-trained transformers ExaoneForCausalLM model into a RBLN transformer model by:
33
37
 
34
- It implements the methods to convert a pre-trained transformers Exaone model into a RBLN transformer model by:
35
38
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
39
  - compiling the resulting graph using the RBLN compiler.
37
40
 
41
+ **Configuration:**
42
+ This model uses [`RBLNExaoneForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
43
+ the `rbln_config` parameter should be an instance of [`RBLNExaoneForCausalLMConfig`] or a dictionary conforming to its structure.
44
+
45
+ See the [`RBLNExaoneForCausalLMConfig`] class for all available configuration options.
46
+
47
+ Examples:
48
+ ```python
49
+ from optimum.rbln import RBLNExaoneForCausalLM
50
+
51
+ # Simple usage using rbln_* arguments
52
+ # `max_seq_len` is automatically inferred from the model config
53
+ model = RBLNExaoneForCausalLM.from_pretrained(
54
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
55
+ export=True,
56
+ rbln_batch_size=1,
57
+ rbln_tensor_parallel_size=4,
58
+ )
59
+
60
+
61
+ # Using a config dictionary
62
+ rbln_config = {
63
+ "batch_size": 1,
64
+ "max_seq_len": 4096,
65
+ "tensor_parallel_size": 4,
66
+ }
67
+ model = RBLNExaoneForCausalLM.from_pretrained(
68
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
69
+ export=True,
70
+ rbln_config=rbln_config
71
+ )
72
+
73
+
74
+ # Using a RBLNExaoneForCausalLMConfig instance (recommended for type checking)
75
+ from optimum.rbln import RBLNExaoneForCausalLMConfig
76
+
77
+ config = RBLNExaoneForCausalLMConfig(
78
+ batch_size=1,
79
+ max_seq_len=4096,
80
+ tensor_parallel_size=4
81
+ )
82
+ model = RBLNExaoneForCausalLM.from_pretrained(
83
+ "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
84
+ export=True,
85
+ rbln_config=config
86
+ )
87
+ ```
38
88
  """
39
89
 
40
90
  _decoder_wrapper_cls = ExaoneForCausalLMWrapper
41
91
  _hf_class = AutoModelForCausalLM
92
+ _supports_cache_class = True
42
93
 
43
94
  @classmethod
44
95
  def from_pretrained(cls, *args, **kwargs):
45
96
  kwargs.setdefault("trust_remote_code", True)
46
97
  return super().from_pretrained(*args, **kwargs)
98
+
99
+ def __getattr__(self, __name: str) -> Any:
100
+ def redirect(func):
101
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
102
+
103
+ val = getattr(GenerationMixin, __name)
104
+
105
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
106
+ return redirect(val)
107
+ return val
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLN Gemma models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNGemmaForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNGemmaForCausalLM.from_pretrained(
37
+ "google/gemma-7b",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """
@@ -52,7 +52,11 @@ class GemmaWrapper(DecoderOnlyWrapper):
52
52
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
53
  new_layers.append(new_layer)
54
54
  new_model = GemmaModel(
55
- causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
55
+ causal_lm.model,
56
+ new_layers,
57
+ partition_len=self.kvcache_partition_len,
58
+ max_seq_len=max_seq_len,
59
+ sliding_window_layers=self.sliding_window_layers,
56
60
  )
57
61
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
58
62
  return new_causal_lm
@@ -27,8 +27,57 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
27
27
 
28
28
  A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
29
29
  It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
30
+
30
31
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
31
32
  - compiling the resulting graph using the RBLN compiler.
33
+
34
+ **Configuration:**
35
+ This model uses [`RBLNGemmaForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
36
+ the `rbln_config` parameter should be an instance of [`RBLNGemmaForCausalLMConfig`] or a dictionary conforming to its structure.
37
+
38
+ See the [`RBLNGemmaForCausalLMConfig`] class for all available configuration options.
39
+
40
+ Examples:
41
+ ```python
42
+ from optimum.rbln import RBLNGemmaForCausalLM
43
+
44
+ # Simple usage using rbln_* arguments
45
+ # `max_seq_len` is automatically inferred from the model config
46
+ model = RBLNGemmaForCausalLM.from_pretrained(
47
+ "google/gemma-7b",
48
+ export=True,
49
+ rbln_batch_size=1,
50
+ rbln_tensor_parallel_size=4,
51
+ )
52
+
53
+
54
+ # Using a config dictionary
55
+ rbln_config = {
56
+ "batch_size": 1,
57
+ "max_seq_len": 4096,
58
+ "tensor_parallel_size": 4,
59
+ }
60
+ model = RBLNGemmaForCausalLM.from_pretrained(
61
+ "google/gemma-7b",
62
+ export=True,
63
+ rbln_config=rbln_config
64
+ )
65
+
66
+
67
+ # Using a RBLNGemmaForCausalLMConfig instance (recommended for type checking)
68
+ from optimum.rbln import RBLNGemmaForCausalLMConfig
69
+
70
+ config = RBLNGemmaForCausalLMConfig(
71
+ batch_size=1,
72
+ max_seq_len=4096,
73
+ tensor_parallel_size=4
74
+ )
75
+ model = RBLNGemmaForCausalLM.from_pretrained(
76
+ "google/gemma-7b",
77
+ export=True,
78
+ rbln_config=config
79
+ )
80
+ ```
32
81
  """
33
82
 
34
83
  _decoder_wrapper_cls = GemmaWrapper
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Optional
14
+ from typing import Any, Dict, Optional
15
15
 
16
16
  import rebel
17
17
 
@@ -26,7 +26,7 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
26
26
  prefill_chunk_size: Optional[int] = None,
27
27
  use_position_ids: Optional[bool] = None,
28
28
  use_attention_mask: Optional[bool] = None,
29
- **kwargs,
29
+ **kwargs: Dict[str, Any],
30
30
  ):
31
31
  # use_attention_mask and use_position_ids are always True for Gemma3
32
32
  use_attention_mask = use_attention_mask or True
@@ -53,7 +53,7 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
53
53
  batch_size: Optional[int] = None,
54
54
  vision_tower: Optional[RBLNModelConfig] = None,
55
55
  language_model: Optional[RBLNModelConfig] = None,
56
- **kwargs,
56
+ **kwargs: Dict[str, Any],
57
57
  ):
58
58
  """
59
59
  Args: