optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,122 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
23
+ from ..decoderonly.decoderonly_architecture import (
24
+ DecoderOnlyAttention,
25
+ DecoderOnlyLayer,
26
+ DecoderOnlyWrapper,
27
+ )
28
+
29
+
30
+ class RBLNGptOssWrapper(DecoderOnlyWrapper):
31
+ def get_rbln_layer_class(self):
32
+ return RBLNGptOssLayer
33
+
34
+
35
+ class RBLNGptOssLayer(DecoderOnlyLayer):
36
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
37
+ super().__init__(layer, self_attn, lora_config)
38
+ self.mlp = RBLNGptOssMLP(layer.mlp)
39
+
40
+ def get_mlp(self) -> nn.Module:
41
+ return self.mlp
42
+
43
+
44
+ class RBLNGptOssTopKRouter(nn.Module):
45
+ def __init__(self, model):
46
+ super().__init__()
47
+ self.weight = model.weight
48
+ self.bias = model.bias
49
+
50
+ def forward(self, hidden_states):
51
+ return F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
52
+
53
+
54
+ class RBLNGptOssExperts(nn.Module):
55
+ def __init__(self, model, top_k: Optional[int] = None):
56
+ super().__init__()
57
+ self.intermediate_size = model.intermediate_size
58
+ self.num_experts = model.num_experts
59
+ self.hidden_size = model.hidden_size
60
+
61
+ self.register_buffer(
62
+ "gate_proj_blocks",
63
+ model.gate_up_proj_blocks.data[:, ::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
64
+ )
65
+ self.register_buffer("gate_proj_scales", model.gate_up_proj_scales.data[:, ::2, :])
66
+ self.register_buffer(
67
+ "gate_proj_bias",
68
+ model.gate_up_proj_bias.data[:, ::2].reshape(self.num_experts, self.intermediate_size),
69
+ )
70
+
71
+ self.register_buffer(
72
+ "up_proj_blocks",
73
+ model.gate_up_proj_blocks.data[:, 1::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
74
+ )
75
+ self.register_buffer("up_proj_scales", model.gate_up_proj_scales.data[:, 1::2, :])
76
+ self.register_buffer(
77
+ "up_proj_bias", model.gate_up_proj_bias.data[:, 1::2].reshape(self.num_experts, self.intermediate_size)
78
+ )
79
+
80
+ self.register_buffer(
81
+ "down_proj_blocks", model.down_proj_blocks.data.reshape(self.num_experts, self.hidden_size, -1)
82
+ )
83
+ self.register_buffer("down_proj_scales", model.down_proj_scales.data)
84
+ self.register_buffer("down_proj_bias", model.down_proj_bias.data)
85
+
86
+ self.alpha = model.alpha # 1.702
87
+ self.limit = model.limit # 7.0
88
+ self.top_k = top_k
89
+
90
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
91
+ return torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
92
+ hidden_states,
93
+ self.gate_proj_blocks,
94
+ self.gate_proj_scales,
95
+ self.gate_proj_bias,
96
+ self.up_proj_blocks,
97
+ self.up_proj_scales,
98
+ self.up_proj_bias,
99
+ self.down_proj_blocks,
100
+ self.down_proj_scales,
101
+ self.down_proj_bias,
102
+ router_logits,
103
+ torch.tensor(self.alpha, dtype=hidden_states.dtype),
104
+ torch.tensor(self.limit, dtype=hidden_states.dtype),
105
+ k=self.top_k,
106
+ post_norm=True,
107
+ )
108
+
109
+
110
+ class RBLNGptOssMLP(nn.Module):
111
+ def __init__(self, model):
112
+ super().__init__()
113
+ self.router = RBLNGptOssTopKRouter(model.router)
114
+ self.experts = RBLNGptOssExperts(model.experts, top_k=model.router.top_k)
115
+
116
+ def forward(self, hidden_states):
117
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
118
+ hidden_states = hidden_states.view(-1, hidden_dim)
119
+ router_logits = self.router(hidden_states)
120
+ routed_out = self.experts(hidden_states, router_logits=router_logits)
121
+ routed_out = routed_out.reshape(batch_size, sequence_length, hidden_dim)
122
+ return routed_out
@@ -0,0 +1,168 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Union
16
+
17
+ import torch
18
+ from safetensors.torch import load_file
19
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
20
+ from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
21
+ from transformers.modeling_utils import PreTrainedModel, no_init_weights
22
+
23
+ from ....utils.logging import get_logger
24
+ from ...models.decoderonly import (
25
+ RBLNDecoderOnlyModelConfig,
26
+ RBLNDecoderOnlyModelForCausalLM,
27
+ RBLNDecoderOnlyModelForCausalLMConfig,
28
+ )
29
+ from ...utils.rbln_quantization import load_weight_files
30
+ from .gpt_oss_architecture import RBLNGptOssWrapper
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
40
+ """
41
+ The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
42
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
43
+
44
+ A class to convert and run pre-trained transformers based GptOssForCausalLM model on RBLN devices.
45
+ It implements the methods to convert a pre-trained transformers GptOssForCausalLM model into a RBLN transformer model by:
46
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
47
+ - compiling the resulting graph using the RBLN compiler.
48
+
49
+ **Configuration:**
50
+ This model uses [`RBLNGptOssForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
51
+ the `rbln_config` parameter should be an instance of [`RBLNGptOssForCausalLMConfig`] or a dictionary conforming to its structure.
52
+
53
+ See the [`RBLNGptOssForCausalLMConfig`] class for all available configuration options.
54
+
55
+ Examples:
56
+ ```python
57
+ from optimum.rbln import RBLNGptOssForCausalLM
58
+
59
+ # Simple usage using rbln_* arguments
60
+ # `max_seq_len` is automatically inferred from the model config
61
+ model = RBLNGptOssForCausalLM.from_pretrained(
62
+ "openai/gpt-oss-20b",
63
+ export=True,
64
+ rbln_batch_size=1,
65
+ rbln_tensor_parallel_size=8,
66
+ rbln_kvcache_partition_len=8192,
67
+ )
68
+
69
+
70
+ # Using a config dictionary
71
+ rbln_config = {
72
+ "batch_size": 1,
73
+ "tensor_parallel_size": 8,
74
+ "kvcache_partition_len": 8192,
75
+ }
76
+ model = RBLNGptOssForCausalLM.from_pretrained(
77
+ "openai/gpt-oss-20b",
78
+ export=True,
79
+ rbln_config=rbln_config,
80
+ )
81
+
82
+
83
+ # Using a RBLNGptOssForCausalLMConfig instance (recommended for type checking)
84
+ from optimum.rbln import RBLNGptOssForCausalLMConfig
85
+
86
+ config = RBLNGptOssForCausalLMConfig(
87
+ batch_size=1,
88
+ tensor_parallel_size=8,
89
+ kvcache_partition_len=8192,
90
+ )
91
+ model = RBLNGptOssForCausalLM.from_pretrained(
92
+ "openai/gpt-oss-20b",
93
+ export=True,
94
+ rbln_config=config,
95
+ )
96
+ ```
97
+ """
98
+
99
+ _decoder_wrapper_cls = RBLNGptOssWrapper
100
+
101
+ @staticmethod
102
+ def _get_dtype(dtype: Union[str, torch.dtype] = None, torch_dtype: Union[str, torch.dtype] = None):
103
+ # For BC on torch_dtype argument
104
+ if torch_dtype is not None:
105
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
106
+ # If both kwargs are provided, use `dtype`
107
+ dtype = dtype if dtype is not None else torch_dtype
108
+
109
+ # As mxfp4_quantizer's default dtype
110
+ if dtype is None or dtype == "auto":
111
+ dtype = torch.bfloat16
112
+
113
+ return dtype
114
+
115
+ @classmethod
116
+ def get_pytorch_model(
117
+ cls,
118
+ model_id: str,
119
+ *args,
120
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
121
+ dtype: Union[str, torch.dtype] = None,
122
+ torch_dtype: Union[str, torch.dtype] = None,
123
+ config: Optional[PretrainedConfig] = None,
124
+ **kwargs,
125
+ ) -> PreTrainedModel:
126
+ safetensor_files = load_weight_files(model_id, exception_keywords=["original"])
127
+ state_dict = {k: v for f in safetensor_files for k, v in load_file(f).items()}
128
+
129
+ if config is None:
130
+ config, kwargs = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True)
131
+
132
+ dtype = cls._get_dtype(dtype, torch_dtype)
133
+
134
+ with no_init_weights():
135
+ model = AutoModelForCausalLM.from_config(config, dtype=dtype, **kwargs)
136
+
137
+ _replace_with_mxfp4_linear(model, config)
138
+ model.load_state_dict(state_dict, strict=False)
139
+
140
+ return model
141
+
142
+ @classmethod
143
+ def _update_rbln_config(
144
+ cls,
145
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
146
+ model: Optional["PreTrainedModel"] = None,
147
+ model_config: Optional["PretrainedConfig"] = None,
148
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
149
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
150
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
151
+
152
+ if rbln_config.use_attention_mask:
153
+ raise ValueError(
154
+ "use_attention_mask is not supported for GPT-OSS because custom attention does not support attention sink for masked attention"
155
+ )
156
+
157
+ return rbln_config
158
+
159
+
160
+ def _replace_with_mxfp4_linear(
161
+ model,
162
+ config,
163
+ ):
164
+ for name, module in model.named_children():
165
+ if module.__class__.__name__ == "GptOssExperts":
166
+ model._modules[name] = Mxfp4GptOssExperts(config)
167
+ if len(list(module.children())) > 0:
168
+ _replace_with_mxfp4_linear(module, config)
@@ -50,11 +50,14 @@ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
50
50
  Raises:
51
51
  ValueError: If batch_size is not a positive integer.
52
52
  """
53
- super().__init__(**kwargs)
54
- self.encoder = encoder
55
- self.decoder = decoder
56
- self.text_backbone = text_backbone
57
- self.backbone = backbone
53
+
54
+ super().__init__(batch_size=batch_size, **kwargs)
55
+ self.encoder = self.initialize_submodule_config(submodule_config=encoder, batch_size=self.batch_size)
56
+ self.decoder = self.initialize_submodule_config(submodule_config=decoder, batch_size=self.batch_size)
57
+ self.text_backbone = self.initialize_submodule_config(
58
+ submodule_config=text_backbone, batch_size=self.batch_size
59
+ )
60
+ self.backbone = self.initialize_submodule_config(submodule_config=backbone, batch_size=self.batch_size)
58
61
  self.output_attentions = output_attentions if output_attentions is not None else False
59
62
  self.output_hidden_states = output_hidden_states if output_hidden_states is not None else False
60
63
 
@@ -509,10 +509,12 @@ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
509
509
 
510
510
  # mask vision for language
511
511
  if vision_attention_mask is not None:
512
- # RBLN FIX: bool tensor to float tensor
513
- mask = vision_attention_mask * torch.finfo(torch.float16).min
514
- text_attn_weights = text_attn_weights.transpose(1, 2) + mask
515
- text_attn_weights = text_attn_weights.transpose(1, 2)
512
+ # RBLN FIX: bool tensor to float tensor, broadcast across heads and src_len
513
+ mask = vision_attention_mask
514
+ if mask.dim() == 3:
515
+ mask = mask[..., 0]
516
+ mask = mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
517
+ text_attn_weights = text_attn_weights + mask * torch.finfo(text_attn_weights.dtype).min
516
518
 
517
519
  text_attn_weights = text_attn_weights.softmax(dim=-1)
518
520
 
@@ -116,7 +116,6 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
116
116
  RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
117
117
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
118
118
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
119
-
120
119
  Important Note:
121
120
  This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
122
121
  tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
@@ -71,6 +71,12 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
71
71
 
72
72
 
73
73
  class MidmModel(DecoderOnlyModel):
74
+ def __init__(self, model, layers, rbln_config, use_learned_pos_emb=None, use_rotary_emb=True):
75
+ super().__init__(
76
+ model, layers, rbln_config, use_learned_pos_emb=use_learned_pos_emb, use_rotary_emb=use_rotary_emb
77
+ )
78
+ self.use_layernorm1p = getattr(model, "use_layernorm1p", False)
79
+
74
80
  def get_layernorm1p(self, module: nn.LayerNorm):
75
81
  def layernorm1p(input: torch.Tensor):
76
82
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -81,19 +87,22 @@ class MidmModel(DecoderOnlyModel):
81
87
  return layernorm1p
82
88
 
83
89
  def get_last_layernorm(self) -> nn.LayerNorm:
84
- if self._original_mod.use_layernorm1p:
85
- return self.get_layernorm1p(self._original_mod.ln_f)
86
- else:
87
- return self._original_mod.ln_f
90
+ if self.use_layernorm1p:
91
+ return self.get_layernorm1p(self.norm)
92
+ return self.norm
88
93
 
89
94
  def get_embedding(self) -> nn.Embedding:
90
- return self._original_mod.wte
95
+ return self.embed_tokens
91
96
 
92
97
  def get_pos_embedding(self) -> nn.Embedding:
93
- return self._original_mod.wpe
98
+ return self.embed_positions
94
99
 
95
100
 
96
101
  class MidmLayer(DecoderOnlyLayer):
102
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config=None):
103
+ super().__init__(layer, self_attn, lora_config)
104
+ self.use_layernorm1p = getattr(layer, "use_layernorm1p", False)
105
+
97
106
  def get_layernorm1p(self, module: nn.LayerNorm):
98
107
  def layernorm1p(input: torch.Tensor):
99
108
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -104,24 +113,22 @@ class MidmLayer(DecoderOnlyLayer):
104
113
  return layernorm1p
105
114
 
106
115
  def get_pre_attention_layernorm(self) -> nn.LayerNorm:
107
- if self._original_mod.use_layernorm1p:
108
- return self.get_layernorm1p(self._original_mod.ln_1)
109
- else:
110
- return self._original_mod.ln_1
116
+ if self.use_layernorm1p:
117
+ return self.get_layernorm1p(self.pre_attention_layernorm)
118
+ return self.pre_attention_layernorm
111
119
 
112
120
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
113
- if self._original_mod.use_layernorm1p:
114
- return self.get_layernorm1p(self._original_mod.ln_2)
115
- else:
116
- return self._original_mod.ln_2
121
+ if self.use_layernorm1p:
122
+ return self.get_layernorm1p(self.post_attention_layernorm)
123
+ return self.post_attention_layernorm
117
124
 
118
125
 
119
126
  class MidmAttention(DecoderOnlyAttention):
120
- def __post_init__(self):
121
- self.c_attn = self._original_mod.c_attn
122
- self.o_proj = self._original_mod.c_proj
123
- self.split_size = self._original_mod.split_size
124
- self.num_key_value_heads = self._original_mod.num_heads
127
+ def __post_init__(self, self_attn):
128
+ self.c_attn = self_attn.c_attn
129
+ self.o_proj = self_attn.c_proj
130
+ self.split_size = self_attn.split_size
131
+ self.num_key_value_heads = self_attn.num_heads
125
132
 
126
133
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
134
  if lora_int_id is not None:
@@ -130,12 +137,12 @@ class MidmAttention(DecoderOnlyAttention):
130
137
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
131
138
  return query_states, key_states, value_states
132
139
 
133
- def get_attn_scale(self):
140
+ def get_attn_scale(self, self_attn):
134
141
  scale = 1.0
135
- if self._original_mod.scale_attn_weights:
142
+ if self_attn.scale_attn_weights:
136
143
  scale /= math.sqrt(self.head_dim)
137
144
 
138
- if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
145
+ if self_attn.scale_attn_by_inverse_layer_idx:
139
146
  scale /= 1 + self.layer_idx
140
147
 
141
148
  return scale
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_mixtral import RBLNMixtralForCausalLMConfig
16
+ from .modeling_mixtral import RBLNMixtralForCausalLM
@@ -0,0 +1,38 @@
1
+ # Copyright 2026 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNMixtralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Mixtral models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNMixtralForCausalLM, RBLNMixtralForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNMixtralForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=32768,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNMixtralForCausalLM.from_pretrained(
33
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
@@ -0,0 +1,76 @@
1
+ # Copyright 2026 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
21
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
22
+
23
+
24
+ class MixtralWrapper(DecoderOnlyWrapper):
25
+ def get_rbln_layer_class(self):
26
+ return MixtralLayer
27
+
28
+
29
+ class MixtralLayer(DecoderOnlyLayer):
30
+ _MLP_ATTR = ("block_sparse_moe",)
31
+
32
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
33
+ super().__init__(layer, self_attn, lora_config)
34
+ self.block_sparse_moe = MixtralSparseMoeBlock(self.mlp)
35
+
36
+ def get_mlp(self) -> nn.Module:
37
+ return self.block_sparse_moe
38
+
39
+
40
+ class MixtralSparseMoeBlock(nn.Module):
41
+ def __init__(self, model: nn.Module):
42
+ super().__init__()
43
+ # self.num_experts = model.num_experts
44
+ self.top_k = model.top_k
45
+ self.gate = model.gate
46
+ self.experts = MixtralBlockSparseTop2MLP(model.experts, self.top_k)
47
+
48
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
49
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
50
+ hidden_states = hidden_states.view(-1, hidden_dim)
51
+ # router_logits: (batch * sequence_length, n_experts)
52
+ router_logits = self.gate(hidden_states)
53
+ final_hidden_states = self.experts(hidden_states, router_logits)
54
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
55
+ return final_hidden_states
56
+
57
+
58
+ class MixtralBlockSparseTop2MLP(nn.Module):
59
+ def __init__(self, expert_list, top_k):
60
+ super().__init__()
61
+ self.hidden_dim = expert_list[0].hidden_dim
62
+ self.ffn_dim = expert_list[0].ffn_dim
63
+ self.top_k = top_k
64
+
65
+ self.num_experts = len(expert_list)
66
+ self.w1 = nn.Linear(self.hidden_dim, self.num_experts * self.ffn_dim, bias=False)
67
+ self.w2 = nn.Linear(self.num_experts * self.ffn_dim, self.hidden_dim, bias=False)
68
+ self.w3 = nn.Linear(self.hidden_dim, self.num_experts * self.ffn_dim, bias=False)
69
+ self.w1.weight.data = torch.stack([expert.w1.weight.data for expert in expert_list], dim=0)
70
+ self.w2.weight.data = torch.stack([expert.w2.weight.data for expert in expert_list], dim=0)
71
+ self.w3.weight.data = torch.stack([expert.w3.weight.data for expert in expert_list], dim=0)
72
+
73
+ def forward(self, x, router_logits):
74
+ return torch.ops.rbln_custom_ops.custom_moe_glu(
75
+ x, self.w1.weight, self.w3.weight, self.w2.weight, router_logits, self.top_k, True
76
+ )
@@ -0,0 +1,68 @@
1
+ # Copyright 2026 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from .mixtral_architecture import MixtralWrapper
17
+
18
+
19
+ class RBLNMixtralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
20
+ """
21
+ The Mixtral is a Mixture-of-Experts (MoE) variant of Mixtral, available as a base model and an aligned chat model.
22
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
23
+ A class to convert and run pre-trained transformers based MixtralForCausalLM model on RBLN devices.
24
+ It implements the methods to convert a pre-trained transformers MixtralForCausalLM model into a RBLN transformer model by:
25
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
26
+ - compiling the resulting graph using the RBLN compiler.
27
+ **Configuration:**
28
+ This model uses [`RBLNMixtralForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
29
+ the `rbln_config` parameter should be an instance of [`RBLNMixtralForCausalLMConfig`] or a dictionary conforming to its structure.
30
+ See the [`RBLNMixtralForCausalLMConfig`] class for all available configuration options.
31
+ Examples:
32
+ ```python
33
+ from optimum.rbln import RBLNMixtralForCausalLM
34
+ # Simple usage using rbln_* arguments
35
+ # `max_seq_len` is automatically inferred from the model config
36
+ model = RBLNMixtralForCausalLM.from_pretrained(
37
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
38
+ export=True,
39
+ rbln_batch_size=1,
40
+ rbln_tensor_parallel_size=4,
41
+ )
42
+ # Using a config dictionary
43
+ rbln_config = {
44
+ "batch_size": 1,
45
+ "max_seq_len": 32768,
46
+ "tensor_parallel_size": 4,
47
+ }
48
+ model = RBLNMixtralForCausalLM.from_pretrained(
49
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
50
+ export=True,
51
+ rbln_config=rbln_config
52
+ )
53
+ # Using a RBLNMixtralForCausalLMConfig instance (recommended for type checking)
54
+ from optimum.rbln import RBLNMixtralForCausalLMConfig
55
+ config = RBLNMixtralForCausalLMConfig(
56
+ batch_size=1,
57
+ max_seq_len=32768,
58
+ tensor_parallel_size=4
59
+ )
60
+ model = RBLNMixtralForCausalLM.from_pretrained(
61
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
62
+ export=True,
63
+ rbln_config=config
64
+ )
65
+ ```
66
+ """
67
+
68
+ _decoder_wrapper_cls = MixtralWrapper