optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,6 @@ import copy
16
16
  from typing import Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
20
19
 
21
20
  from ..decoderonly.decoderonly_architecture import (
22
21
  DecoderOnlyAttention,
@@ -95,16 +94,18 @@ class Gemma3TextModel(DecoderOnlyModel):
95
94
  else:
96
95
  seq_positions = cache_position[:, :1]
97
96
 
98
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
97
+ cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
98
+ sliding_cache_pos = (cache_seq_len, cache_offset)
99
99
 
100
100
  all_hidden_states = () if output_hidden_states else None
101
101
  for layer_idx, layer in enumerate(self.layers):
102
102
  if output_hidden_states:
103
103
  all_hidden_states += (hidden_states,)
104
104
  is_sliding = True if layer_idx in self.sliding_window_layers else False
105
+ is_sliding_decode = is_sliding and self.phase == "decode"
105
106
  hidden_states = layer(
106
107
  hidden_states=hidden_states,
107
- attention_mask=attention_mask,
108
+ attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
108
109
  seq_positions=sliding_cache_pos if is_sliding else seq_positions,
109
110
  past_key_values=past_key_values,
110
111
  cos=cos_local if is_sliding else cos_global,
@@ -120,11 +121,8 @@ class Gemma3TextModel(DecoderOnlyModel):
120
121
 
121
122
 
122
123
  class Gemma3DecoderLayer(DecoderOnlyLayer):
123
- def get_pre_feedforward_layernorm(self) -> Gemma3RMSNorm:
124
- return self._original_mod.pre_feedforward_layernorm
125
-
126
- def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
127
- return self._original_mod.post_feedforward_layernorm
124
+ _PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
125
+ _POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
128
126
 
129
127
  def forward(
130
128
  self,
@@ -164,13 +162,13 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
164
162
 
165
163
 
166
164
  class Gemma3Attention(DecoderOnlyAttention):
167
- def __post_init__(self):
168
- self.q_proj = self._original_mod.q_proj
169
- self.k_proj = self._original_mod.k_proj
170
- self.v_proj = self._original_mod.v_proj
171
- self.o_proj = self._original_mod.o_proj
172
- self.q_norm = self._original_mod.q_norm
173
- self.k_norm = self._original_mod.k_norm
174
-
175
- def get_attn_scale(self):
176
- return self._original_mod.config.query_pre_attn_scalar**-0.5
165
+ def __post_init__(self, self_attn):
166
+ self.q_proj = self_attn.q_proj
167
+ self.k_proj = self_attn.k_proj
168
+ self.v_proj = self_attn.v_proj
169
+ self.o_proj = self_attn.o_proj
170
+ self.q_norm = self_attn.q_norm
171
+ self.k_norm = self_attn.k_norm
172
+
173
+ def get_attn_scale(self, self_attn):
174
+ return self_attn.config.query_pre_attn_scalar**-0.5
@@ -325,7 +325,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
325
325
  batch_size,
326
326
  inputs_embeds.shape[1],
327
327
  self.config.text_config.hidden_size,
328
- dtype=self.rbln_config.torch_dtype,
328
+ dtype=self.rbln_config.dtype,
329
329
  )
330
330
  for _ in range(self.config.text_config.num_hidden_layers + 1)
331
331
  )
@@ -20,8 +20,6 @@ import torch.nn as nn
20
20
 
21
21
  from ..decoderonly.decoderonly_architecture import (
22
22
  DecoderOnlyAttention,
23
- DecoderOnlyLayer,
24
- DecoderOnlyModel,
25
23
  DecoderOnlyWrapper,
26
24
  )
27
25
 
@@ -34,12 +32,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
34
32
  def get_rbln_attn_class(self):
35
33
  return GPT2Attention
36
34
 
37
- def get_rbln_layer_class(self):
38
- return GPT2Layer
39
-
40
- def get_rbln_model_class(self):
41
- return GPT2Model
42
-
43
35
  def get_attn_layer(self, layer: nn.Module):
44
36
  return layer.attn
45
37
 
@@ -50,30 +42,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
50
42
  return model.transformer.h if self.is_causal_lm else model.h
51
43
 
52
44
 
53
- class GPT2Model(DecoderOnlyModel):
54
- def get_last_layernorm(self) -> nn.LayerNorm:
55
- return self._original_mod.ln_f
56
-
57
- def get_embedding(self) -> nn.Embedding:
58
- return self._original_mod.wte
59
-
60
- def get_pos_embedding(self) -> nn.Embedding:
61
- return self._original_mod.wpe
62
-
63
-
64
- class GPT2Layer(DecoderOnlyLayer):
65
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.ln_1
67
-
68
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
69
- return self._original_mod.ln_2
70
-
71
-
72
45
  class GPT2Attention(DecoderOnlyAttention):
73
- def __post_init__(self):
74
- self.c_attn = self._original_mod.c_attn
75
- self.o_proj = self._original_mod.c_proj
76
- self.split_size = self._original_mod.split_size
46
+ def __post_init__(self, self_attn):
47
+ self.c_attn = self_attn.c_attn
48
+ self.o_proj = self_attn.c_proj
49
+ self.split_size = self_attn.split_size
50
+ self.num_key_value_heads = self_attn.num_heads
77
51
 
78
52
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
53
  if lora_int_id is not None:
@@ -82,12 +56,12 @@ class GPT2Attention(DecoderOnlyAttention):
82
56
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
83
57
  return query_states, key_states, value_states
84
58
 
85
- def get_attn_scale(self):
59
+ def get_attn_scale(self, self_attn):
86
60
  scale = 1.0
87
- if self._original_mod.scale_attn_weights:
61
+ if self_attn.scale_attn_weights:
88
62
  scale /= math.sqrt(self.head_dim)
89
63
 
90
- if self._original_mod.scale_attn_by_inverse_layer_idx:
64
+ if self_attn.scale_attn_by_inverse_layer_idx:
91
65
  scale /= 1 + self.layer_idx
92
66
 
93
67
  return scale
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gpt_oss import RBLNGptOssForCausalLMConfig
16
+ from .modeling_gpt_oss import RBLNGptOssForCausalLM
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN GPT-OSS models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNGptOssForCausalLMConfig(
30
+ batch_size=1,
31
+ tensor_parallel_size=4
32
+ )
33
+
34
+ # Use the configuration with from_pretrained
35
+ model = RBLNGptOssForCausalLM.from_pretrained(
36
+ "openai/gpt-oss-20b",
37
+ export=True,
38
+ rbln_config=config
39
+ )
40
+ ```
41
+ """
@@ -0,0 +1,122 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
23
+ from ..decoderonly.decoderonly_architecture import (
24
+ DecoderOnlyAttention,
25
+ DecoderOnlyLayer,
26
+ DecoderOnlyWrapper,
27
+ )
28
+
29
+
30
+ class RBLNGptOssWrapper(DecoderOnlyWrapper):
31
+ def get_rbln_layer_class(self):
32
+ return RBLNGptOssLayer
33
+
34
+
35
+ class RBLNGptOssLayer(DecoderOnlyLayer):
36
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
37
+ super().__init__(layer, self_attn, lora_config)
38
+ self.mlp = RBLNGptOssMLP(layer.mlp)
39
+
40
+ def get_mlp(self) -> nn.Module:
41
+ return self.mlp
42
+
43
+
44
+ class RBLNGptOssTopKRouter(nn.Module):
45
+ def __init__(self, model):
46
+ super().__init__()
47
+ self.weight = model.weight
48
+ self.bias = model.bias
49
+
50
+ def forward(self, hidden_states):
51
+ return F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
52
+
53
+
54
+ class RBLNGptOssExperts(nn.Module):
55
+ def __init__(self, model, top_k: Optional[int] = None):
56
+ super().__init__()
57
+ self.intermediate_size = model.intermediate_size
58
+ self.num_experts = model.num_experts
59
+ self.hidden_size = model.hidden_size
60
+
61
+ self.register_buffer(
62
+ "gate_proj_blocks",
63
+ model.gate_up_proj_blocks.data[:, ::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
64
+ )
65
+ self.register_buffer("gate_proj_scales", model.gate_up_proj_scales.data[:, ::2, :])
66
+ self.register_buffer(
67
+ "gate_proj_bias",
68
+ model.gate_up_proj_bias.data[:, ::2].reshape(self.num_experts, self.intermediate_size),
69
+ )
70
+
71
+ self.register_buffer(
72
+ "up_proj_blocks",
73
+ model.gate_up_proj_blocks.data[:, 1::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
74
+ )
75
+ self.register_buffer("up_proj_scales", model.gate_up_proj_scales.data[:, 1::2, :])
76
+ self.register_buffer(
77
+ "up_proj_bias", model.gate_up_proj_bias.data[:, 1::2].reshape(self.num_experts, self.intermediate_size)
78
+ )
79
+
80
+ self.register_buffer(
81
+ "down_proj_blocks", model.down_proj_blocks.data.reshape(self.num_experts, self.hidden_size, -1)
82
+ )
83
+ self.register_buffer("down_proj_scales", model.down_proj_scales.data)
84
+ self.register_buffer("down_proj_bias", model.down_proj_bias.data)
85
+
86
+ self.alpha = model.alpha # 1.702
87
+ self.limit = model.limit # 7.0
88
+ self.top_k = top_k
89
+
90
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
91
+ return torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
92
+ hidden_states,
93
+ self.gate_proj_blocks,
94
+ self.gate_proj_scales,
95
+ self.gate_proj_bias,
96
+ self.up_proj_blocks,
97
+ self.up_proj_scales,
98
+ self.up_proj_bias,
99
+ self.down_proj_blocks,
100
+ self.down_proj_scales,
101
+ self.down_proj_bias,
102
+ router_logits,
103
+ torch.tensor(self.alpha, dtype=hidden_states.dtype),
104
+ torch.tensor(self.limit, dtype=hidden_states.dtype),
105
+ k=self.top_k,
106
+ post_norm=True,
107
+ )
108
+
109
+
110
+ class RBLNGptOssMLP(nn.Module):
111
+ def __init__(self, model):
112
+ super().__init__()
113
+ self.router = RBLNGptOssTopKRouter(model.router)
114
+ self.experts = RBLNGptOssExperts(model.experts, top_k=model.router.top_k)
115
+
116
+ def forward(self, hidden_states):
117
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
118
+ hidden_states = hidden_states.view(-1, hidden_dim)
119
+ router_logits = self.router(hidden_states)
120
+ routed_out = self.experts(hidden_states, router_logits=router_logits)
121
+ routed_out = routed_out.reshape(batch_size, sequence_length, hidden_dim)
122
+ return routed_out
@@ -0,0 +1,165 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Union
16
+
17
+ import torch
18
+ from safetensors.torch import load_file
19
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
20
+ from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
21
+ from transformers.modeling_utils import PreTrainedModel, no_init_weights
22
+
23
+ from ....utils.logging import get_logger
24
+ from ...models.decoderonly import (
25
+ RBLNDecoderOnlyModelConfig,
26
+ RBLNDecoderOnlyModelForCausalLM,
27
+ RBLNDecoderOnlyModelForCausalLMConfig,
28
+ )
29
+ from ...utils.rbln_quantization import load_weight_files
30
+ from .gpt_oss_architecture import RBLNGptOssWrapper
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
40
+ """
41
+ The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
42
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
43
+
44
+ A class to convert and run pre-trained transformers based GPT-OSSForCausalLM model on RBLN devices.
45
+ It implements the methods to convert a pre-trained transformers GPT-OSSForCausalLM model into a RBLN transformer model by:
46
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
47
+ - compiling the resulting graph using the RBLN compiler.
48
+
49
+ **Configuration:**
50
+ This model uses [`RBLNGptOssForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
51
+ the `rbln_config` parameter should be an instance of [`RBLNGptOssForCausalLMConfig`] or a dictionary conforming to its structure.
52
+
53
+ See the [`RBLNGptOssForCausalLMConfig`] class for all available configuration options.
54
+
55
+ Examples:
56
+ ```python
57
+ from optimum.rbln import RBLNGptOssForCausalLM
58
+
59
+ # Simple usage using rbln_* arguments
60
+ # `max_seq_len` is automatically inferred from the model config
61
+ model = RBLNGptOssForCausalLM.from_pretrained(
62
+ "openai/gpt-oss-20b",
63
+ export=True,
64
+ rbln_batch_size=1,
65
+ rbln_tensor_parallel_size=4,
66
+ )
67
+
68
+
69
+ # Using a config dictionary
70
+ rbln_config = {
71
+ "batch_size": 1,
72
+ "tensor_parallel_size": 4,
73
+ }
74
+ model = RBLNGptOssForCausalLM.from_pretrained(
75
+ "openai/gpt-oss-20b",
76
+ export=True,
77
+ rbln_config=rbln_config
78
+ )
79
+
80
+
81
+ # Using a RBLNGptOssForCausalLMConfig instance (recommended for type checking)
82
+ from optimum.rbln import RBLNGptOssForCausalLMConfig
83
+
84
+ config = RBLNGptOssForCausalLMConfig(
85
+ batch_size=1,
86
+ tensor_parallel_size=4
87
+ )
88
+ model = RBLNGptOssForCausalLM.from_pretrained(
89
+ "openai/gpt-oss-20b",
90
+ export=True,
91
+ rbln_config=config
92
+ )
93
+ ```
94
+ """
95
+
96
+ _decoder_wrapper_cls = RBLNGptOssWrapper
97
+
98
+ @staticmethod
99
+ def _get_dtype(dtype: Union[str, torch.dtype] = None, torch_dtype: Union[str, torch.dtype] = None):
100
+ # For BC on torch_dtype argument
101
+ if torch_dtype is not None:
102
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
103
+ # If both kwargs are provided, use `dtype`
104
+ dtype = dtype if dtype is not None else torch_dtype
105
+
106
+ # As mxfp4_quantizer's default dtype
107
+ if dtype is None or dtype == "auto":
108
+ dtype = torch.bfloat16
109
+
110
+ return dtype
111
+
112
+ @classmethod
113
+ def get_pytorch_model(
114
+ cls,
115
+ model_id: str,
116
+ *args,
117
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
118
+ dtype: Union[str, torch.dtype] = None,
119
+ torch_dtype: Union[str, torch.dtype] = None,
120
+ config: Optional[PretrainedConfig] = None,
121
+ **kwargs,
122
+ ) -> PreTrainedModel:
123
+ safetensor_files = load_weight_files(model_id, exception_keywords=["original"])
124
+ state_dict = {k: v for f in safetensor_files for k, v in load_file(f).items()}
125
+
126
+ if config is None:
127
+ config, kwargs = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True)
128
+
129
+ dtype = cls._get_dtype(dtype, torch_dtype)
130
+
131
+ with no_init_weights():
132
+ model = AutoModelForCausalLM.from_config(config, dtype=dtype, **kwargs)
133
+
134
+ _replace_with_mxfp4_linear(model, config)
135
+ model.load_state_dict(state_dict, strict=False)
136
+
137
+ return model
138
+
139
+ @classmethod
140
+ def _update_rbln_config(
141
+ cls,
142
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
143
+ model: Optional["PreTrainedModel"] = None,
144
+ model_config: Optional["PretrainedConfig"] = None,
145
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
146
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
147
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
148
+
149
+ if rbln_config.use_attention_mask:
150
+ raise ValueError(
151
+ "use_attention_mask is not supported for GPT-OSS because custom attention does not support attention sink for masked attention"
152
+ )
153
+
154
+ return rbln_config
155
+
156
+
157
+ def _replace_with_mxfp4_linear(
158
+ model,
159
+ config,
160
+ ):
161
+ for name, module in model.named_children():
162
+ if module.__class__.__name__ == "GptOssExperts":
163
+ model._modules[name] = Mxfp4GptOssExperts(config)
164
+ if len(list(module.children())) > 0:
165
+ _replace_with_mxfp4_linear(module, config)
@@ -50,11 +50,14 @@ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
50
50
  Raises:
51
51
  ValueError: If batch_size is not a positive integer.
52
52
  """
53
- super().__init__(**kwargs)
54
- self.encoder = encoder
55
- self.decoder = decoder
56
- self.text_backbone = text_backbone
57
- self.backbone = backbone
53
+
54
+ super().__init__(batch_size=batch_size, **kwargs)
55
+ self.encoder = self.initialize_submodule_config(submodule_config=encoder, batch_size=self.batch_size)
56
+ self.decoder = self.initialize_submodule_config(submodule_config=decoder, batch_size=self.batch_size)
57
+ self.text_backbone = self.initialize_submodule_config(
58
+ submodule_config=text_backbone, batch_size=self.batch_size
59
+ )
60
+ self.backbone = self.initialize_submodule_config(submodule_config=backbone, batch_size=self.batch_size)
58
61
  self.output_attentions = output_attentions if output_attentions is not None else False
59
62
  self.output_hidden_states = output_hidden_states if output_hidden_states is not None else False
60
63
 
@@ -509,10 +509,12 @@ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
509
509
 
510
510
  # mask vision for language
511
511
  if vision_attention_mask is not None:
512
- # RBLN FIX: bool tensor to float tensor
513
- mask = vision_attention_mask * torch.finfo(torch.float16).min
514
- text_attn_weights = text_attn_weights.transpose(1, 2) + mask
515
- text_attn_weights = text_attn_weights.transpose(1, 2)
512
+ # RBLN FIX: bool tensor to float tensor, broadcast across heads and src_len
513
+ mask = vision_attention_mask
514
+ if mask.dim() == 3:
515
+ mask = mask[..., 0]
516
+ mask = mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
517
+ text_attn_weights = text_attn_weights + mask * torch.finfo(text_attn_weights.dtype).min
516
518
 
517
519
  text_attn_weights = text_attn_weights.softmax(dim=-1)
518
520
 
@@ -116,7 +116,6 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
116
116
  RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
117
117
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
118
118
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
119
-
120
119
  Important Note:
121
120
  This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
122
121
  tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
@@ -71,6 +71,12 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
71
71
 
72
72
 
73
73
  class MidmModel(DecoderOnlyModel):
74
+ def __init__(self, model, layers, rbln_config, use_learned_pos_emb=None, use_rotary_emb=True):
75
+ super().__init__(
76
+ model, layers, rbln_config, use_learned_pos_emb=use_learned_pos_emb, use_rotary_emb=use_rotary_emb
77
+ )
78
+ self.use_layernorm1p = getattr(model, "use_layernorm1p", False)
79
+
74
80
  def get_layernorm1p(self, module: nn.LayerNorm):
75
81
  def layernorm1p(input: torch.Tensor):
76
82
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -81,19 +87,22 @@ class MidmModel(DecoderOnlyModel):
81
87
  return layernorm1p
82
88
 
83
89
  def get_last_layernorm(self) -> nn.LayerNorm:
84
- if self._original_mod.use_layernorm1p:
85
- return self.get_layernorm1p(self._original_mod.ln_f)
86
- else:
87
- return self._original_mod.ln_f
90
+ if self.use_layernorm1p:
91
+ return self.get_layernorm1p(self.norm)
92
+ return self.norm
88
93
 
89
94
  def get_embedding(self) -> nn.Embedding:
90
- return self._original_mod.wte
95
+ return self.embed_tokens
91
96
 
92
97
  def get_pos_embedding(self) -> nn.Embedding:
93
- return self._original_mod.wpe
98
+ return self.embed_positions
94
99
 
95
100
 
96
101
  class MidmLayer(DecoderOnlyLayer):
102
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config=None):
103
+ super().__init__(layer, self_attn, lora_config)
104
+ self.use_layernorm1p = getattr(layer, "use_layernorm1p", False)
105
+
97
106
  def get_layernorm1p(self, module: nn.LayerNorm):
98
107
  def layernorm1p(input: torch.Tensor):
99
108
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -104,24 +113,22 @@ class MidmLayer(DecoderOnlyLayer):
104
113
  return layernorm1p
105
114
 
106
115
  def get_pre_attention_layernorm(self) -> nn.LayerNorm:
107
- if self._original_mod.use_layernorm1p:
108
- return self.get_layernorm1p(self._original_mod.ln_1)
109
- else:
110
- return self._original_mod.ln_1
116
+ if self.use_layernorm1p:
117
+ return self.get_layernorm1p(self.pre_attention_layernorm)
118
+ return self.pre_attention_layernorm
111
119
 
112
120
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
113
- if self._original_mod.use_layernorm1p:
114
- return self.get_layernorm1p(self._original_mod.ln_2)
115
- else:
116
- return self._original_mod.ln_2
121
+ if self.use_layernorm1p:
122
+ return self.get_layernorm1p(self.post_attention_layernorm)
123
+ return self.post_attention_layernorm
117
124
 
118
125
 
119
126
  class MidmAttention(DecoderOnlyAttention):
120
- def __post_init__(self):
121
- self.c_attn = self._original_mod.c_attn
122
- self.o_proj = self._original_mod.c_proj
123
- self.split_size = self._original_mod.split_size
124
- self.num_key_value_heads = self._original_mod.num_heads
127
+ def __post_init__(self, self_attn):
128
+ self.c_attn = self_attn.c_attn
129
+ self.o_proj = self_attn.c_proj
130
+ self.split_size = self_attn.split_size
131
+ self.num_key_value_heads = self_attn.num_heads
125
132
 
126
133
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
134
  if lora_int_id is not None:
@@ -130,12 +137,12 @@ class MidmAttention(DecoderOnlyAttention):
130
137
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
131
138
  return query_states, key_states, value_states
132
139
 
133
- def get_attn_scale(self):
140
+ def get_attn_scale(self, self_attn):
134
141
  scale = 1.0
135
- if self._original_mod.scale_attn_weights:
142
+ if self_attn.scale_attn_weights:
136
143
  scale /= math.sqrt(self.head_dim)
137
144
 
138
- if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
145
+ if self_attn.scale_attn_by_inverse_layer_idx:
139
146
  scale /= 1 + self.layer_idx
140
147
 
141
148
  return scale