optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -12,24 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
16
-
17
- from transformers import PretrainedConfig
18
15
 
19
16
  from ....utils import logging
20
17
  from ...models.decoderonly import (
21
18
  RBLNDecoderOnlyModel,
22
19
  RBLNDecoderOnlyModelForCausalLM,
23
- RBLNDecoderOnlyModelForCausalLMConfig,
24
20
  )
25
21
  from .qwen3_architecture import Qwen3Wrapper
26
22
 
27
23
 
28
24
  logger = logging.get_logger(__name__)
29
25
 
30
- if TYPE_CHECKING:
31
- from transformers import PretrainedConfig
32
-
33
26
 
34
27
  class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
35
28
  """
@@ -84,19 +77,6 @@ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
84
77
 
85
78
  _decoder_wrapper_cls = Qwen3Wrapper
86
79
 
87
- @classmethod
88
- def _update_sliding_window_config(
89
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
90
- ):
91
- # https://github.com/huggingface/transformers/issues/35896
92
- # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
93
- # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
94
-
95
- rbln_config.cache_impl = "sliding_window"
96
- rbln_config.sliding_window = model_config.sliding_window
97
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
98
- return rbln_config
99
-
100
80
  def forward(self, *args, **kwargs):
101
81
  kwargs["return_dict"] = True
102
82
  return super().forward(*args, **kwargs)
@@ -22,10 +22,10 @@ class Qwen3Wrapper(DecoderOnlyWrapper):
22
22
 
23
23
 
24
24
  class Qwen3Attention(DecoderOnlyAttention):
25
- def __post_init__(self):
26
- self.k_proj = self._original_mod.k_proj
27
- self.v_proj = self._original_mod.v_proj
28
- self.q_proj = self._original_mod.q_proj
29
- self.o_proj = self._original_mod.o_proj
30
- self.q_norm = self._original_mod.q_norm
31
- self.k_norm = self._original_mod.k_norm
25
+ def __post_init__(self, self_attn):
26
+ self.q_proj = self_attn.q_proj
27
+ self.k_proj = self_attn.k_proj
28
+ self.v_proj = self_attn.v_proj
29
+ self.o_proj = self_attn.o_proj
30
+ self.q_norm = self_attn.q_norm
31
+ self.k_norm = self_attn.k_norm
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_qwen3_moe import RBLNQwen3MoeForCausalLMConfig
16
+ from .modeling_qwen3_moe import RBLNQwen3MoeForCausalLM
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNQwen3MoeForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Qwen3 Moe models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNQwen3MoeForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=262144,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
33
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from .qwen3_moe_architecture import Qwen3MoeWrapper
17
+
18
+
19
+ class RBLNQwen3MoeForCausalLM(RBLNDecoderOnlyModelForCausalLM):
20
+ """
21
+ The Qwen3 Moe is a Mixture-of-Experts (MoE) variant of Qwen3, available as a base model and an aligned chat model.
22
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
23
+ A class to convert and run pre-trained transformers based Qwen3MoeForCausalLM model on RBLN devices.
24
+ It implements the methods to convert a pre-trained transformers Qwen3MoeForCausalLM model into a RBLN transformer model by:
25
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
26
+ - compiling the resulting graph using the RBLN compiler.
27
+ **Configuration:**
28
+ This model uses [`RBLNQwen3MoeForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
29
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3MoeForCausalLMConfig`] or a dictionary conforming to its structure.
30
+ See the [`RBLNQwen3MoeForCausalLMConfig`] class for all available configuration options.
31
+ Examples:
32
+ ```python
33
+ from optimum.rbln import RBLNQwen3MoeForCausalLM
34
+ # Simple usage using rbln_* arguments
35
+ # `max_seq_len` is automatically inferred from the model config
36
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
37
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
38
+ export=True,
39
+ rbln_batch_size=1,
40
+ rbln_tensor_parallel_size=4,
41
+ )
42
+ # Using a config dictionary
43
+ rbln_config = {
44
+ "batch_size": 1,
45
+ "max_seq_len": 262144,
46
+ "tensor_parallel_size": 4,
47
+ }
48
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
49
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
50
+ export=True,
51
+ rbln_config=rbln_config
52
+ )
53
+ # Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
54
+ from optimum.rbln import RBLNQwen3MoeForCausalLMConfig
55
+ config = RBLNQwen3MoeForCausalLMConfig(
56
+ batch_size=1,
57
+ max_seq_len=262144,
58
+ tensor_parallel_size=4
59
+ )
60
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
61
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
62
+ export=True,
63
+ rbln_config=config
64
+ )
65
+ ```
66
+ """
67
+
68
+ _decoder_wrapper_cls = Qwen3MoeWrapper
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
21
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
22
+
23
+
24
+ class Qwen3MoeWrapper(DecoderOnlyWrapper):
25
+ def get_rbln_layer_class(self):
26
+ return Qwen3MoeLayer
27
+
28
+ def get_rbln_attn_class(self):
29
+ return Qwen3MoeAttention
30
+
31
+
32
+ class Qwen3MoeAttention(DecoderOnlyAttention):
33
+ def __post_init__(self, self_attn):
34
+ self.q_proj = self_attn.q_proj
35
+ self.k_proj = self_attn.k_proj
36
+ self.v_proj = self_attn.v_proj
37
+ self.o_proj = self_attn.o_proj
38
+ self.q_norm = self_attn.q_norm
39
+ self.k_norm = self_attn.k_norm
40
+
41
+
42
+ class Qwen3MoeLayer(DecoderOnlyLayer):
43
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
44
+ super().__init__(layer, self_attn, lora_config)
45
+ self.mlp = (
46
+ Qwen3MoeSparseMoeBlock(layer.mlp)
47
+ if layer.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock"
48
+ else layer.mlp
49
+ )
50
+
51
+ def get_mlp(self) -> nn.Module:
52
+ return self.mlp
53
+
54
+
55
+ class Qwen3MoeSparseMoeBlock(nn.Module):
56
+ def __init__(self, model: nn.Module):
57
+ super().__init__()
58
+ self.num_experts = model.num_experts
59
+ self.top_k = model.top_k
60
+ self.norm_topk_prob = model.norm_topk_prob
61
+ self.gate = model.gate
62
+ self.experts = Qwen3MoeMLP(model.experts, self.top_k, self.norm_topk_prob)
63
+
64
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
65
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
66
+ hidden_states = hidden_states.view(-1, hidden_dim)
67
+
68
+ # router_logits: (batch * sequence_length, n_experts)
69
+ router_logits = self.gate(hidden_states)
70
+ final_hidden_states = self.experts(hidden_states, router_logits)
71
+
72
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
73
+ return final_hidden_states
74
+
75
+
76
+ class Qwen3MoeMLP(nn.Module):
77
+ def __init__(self, expert_list, top_k, norm_topk_prob):
78
+ super().__init__()
79
+ self.hidden_size = expert_list[0].hidden_size
80
+ self.intermediate_size = expert_list[0].intermediate_size
81
+ self.top_k = top_k
82
+ self.norm_topk_prob = norm_topk_prob
83
+ self.num_experts = len(expert_list)
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(self.num_experts * self.intermediate_size, self.hidden_size, bias=False)
87
+ self.gate_proj.weight.data = torch.stack([expert.gate_proj.weight.data for expert in expert_list], dim=0)
88
+ self.up_proj.weight.data = torch.stack([expert.up_proj.weight.data for expert in expert_list], dim=0)
89
+ self.down_proj.weight.data = torch.stack([expert.down_proj.weight.data for expert in expert_list], dim=0)
90
+
91
+ def forward(self, x, router_logits):
92
+ return torch.ops.rbln_custom_ops.custom_moe_glu(
93
+ x,
94
+ self.gate_proj.weight,
95
+ self.up_proj.weight,
96
+ self.down_proj.weight,
97
+ router_logits,
98
+ self.top_k,
99
+ self.norm_topk_prob,
100
+ )
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  from ...configuration_generic import RBLNModelForImageClassificationConfig
17
19
 
18
20
 
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
23
25
  This configuration class stores the configuration parameters specific to
24
26
  RBLN-optimized ResNet models for image classification tasks.
25
27
  """
28
+
29
+ def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
30
+ """
31
+ Args:
32
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
33
+ Can be an integer for square images or a tuple (height, width).
34
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
35
+ output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
36
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.output_hidden_states = output_hidden_states
@@ -13,7 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
20
+
16
21
  from ...modeling_generic import RBLNModelForImageClassification
22
+ from .configuration_resnet import RBLNResNetForImageClassificationConfig
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
17
27
 
18
28
 
19
29
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
24
34
  on RBLN devices, supporting image classification with convolutional neural networks
25
35
  designed for computer vision tasks.
26
36
  """
37
+
38
+ @classmethod
39
+ def _update_rbln_config(
40
+ cls,
41
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
42
+ model: Optional["PreTrainedModel"] = None,
43
+ model_config: Optional["PretrainedConfig"] = None,
44
+ rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
45
+ ) -> "RBLNResNetForImageClassificationConfig":
46
+ if rbln_config.output_hidden_states is None:
47
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
48
+
49
+ rbln_config = super()._update_rbln_config(
50
+ preprocessors=preprocessors,
51
+ model=model,
52
+ model_config=model_config,
53
+ rbln_config=rbln_config,
54
+ )
55
+
56
+ return rbln_config
57
+
58
+ @classmethod
59
+ def _wrap_model_if_needed(
60
+ cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
61
+ ) -> torch.nn.Module:
62
+ class _ResNetForImageClassification(torch.nn.Module):
63
+ def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
64
+ super().__init__()
65
+ self.model = model
66
+ self.output_hidden_states = output_hidden_states
67
+
68
+ def forward(self, *args, **kwargs):
69
+ output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
70
+ return output
71
+
72
+ return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
73
+
74
+ def forward(
75
+ self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
76
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
77
+ """
78
+ Foward pass for the RBLN-optimized ResNet model for image classification.
79
+
80
+ Args:
81
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
82
+ output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
83
+ See hidden_states under returned tensors for more details.
84
+ return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
85
+
86
+ Returns:
87
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
88
+ """
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
91
+ )
92
+
93
+ if output_hidden_states != self.rbln_config.output_hidden_states:
94
+ raise ValueError(
95
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
96
+ f"Please compile again with the correct argument."
97
+ )
98
+
99
+ return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
16
21
 
17
22
 
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
26
31
 
27
32
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
33
 
34
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
35
+ """
36
+ Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
37
+
38
+ Args:
39
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+
42
+ Returns:
43
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
44
+ """
45
+ return super().forward(input_ids, attention_mask, **kwargs)
46
+
29
47
 
30
48
  class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
31
49
  """
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
37
55
  """
38
56
 
39
57
  rbln_model_input_names = ["input_ids", "attention_mask"]
58
+
59
+ def forward(
60
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
61
+ ) -> Union[Tuple, SequenceClassifierOutput]:
62
+ """
63
+ Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
64
+
65
+ Args:
66
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
67
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
68
+
69
+ Returns:
70
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
71
+ """
72
+ return super().forward(input_ids, attention_mask, **kwargs)
@@ -15,6 +15,7 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.deprecation import deprecate_kwarg
18
19
  from ....utils.logging import get_logger
19
20
 
20
21
 
@@ -24,13 +25,13 @@ logger = get_logger()
24
25
  class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
25
26
  support_paged_attention = None
26
27
 
28
+ @deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
27
29
  def __init__(
28
30
  self,
29
31
  batch_size: Optional[int] = None,
30
32
  enc_max_seq_len: Optional[int] = None,
31
33
  dec_max_seq_len: Optional[int] = None,
32
34
  use_attention_mask: Optional[bool] = None,
33
- pad_token_id: Optional[int] = None,
34
35
  kvcache_num_blocks: Optional[int] = None,
35
36
  kvcache_block_size: Optional[int] = None,
36
37
  **kwargs: Any,
@@ -41,7 +42,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
41
42
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
42
43
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
43
44
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
44
- pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
45
45
  kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
46
46
  PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
47
47
  kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
@@ -61,8 +61,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
61
61
 
62
62
  self.use_attention_mask = use_attention_mask
63
63
 
64
- self.pad_token_id = pad_token_id
65
-
66
64
  if self.support_paged_attention:
67
65
  self.kvcache_num_blocks = kvcache_num_blocks
68
66
  self.kvcache_block_size = kvcache_block_size
@@ -20,8 +20,9 @@ import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.generation.configuration_utils import GenerationConfig
23
24
  from transformers.generation.utils import GenerationMixin
24
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
25
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
25
26
 
26
27
  from ....configuration_utils import RBLNCompileConfig
27
28
  from ....modeling import RBLNModel
@@ -33,7 +34,7 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
33
34
  logger = get_logger(__name__)
34
35
 
35
36
  if TYPE_CHECKING:
36
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
37
38
 
38
39
 
39
40
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
@@ -140,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
140
141
  @classmethod
141
142
  @torch.inference_mode()
142
143
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
143
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
144
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
144
145
 
145
146
  enc_compile_config = rbln_config.compile_cfgs[0]
146
147
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -209,8 +210,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
209
210
  if not cls.support_causal_attn:
210
211
  rbln_config.use_attention_mask = True
211
212
 
212
- n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
213
- n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
213
+ n_layer = getattr(model_config, "decoder_layers", None) or model_config.num_layers
214
+ n_head = getattr(model_config, "decoder_attention_heads", None) or model_config.num_heads
214
215
  d_kv = (
215
216
  model_config.d_kv
216
217
  if hasattr(model_config, "d_kv")
@@ -221,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
221
222
  model_config, "max_position_embeddings", None
222
223
  )
223
224
 
224
- pad_token_id = getattr(model_config, "pad_token_id", None)
225
- pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
226
- pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
227
- pad_token_id = pad_token_id or -1
228
- rbln_config.pad_token_id = pad_token_id
229
-
230
225
  if rbln_config.enc_max_seq_len is None:
231
226
  enc_max_seq_len = max_position_embeddings
232
227
  for tokenizer in preprocessors:
@@ -432,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
432
427
  inputs_tensor = torch.nn.functional.pad(
433
428
  inputs_tensor,
434
429
  (0, self.rbln_config.enc_max_seq_len - input_len),
435
- value=self.rbln_config.pad_token_id,
430
+ value=self.config.pad_token_id,
436
431
  )
437
432
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
438
433
  model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
@@ -451,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
451
446
  model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
452
447
 
453
448
  return model_kwargs
449
+
450
+ def generate(
451
+ self,
452
+ input_ids: torch.LongTensor,
453
+ attention_mask: Optional[torch.LongTensor] = None,
454
+ generation_config: Optional[GenerationConfig] = None,
455
+ **kwargs,
456
+ ) -> Union[ModelOutput, torch.LongTensor]:
457
+ """
458
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
459
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
460
+
461
+ Args:
462
+ input_ids (torch.LongTensor): The input ids to the model.
463
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
464
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
465
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
466
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
467
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
468
+
469
+ Returns:
470
+ Generates sequences of token ids for models with a language modeling head.
471
+ """
472
+ if generation_config is not None:
473
+ kwargs["generation_config"] = generation_config
474
+ if attention_mask is not None:
475
+ kwargs["attention_mask"] = attention_mask
476
+
477
+ return super().generate(input_ids, **kwargs)
@@ -268,13 +268,12 @@ class Seq2SeqDecoder(torch.nn.Module):
268
268
 
269
269
  def __init__(self, model, layers, **kwargs):
270
270
  super().__init__()
271
- self._original_mod = model
272
271
  self.layers = nn.ModuleList(layers)
273
272
  self.embed_tokens = model.embed_tokens
274
- self.final_layer_norm = getattr(model, "final_layer_norm", None)
275
- self.__post_init__(**kwargs)
273
+ self.final_layer_norm = getattr(model, "final_layer_norm", None) or getattr(model, "layer_norm", None)
274
+ self.__post_init__(model, **kwargs)
276
275
 
277
- def __post_init__(self, **kwargs):
276
+ def __post_init__(self, model: nn.Module, **kwargs):
278
277
  """
279
278
  Abstract method intended to be overridden by subclasses to modify or override
280
279
  the attributes of the original model after initialization.
@@ -344,12 +343,11 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
344
343
 
345
344
  def __init__(self, decoder_layer, self_attn, cross_attn):
346
345
  super().__init__()
347
- self._original_mod = decoder_layer
348
346
  self.self_attn = self_attn
349
347
  self.cross_attn = cross_attn
350
- self.__post_init__()
348
+ self.__post_init__(decoder_layer)
351
349
 
352
- def __post_init__(self, **kwargs):
350
+ def __post_init__(self, decoder_layer: nn.Module, **kwargs):
353
351
  """
354
352
  Abstract method intended to be overridden by subclasses to modify or override
355
353
  the attributes of the original model after initialization.
@@ -423,10 +421,9 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
423
421
  class Seq2SeqSelfAttention(nn.Module):
424
422
  def __init__(self, attn, **kwargs):
425
423
  super().__init__()
426
- self._original_mod = attn
427
- self.__post_init__(**kwargs)
424
+ self.__post_init__(attn, **kwargs)
428
425
 
429
- def __post_init__(self, **kwargs):
426
+ def __post_init__(self, attn: nn.Module, **kwargs):
430
427
  """
431
428
  Abstract method intended to be overridden by subclasses to modify or override
432
429
  the attributes of the original model after initialization.
@@ -495,8 +492,13 @@ class Seq2SeqSelfAttention(nn.Module):
495
492
  class Seq2SeqCrossAttention(nn.Module):
496
493
  def __init__(self, attn, **kwargs):
497
494
  super().__init__()
498
- self._original_mod = attn
499
- self.__post_init__(**kwargs)
495
+ self.__post_init__(attn, **kwargs)
496
+
497
+ def __post_init__(self, attn: nn.Module, **kwargs):
498
+ """
499
+ Optional post-init hook for subclasses (e.g., to register q/k/v/out projections).
500
+ """
501
+ pass
500
502
 
501
503
  def forward(
502
504
  self,