optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_qwen3_moe import RBLNQwen3MoeForCausalLMConfig
16
+ from .modeling_qwen3_moe import RBLNQwen3MoeForCausalLM
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNQwen3MoeForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Qwen3 Moe models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNQwen3MoeForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=262144,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
33
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from .qwen3_moe_architecture import Qwen3MoeWrapper
17
+
18
+
19
+ class RBLNQwen3MoeForCausalLM(RBLNDecoderOnlyModelForCausalLM):
20
+ """
21
+ The Qwen3 Moe is a Mixture-of-Experts (MoE) variant of Qwen3, available as a base model and an aligned chat model.
22
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
23
+ A class to convert and run pre-trained transformers based Qwen3MoeForCausalLM model on RBLN devices.
24
+ It implements the methods to convert a pre-trained transformers Qwen3MoeForCausalLM model into a RBLN transformer model by:
25
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
26
+ - compiling the resulting graph using the RBLN compiler.
27
+ **Configuration:**
28
+ This model uses [`RBLNQwen3MoeForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
29
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3MoeForCausalLMConfig`] or a dictionary conforming to its structure.
30
+ See the [`RBLNQwen3MoeForCausalLMConfig`] class for all available configuration options.
31
+ Examples:
32
+ ```python
33
+ from optimum.rbln import RBLNQwen3MoeForCausalLM
34
+ # Simple usage using rbln_* arguments
35
+ # `max_seq_len` is automatically inferred from the model config
36
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
37
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
38
+ export=True,
39
+ rbln_batch_size=1,
40
+ rbln_tensor_parallel_size=4,
41
+ )
42
+ # Using a config dictionary
43
+ rbln_config = {
44
+ "batch_size": 1,
45
+ "max_seq_len": 262144,
46
+ "tensor_parallel_size": 4,
47
+ }
48
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
49
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
50
+ export=True,
51
+ rbln_config=rbln_config
52
+ )
53
+ # Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
54
+ from optimum.rbln import RBLNQwen3MoeForCausalLMConfig
55
+ config = RBLNQwen3MoeForCausalLMConfig(
56
+ batch_size=1,
57
+ max_seq_len=262144,
58
+ tensor_parallel_size=4
59
+ )
60
+ model = RBLNQwen3MoeForCausalLM.from_pretrained(
61
+ "Qwen/Qwen3-30B-A3B-Thinking-2507",
62
+ export=True,
63
+ rbln_config=config
64
+ )
65
+ ```
66
+ """
67
+
68
+ _decoder_wrapper_cls = Qwen3MoeWrapper
@@ -0,0 +1,100 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
21
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
22
+
23
+
24
+ class Qwen3MoeWrapper(DecoderOnlyWrapper):
25
+ def get_rbln_layer_class(self):
26
+ return Qwen3MoeLayer
27
+
28
+ def get_rbln_attn_class(self):
29
+ return Qwen3MoeAttention
30
+
31
+
32
+ class Qwen3MoeAttention(DecoderOnlyAttention):
33
+ def __post_init__(self, self_attn):
34
+ self.q_proj = self_attn.q_proj
35
+ self.k_proj = self_attn.k_proj
36
+ self.v_proj = self_attn.v_proj
37
+ self.o_proj = self_attn.o_proj
38
+ self.q_norm = self_attn.q_norm
39
+ self.k_norm = self_attn.k_norm
40
+
41
+
42
+ class Qwen3MoeLayer(DecoderOnlyLayer):
43
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
44
+ super().__init__(layer, self_attn, lora_config)
45
+ self.mlp = (
46
+ Qwen3MoeSparseMoeBlock(layer.mlp)
47
+ if layer.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock"
48
+ else layer.mlp
49
+ )
50
+
51
+ def get_mlp(self) -> nn.Module:
52
+ return self.mlp
53
+
54
+
55
+ class Qwen3MoeSparseMoeBlock(nn.Module):
56
+ def __init__(self, model: nn.Module):
57
+ super().__init__()
58
+ self.num_experts = model.num_experts
59
+ self.top_k = model.top_k
60
+ self.norm_topk_prob = model.norm_topk_prob
61
+ self.gate = model.gate
62
+ self.experts = Qwen3MoeMLP(model.experts, self.top_k, self.norm_topk_prob)
63
+
64
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
65
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
66
+ hidden_states = hidden_states.view(-1, hidden_dim)
67
+
68
+ # router_logits: (batch * sequence_length, n_experts)
69
+ router_logits = self.gate(hidden_states)
70
+ final_hidden_states = self.experts(hidden_states, router_logits)
71
+
72
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
73
+ return final_hidden_states
74
+
75
+
76
+ class Qwen3MoeMLP(nn.Module):
77
+ def __init__(self, expert_list, top_k, norm_topk_prob):
78
+ super().__init__()
79
+ self.hidden_size = expert_list[0].hidden_size
80
+ self.intermediate_size = expert_list[0].intermediate_size
81
+ self.top_k = top_k
82
+ self.norm_topk_prob = norm_topk_prob
83
+ self.num_experts = len(expert_list)
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(self.num_experts * self.intermediate_size, self.hidden_size, bias=False)
87
+ self.gate_proj.weight.data = torch.stack([expert.gate_proj.weight.data for expert in expert_list], dim=0)
88
+ self.up_proj.weight.data = torch.stack([expert.up_proj.weight.data for expert in expert_list], dim=0)
89
+ self.down_proj.weight.data = torch.stack([expert.down_proj.weight.data for expert in expert_list], dim=0)
90
+
91
+ def forward(self, x, router_logits):
92
+ return torch.ops.rbln_custom_ops.custom_moe_glu(
93
+ x,
94
+ self.gate_proj.weight,
95
+ self.up_proj.weight,
96
+ self.down_proj.weight,
97
+ router_logits,
98
+ self.top_k,
99
+ self.norm_topk_prob,
100
+ )
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Optional
16
+ from typing import Any, Optional, Tuple, Union
17
17
 
18
18
  from ...configuration_generic import RBLNModelForImageClassificationConfig
19
19
 
@@ -26,17 +26,23 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
26
26
  RBLN-optimized ResNet models for image classification tasks.
27
27
  """
28
28
 
29
- def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
29
+ def __init__(
30
+ self,
31
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
32
+ batch_size: Optional[int] = None,
33
+ output_hidden_states: Optional[bool] = None,
34
+ **kwargs: Any,
35
+ ):
30
36
  """
31
37
  Args:
32
38
  image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
33
39
  Can be an integer for square images or a tuple (height, width).
34
40
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
35
- output_hidden_states (bool, optional) Whether or not to return the hidden states of all layers.
41
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
36
42
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
43
 
38
44
  Raises:
39
45
  ValueError: If batch_size is not a positive integer.
40
46
  """
41
- super().__init__(**kwargs)
47
+ super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
42
48
  self.output_hidden_states = output_hidden_states
@@ -268,13 +268,12 @@ class Seq2SeqDecoder(torch.nn.Module):
268
268
 
269
269
  def __init__(self, model, layers, **kwargs):
270
270
  super().__init__()
271
- self._original_mod = model
272
271
  self.layers = nn.ModuleList(layers)
273
272
  self.embed_tokens = model.embed_tokens
274
- self.final_layer_norm = getattr(model, "final_layer_norm", None)
275
- self.__post_init__(**kwargs)
273
+ self.final_layer_norm = getattr(model, "final_layer_norm", None) or getattr(model, "layer_norm", None)
274
+ self.__post_init__(model, **kwargs)
276
275
 
277
- def __post_init__(self, **kwargs):
276
+ def __post_init__(self, model: nn.Module, **kwargs):
278
277
  """
279
278
  Abstract method intended to be overridden by subclasses to modify or override
280
279
  the attributes of the original model after initialization.
@@ -344,12 +343,11 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
344
343
 
345
344
  def __init__(self, decoder_layer, self_attn, cross_attn):
346
345
  super().__init__()
347
- self._original_mod = decoder_layer
348
346
  self.self_attn = self_attn
349
347
  self.cross_attn = cross_attn
350
- self.__post_init__()
348
+ self.__post_init__(decoder_layer)
351
349
 
352
- def __post_init__(self, **kwargs):
350
+ def __post_init__(self, decoder_layer: nn.Module, **kwargs):
353
351
  """
354
352
  Abstract method intended to be overridden by subclasses to modify or override
355
353
  the attributes of the original model after initialization.
@@ -423,10 +421,9 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
423
421
  class Seq2SeqSelfAttention(nn.Module):
424
422
  def __init__(self, attn, **kwargs):
425
423
  super().__init__()
426
- self._original_mod = attn
427
- self.__post_init__(**kwargs)
424
+ self.__post_init__(attn, **kwargs)
428
425
 
429
- def __post_init__(self, **kwargs):
426
+ def __post_init__(self, attn: nn.Module, **kwargs):
430
427
  """
431
428
  Abstract method intended to be overridden by subclasses to modify or override
432
429
  the attributes of the original model after initialization.
@@ -495,8 +492,13 @@ class Seq2SeqSelfAttention(nn.Module):
495
492
  class Seq2SeqCrossAttention(nn.Module):
496
493
  def __init__(self, attn, **kwargs):
497
494
  super().__init__()
498
- self._original_mod = attn
499
- self.__post_init__(**kwargs)
495
+ self.__post_init__(attn, **kwargs)
496
+
497
+ def __post_init__(self, attn: nn.Module, **kwargs):
498
+ """
499
+ Optional post-init hook for subclasses (e.g., to register q/k/v/out projections).
500
+ """
501
+ pass
500
502
 
501
503
  def forward(
502
504
  self,
@@ -21,6 +21,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
21
21
  from ....configuration_utils import RBLNCompileConfig
22
22
  from ....modeling import RBLNModel
23
23
  from ....utils.logging import get_logger
24
+ from ...modeling_outputs import _validate_output_attentions, _validate_output_hidden_states
24
25
  from .configuration_siglip import RBLNSiglipVisionModelConfig
25
26
 
26
27
 
@@ -52,7 +53,7 @@ class _SiglipVisionModel(torch.nn.Module):
52
53
  interpolate_pos_encoding=self.interpolate_pos_encoding,
53
54
  output_attentions=self.output_attentions,
54
55
  )
55
- return tuple(x for x in enc_out if x is not None)
56
+ return enc_out
56
57
 
57
58
 
58
59
  class RBLNSiglipVisionModel(RBLNModel):
@@ -138,23 +139,8 @@ class RBLNSiglipVisionModel(RBLNModel):
138
139
  The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
139
140
  """
140
141
 
141
- output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
142
- output_hidden_states = (
143
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
144
- )
145
-
146
- if output_attentions != self.rbln_config.output_attentions:
147
- raise ValueError(
148
- f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
149
- f"Please compile again with the correct argument."
150
- )
151
-
152
- if output_hidden_states != self.rbln_config.output_hidden_states:
153
- raise ValueError(
154
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
155
- f"Please compile again with the correct argument."
156
- )
157
-
142
+ output_attentions = _validate_output_attentions(output_attentions, self.rbln_config)
143
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
158
144
  if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
159
145
  raise ValueError(
160
146
  f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
@@ -32,11 +32,6 @@ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
32
32
  Raises:
33
33
  ValueError: If batch_size is not a positive integer.
34
34
  """
35
- super().__init__(**kwargs)
36
- self.batch_size = batch_size or 1
37
- if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
- raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
39
-
40
- self.image_size = image_size
35
+ super().__init__(batch_size=batch_size, image_size=image_size, **kwargs)
41
36
  self.output_hidden_states = output_hidden_states
42
37
  self.output_attentions = output_attentions
@@ -111,9 +111,9 @@ class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
111
111
  class T5Decoder(Seq2SeqDecoder):
112
112
  has_pos_emb = False
113
113
 
114
- def __post_init__(self, dec_max_seq_len: int = None):
115
- self.invert_attention_mask = self._original_mod.invert_attention_mask
116
- self._dec_position_bias = self.precompute_dec_position_bias(self._original_mod, dec_max_seq_len)
114
+ def __post_init__(self, model: nn.Module, dec_max_seq_len: int = None):
115
+ self.invert_attention_mask = model.invert_attention_mask
116
+ self._dec_position_bias = self.precompute_dec_position_bias(model, dec_max_seq_len)
117
117
 
118
118
  def precompute_dec_position_bias(self, model, dec_max_length):
119
119
  attn_layer = model.block[0].layer[0].SelfAttention
@@ -145,13 +145,12 @@ class T5Decoder(Seq2SeqDecoder):
145
145
  class T5Block(Seq2SeqDecoderLayer):
146
146
  def __init__(self, decoder_layer, self_attn):
147
147
  super().__init__(decoder_layer, self_attn, cross_attn=None)
148
- self.__post_init__()
149
148
 
150
- def __post_init__(self):
151
- self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
152
- self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
153
- self.cross_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
154
- self.ff_layer = self._original_mod.layer[2]
149
+ def __post_init__(self, decoder_layer: nn.Module):
150
+ self.self_attn_layer_norm = decoder_layer.layer[0].layer_norm
151
+ self.encoder_attn_layer_norm = decoder_layer.layer[1].layer_norm
152
+ self.cross_attn = T5CrossAttention(decoder_layer.layer[1].EncDecAttention)
153
+ self.ff_layer = decoder_layer.layer[2]
155
154
 
156
155
  def pre_self_attn_layer_norm(self, hidden_states):
157
156
  return self.self_attn_layer_norm(hidden_states)
@@ -167,13 +166,13 @@ class T5Block(Seq2SeqDecoderLayer):
167
166
 
168
167
 
169
168
  class T5LayerSelfAttention(Seq2SeqSelfAttention):
170
- def __post_init__(self):
171
- self.q_proj = self._original_mod.q
172
- self.k_proj = self._original_mod.k
173
- self.v_proj = self._original_mod.v
174
- self.out_proj = self._original_mod.o
175
- self.num_heads = self._original_mod.n_heads
176
- self.head_dim = self._original_mod.key_value_proj_dim
169
+ def __post_init__(self, attn: nn.Module):
170
+ self.q_proj = attn.q
171
+ self.k_proj = attn.k
172
+ self.v_proj = attn.v
173
+ self.out_proj = attn.o
174
+ self.num_heads = attn.n_heads
175
+ self.head_dim = attn.key_value_proj_dim
177
176
  self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
178
177
 
179
178
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -140,7 +140,6 @@ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
140
140
  class TimeSeriesTransformersDecoder(nn.Module):
141
141
  def __init__(self, model, layers, **kwargs):
142
142
  super().__init__()
143
- self._original_mod = model
144
143
  self.config = model.config
145
144
  self.layers = nn.ModuleList(layers)
146
145
  self.value_embedding = model.value_embedding
@@ -190,7 +189,6 @@ class TimeSeriesTransformersDecoder(nn.Module):
190
189
  class TimeSeriesTransformersDecoderLayer(nn.Module):
191
190
  def __init__(self, decoder_layer, self_attn, cross_attn):
192
191
  super().__init__()
193
- self._original_mod = decoder_layer
194
192
  self.self_attn = self_attn
195
193
  self.encoder_attn = cross_attn
196
194
  self.embed_dim = decoder_layer.embed_dim
@@ -245,7 +243,6 @@ class TimeSeriesTransformersDecoderLayer(nn.Module):
245
243
  class TimeSeriesTransformersAttention(nn.Module):
246
244
  def __init__(self, attn, num_parallel_samples):
247
245
  super().__init__()
248
- self._original_mod = attn
249
246
  self.q_proj = attn.q_proj
250
247
  self.k_proj = attn.k_proj
251
248
  self.v_proj = attn.v_proj
@@ -51,22 +51,22 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
51
51
  return_segments: Optional[bool] = None,
52
52
  return_timestamps: Optional[bool] = None,
53
53
  return_token_timestamps: Optional[bool] = None,
54
- **kwargs,
54
+ **kwargs: Optional[Dict[str, Any]],
55
55
  ) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
56
56
  """
57
57
  The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
58
58
  Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
59
59
 
60
60
  Args:
61
- input_features(torch.Tensor, optional): The input features to the model.
62
- attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
- generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
61
+ input_features (torch.Tensor, optional): The input features to the model.
62
+ attention_mask (torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
64
64
  If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
65
65
  Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
66
- return_segments(bool, optional): Whether to return segments.
67
- return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
- return_token_timestamps(bool, optional): Whether to return token timestamps.
69
- kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
66
+ return_segments (bool, optional): Whether to return segments.
67
+ return_timestamps (bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
+ return_token_timestamps (bool, optional): Whether to return token timestamps.
69
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function.
70
70
 
71
71
  Returns:
72
72
  Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
@@ -154,7 +154,6 @@ class WhisperDecoderWrapper(torch.nn.Module):
154
154
  class WhisperDecoder(nn.Module):
155
155
  def __init__(self, model, layers, **kwargs):
156
156
  super().__init__()
157
- self._original_mod = model
158
157
  self.layers = nn.ModuleList(layers)
159
158
  self.embed_tokens = model.embed_tokens
160
159
  self.layer_norm = model.layer_norm
@@ -210,7 +209,6 @@ class WhisperDecoder(nn.Module):
210
209
  class WhisperDecoderLayer(nn.Module):
211
210
  def __init__(self, decoder_layer, self_attn, cross_attn):
212
211
  super().__init__()
213
- self._original_mod = decoder_layer
214
212
  self.self_attn = self_attn
215
213
  self.encoder_attn = cross_attn
216
214
  self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
@@ -263,7 +261,6 @@ class WhisperDecoderLayer(nn.Module):
263
261
  class WhisperAttention(nn.Module):
264
262
  def __init__(self, attn):
265
263
  super().__init__()
266
- self._original_mod = attn
267
264
  self.q_proj = attn.q_proj
268
265
  self.k_proj = attn.k_proj
269
266
  self.v_proj = attn.v_proj
@@ -221,11 +221,12 @@ def load_weight_files(
221
221
  cache_dir: Optional[str] = None,
222
222
  force_download: bool = False,
223
223
  local_files_only: bool = False,
224
+ exception_keywords: Optional[List[str]] = None,
224
225
  ) -> list[str]:
225
226
  """
226
227
  Discover and download safetensors files for the given model id.
227
228
  """
228
-
229
+ exception_keywords = exception_keywords or []
229
230
  if os.path.isdir(model_id):
230
231
  safetensor_files = glob.glob(f"{model_id}/*.safetensors")
231
232
  else:
@@ -237,17 +238,24 @@ def load_weight_files(
237
238
 
238
239
  for file in repo_files:
239
240
  if file.endswith(".safetensors"):
240
- # Download the safetensors file
241
- downloaded_file = hf_hub_download(
242
- repo_id=model_id,
243
- filename=file,
244
- revision=revision,
245
- token=use_auth_token,
246
- cache_dir=cache_dir,
247
- force_download=force_download,
248
- local_files_only=local_files_only,
249
- )
250
- safetensor_files.append(downloaded_file)
241
+ exculde = False
242
+ for except_key in exception_keywords:
243
+ if except_key in file:
244
+ exculde = True
245
+ break
246
+
247
+ if not exculde:
248
+ # Download the safetensors file
249
+ downloaded_file = hf_hub_download(
250
+ repo_id=model_id,
251
+ filename=file,
252
+ revision=revision,
253
+ token=use_auth_token,
254
+ cache_dir=cache_dir,
255
+ force_download=force_download,
256
+ local_files_only=local_files_only,
257
+ )
258
+ safetensor_files.append(downloaded_file)
251
259
  except Exception as e:
252
260
  logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
253
261
  raise e
@@ -194,7 +194,7 @@ def deprecate_kwarg(
194
194
  message = f"{message} {additional_message}"
195
195
 
196
196
  # update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
197
- if is_greater_or_equal_version:
197
+ if is_greater_or_equal_version and message is not None:
198
198
  # change to NOTIFY -> RAISE in case we want to raise error for already deprecated arguments
199
199
  if raise_if_greater_or_equal_version:
200
200
  minimum_action = Action.RAISE
@@ -211,3 +211,80 @@ def deprecate_kwarg(
211
211
  return wrapped_func
212
212
 
213
213
  return wrapper
214
+
215
+
216
+ def deprecate_method(
217
+ version: str,
218
+ new_method: Optional[str] = None,
219
+ raise_if_greater_or_equal_version: bool = True,
220
+ additional_message: Optional[str] = None,
221
+ ):
222
+ """
223
+ Decorator to mark a method as deprecated, optionally pointing to a replacement method.
224
+ This decorator allows you to:
225
+ - Notify users when a method is deprecated.
226
+ - Optionally specify a new method name that should be used instead.
227
+ - Raise an error if the deprecated method is called after the specified version.
228
+ Parameters:
229
+ version (`str`):
230
+ The version in which the method was (or will be) deprecated.
231
+ new_method (`Optional[str]`, *optional*):
232
+ The name of the new method to use instead. If specified, users will be directed to use this method.
233
+ raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `True`):
234
+ Whether to raise `ValueError` if current `optimum.rbln` version is greater than or equal to the deprecated version.
235
+ additional_message (`Optional[str]`, *optional*):
236
+ An additional message to append to the default deprecation message.
237
+ Returns:
238
+ Callable:
239
+ A wrapped function that handles the deprecation warning or error.
240
+ Examples:
241
+ >>> class MyClass:
242
+ ... @deprecate_method(version="0.12.0", new_method="from_pretrained")
243
+ ... def load(self, path):
244
+ ... return self.from_pretrained(path)
245
+ """
246
+
247
+ deprecated_version = packaging.version.parse(version)
248
+ current_version = packaging.version.parse(__version__)
249
+ is_greater_or_equal_version = current_version >= deprecated_version
250
+
251
+ if is_greater_or_equal_version:
252
+ version_message = f"and removed starting from version {version}"
253
+ else:
254
+ version_message = f"and will be removed in version {version}"
255
+
256
+ def wrapper(func):
257
+ sig = inspect.signature(func)
258
+ function_named_args = set(sig.parameters.keys())
259
+ is_instance_method = "self" in function_named_args
260
+ is_class_method = "cls" in function_named_args
261
+
262
+ @wraps(func)
263
+ def wrapped_func(*args, **kwargs):
264
+ # Get class + method name for better warning message
265
+ method_name = func.__name__
266
+ if is_instance_method:
267
+ method_name = f"{args[0].__class__.__name__}.{method_name}"
268
+ elif is_class_method:
269
+ method_name = f"{args[0].__name__}.{method_name}"
270
+
271
+ # Build deprecation message
272
+ if new_method is not None:
273
+ message = f"`{method_name}` is deprecated {version_message}. Use `{new_method}` instead."
274
+ else:
275
+ message = f"`{method_name}` is deprecated {version_message}."
276
+
277
+ if additional_message is not None:
278
+ message = f"{message} {additional_message}"
279
+
280
+ # Determine action based on version
281
+ if is_greater_or_equal_version and raise_if_greater_or_equal_version:
282
+ raise ValueError(message)
283
+ else:
284
+ logger.warning(message, stacklevel=2)
285
+
286
+ return func(*args, **kwargs)
287
+
288
+ return wrapped_func
289
+
290
+ return wrapper