optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,7 +16,6 @@ import copy
|
|
|
16
16
|
from typing import Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
|
|
20
19
|
|
|
21
20
|
from ..decoderonly.decoderonly_architecture import (
|
|
22
21
|
DecoderOnlyAttention,
|
|
@@ -95,16 +94,18 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
95
94
|
else:
|
|
96
95
|
seq_positions = cache_position[:, :1]
|
|
97
96
|
|
|
98
|
-
|
|
97
|
+
cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
|
|
98
|
+
sliding_cache_pos = (cache_seq_len, cache_offset)
|
|
99
99
|
|
|
100
100
|
all_hidden_states = () if output_hidden_states else None
|
|
101
101
|
for layer_idx, layer in enumerate(self.layers):
|
|
102
102
|
if output_hidden_states:
|
|
103
103
|
all_hidden_states += (hidden_states,)
|
|
104
104
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
105
|
+
is_sliding_decode = is_sliding and self.phase == "decode"
|
|
105
106
|
hidden_states = layer(
|
|
106
107
|
hidden_states=hidden_states,
|
|
107
|
-
attention_mask=attention_mask,
|
|
108
|
+
attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
|
|
108
109
|
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
109
110
|
past_key_values=past_key_values,
|
|
110
111
|
cos=cos_local if is_sliding else cos_global,
|
|
@@ -120,11 +121,8 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
120
121
|
|
|
121
122
|
|
|
122
123
|
class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
|
|
127
|
-
return self._original_mod.post_feedforward_layernorm
|
|
124
|
+
_PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
|
|
125
|
+
_POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
|
|
128
126
|
|
|
129
127
|
def forward(
|
|
130
128
|
self,
|
|
@@ -164,13 +162,13 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
|
164
162
|
|
|
165
163
|
|
|
166
164
|
class Gemma3Attention(DecoderOnlyAttention):
|
|
167
|
-
def __post_init__(self):
|
|
168
|
-
self.q_proj =
|
|
169
|
-
self.k_proj =
|
|
170
|
-
self.v_proj =
|
|
171
|
-
self.o_proj =
|
|
172
|
-
self.q_norm =
|
|
173
|
-
self.k_norm =
|
|
174
|
-
|
|
175
|
-
def get_attn_scale(self):
|
|
176
|
-
return
|
|
165
|
+
def __post_init__(self, self_attn):
|
|
166
|
+
self.q_proj = self_attn.q_proj
|
|
167
|
+
self.k_proj = self_attn.k_proj
|
|
168
|
+
self.v_proj = self_attn.v_proj
|
|
169
|
+
self.o_proj = self_attn.o_proj
|
|
170
|
+
self.q_norm = self_attn.q_norm
|
|
171
|
+
self.k_norm = self_attn.k_norm
|
|
172
|
+
|
|
173
|
+
def get_attn_scale(self, self_attn):
|
|
174
|
+
return self_attn.config.query_pre_attn_scalar**-0.5
|
|
@@ -325,7 +325,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
325
325
|
batch_size,
|
|
326
326
|
inputs_embeds.shape[1],
|
|
327
327
|
self.config.text_config.hidden_size,
|
|
328
|
-
dtype=self.rbln_config.
|
|
328
|
+
dtype=self.rbln_config.dtype,
|
|
329
329
|
)
|
|
330
330
|
for _ in range(self.config.text_config.num_hidden_layers + 1)
|
|
331
331
|
)
|
|
@@ -20,8 +20,6 @@ import torch.nn as nn
|
|
|
20
20
|
|
|
21
21
|
from ..decoderonly.decoderonly_architecture import (
|
|
22
22
|
DecoderOnlyAttention,
|
|
23
|
-
DecoderOnlyLayer,
|
|
24
|
-
DecoderOnlyModel,
|
|
25
23
|
DecoderOnlyWrapper,
|
|
26
24
|
)
|
|
27
25
|
|
|
@@ -34,12 +32,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
|
34
32
|
def get_rbln_attn_class(self):
|
|
35
33
|
return GPT2Attention
|
|
36
34
|
|
|
37
|
-
def get_rbln_layer_class(self):
|
|
38
|
-
return GPT2Layer
|
|
39
|
-
|
|
40
|
-
def get_rbln_model_class(self):
|
|
41
|
-
return GPT2Model
|
|
42
|
-
|
|
43
35
|
def get_attn_layer(self, layer: nn.Module):
|
|
44
36
|
return layer.attn
|
|
45
37
|
|
|
@@ -50,30 +42,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
|
50
42
|
return model.transformer.h if self.is_causal_lm else model.h
|
|
51
43
|
|
|
52
44
|
|
|
53
|
-
class GPT2Model(DecoderOnlyModel):
|
|
54
|
-
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
55
|
-
return self._original_mod.ln_f
|
|
56
|
-
|
|
57
|
-
def get_embedding(self) -> nn.Embedding:
|
|
58
|
-
return self._original_mod.wte
|
|
59
|
-
|
|
60
|
-
def get_pos_embedding(self) -> nn.Embedding:
|
|
61
|
-
return self._original_mod.wpe
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class GPT2Layer(DecoderOnlyLayer):
|
|
65
|
-
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
66
|
-
return self._original_mod.ln_1
|
|
67
|
-
|
|
68
|
-
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
69
|
-
return self._original_mod.ln_2
|
|
70
|
-
|
|
71
|
-
|
|
72
45
|
class GPT2Attention(DecoderOnlyAttention):
|
|
73
|
-
def __post_init__(self):
|
|
74
|
-
self.c_attn =
|
|
75
|
-
self.o_proj =
|
|
76
|
-
self.split_size =
|
|
46
|
+
def __post_init__(self, self_attn):
|
|
47
|
+
self.c_attn = self_attn.c_attn
|
|
48
|
+
self.o_proj = self_attn.c_proj
|
|
49
|
+
self.split_size = self_attn.split_size
|
|
50
|
+
self.num_key_value_heads = self_attn.num_heads
|
|
77
51
|
|
|
78
52
|
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
79
53
|
if lora_int_id is not None:
|
|
@@ -82,12 +56,12 @@ class GPT2Attention(DecoderOnlyAttention):
|
|
|
82
56
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
83
57
|
return query_states, key_states, value_states
|
|
84
58
|
|
|
85
|
-
def get_attn_scale(self):
|
|
59
|
+
def get_attn_scale(self, self_attn):
|
|
86
60
|
scale = 1.0
|
|
87
|
-
if
|
|
61
|
+
if self_attn.scale_attn_weights:
|
|
88
62
|
scale /= math.sqrt(self.head_dim)
|
|
89
63
|
|
|
90
|
-
if
|
|
64
|
+
if self_attn.scale_attn_by_inverse_layer_idx:
|
|
91
65
|
scale /= 1 + self.layer_idx
|
|
92
66
|
|
|
93
67
|
return scale
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_gpt_oss import RBLNGptOssForCausalLMConfig
|
|
16
|
+
from .modeling_gpt_oss import RBLNGptOssForCausalLM
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN GPT-OSS models.
|
|
21
|
+
|
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
```python
|
|
26
|
+
from optimum.rbln import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
|
|
27
|
+
|
|
28
|
+
# Create a configuration object
|
|
29
|
+
config = RBLNGptOssForCausalLMConfig(
|
|
30
|
+
batch_size=1,
|
|
31
|
+
tensor_parallel_size=4
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Use the configuration with from_pretrained
|
|
35
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
36
|
+
"openai/gpt-oss-20b",
|
|
37
|
+
export=True,
|
|
38
|
+
rbln_config=config
|
|
39
|
+
)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch import nn
|
|
21
|
+
|
|
22
|
+
from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
|
|
23
|
+
from ..decoderonly.decoderonly_architecture import (
|
|
24
|
+
DecoderOnlyAttention,
|
|
25
|
+
DecoderOnlyLayer,
|
|
26
|
+
DecoderOnlyWrapper,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RBLNGptOssWrapper(DecoderOnlyWrapper):
|
|
31
|
+
def get_rbln_layer_class(self):
|
|
32
|
+
return RBLNGptOssLayer
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RBLNGptOssLayer(DecoderOnlyLayer):
|
|
36
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
|
|
37
|
+
super().__init__(layer, self_attn, lora_config)
|
|
38
|
+
self.mlp = RBLNGptOssMLP(layer.mlp)
|
|
39
|
+
|
|
40
|
+
def get_mlp(self) -> nn.Module:
|
|
41
|
+
return self.mlp
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RBLNGptOssTopKRouter(nn.Module):
|
|
45
|
+
def __init__(self, model):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.weight = model.weight
|
|
48
|
+
self.bias = model.bias
|
|
49
|
+
|
|
50
|
+
def forward(self, hidden_states):
|
|
51
|
+
return F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RBLNGptOssExperts(nn.Module):
|
|
55
|
+
def __init__(self, model, top_k: Optional[int] = None):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.intermediate_size = model.intermediate_size
|
|
58
|
+
self.num_experts = model.num_experts
|
|
59
|
+
self.hidden_size = model.hidden_size
|
|
60
|
+
|
|
61
|
+
self.register_buffer(
|
|
62
|
+
"gate_proj_blocks",
|
|
63
|
+
model.gate_up_proj_blocks.data[:, ::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
|
|
64
|
+
)
|
|
65
|
+
self.register_buffer("gate_proj_scales", model.gate_up_proj_scales.data[:, ::2, :])
|
|
66
|
+
self.register_buffer(
|
|
67
|
+
"gate_proj_bias",
|
|
68
|
+
model.gate_up_proj_bias.data[:, ::2].reshape(self.num_experts, self.intermediate_size),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.register_buffer(
|
|
72
|
+
"up_proj_blocks",
|
|
73
|
+
model.gate_up_proj_blocks.data[:, 1::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
|
|
74
|
+
)
|
|
75
|
+
self.register_buffer("up_proj_scales", model.gate_up_proj_scales.data[:, 1::2, :])
|
|
76
|
+
self.register_buffer(
|
|
77
|
+
"up_proj_bias", model.gate_up_proj_bias.data[:, 1::2].reshape(self.num_experts, self.intermediate_size)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.register_buffer(
|
|
81
|
+
"down_proj_blocks", model.down_proj_blocks.data.reshape(self.num_experts, self.hidden_size, -1)
|
|
82
|
+
)
|
|
83
|
+
self.register_buffer("down_proj_scales", model.down_proj_scales.data)
|
|
84
|
+
self.register_buffer("down_proj_bias", model.down_proj_bias.data)
|
|
85
|
+
|
|
86
|
+
self.alpha = model.alpha # 1.702
|
|
87
|
+
self.limit = model.limit # 7.0
|
|
88
|
+
self.top_k = top_k
|
|
89
|
+
|
|
90
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
|
|
91
|
+
return torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
|
|
92
|
+
hidden_states,
|
|
93
|
+
self.gate_proj_blocks,
|
|
94
|
+
self.gate_proj_scales,
|
|
95
|
+
self.gate_proj_bias,
|
|
96
|
+
self.up_proj_blocks,
|
|
97
|
+
self.up_proj_scales,
|
|
98
|
+
self.up_proj_bias,
|
|
99
|
+
self.down_proj_blocks,
|
|
100
|
+
self.down_proj_scales,
|
|
101
|
+
self.down_proj_bias,
|
|
102
|
+
router_logits,
|
|
103
|
+
torch.tensor(self.alpha, dtype=hidden_states.dtype),
|
|
104
|
+
torch.tensor(self.limit, dtype=hidden_states.dtype),
|
|
105
|
+
k=self.top_k,
|
|
106
|
+
post_norm=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RBLNGptOssMLP(nn.Module):
|
|
111
|
+
def __init__(self, model):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.router = RBLNGptOssTopKRouter(model.router)
|
|
114
|
+
self.experts = RBLNGptOssExperts(model.experts, top_k=model.router.top_k)
|
|
115
|
+
|
|
116
|
+
def forward(self, hidden_states):
|
|
117
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
118
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
119
|
+
router_logits = self.router(hidden_states)
|
|
120
|
+
routed_out = self.experts(hidden_states, router_logits=router_logits)
|
|
121
|
+
routed_out = routed_out.reshape(batch_size, sequence_length, hidden_dim)
|
|
122
|
+
return routed_out
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from safetensors.torch import load_file
|
|
19
|
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
|
|
20
|
+
from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
|
|
21
|
+
from transformers.modeling_utils import PreTrainedModel, no_init_weights
|
|
22
|
+
|
|
23
|
+
from ....utils.logging import get_logger
|
|
24
|
+
from ...models.decoderonly import (
|
|
25
|
+
RBLNDecoderOnlyModelConfig,
|
|
26
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
27
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
28
|
+
)
|
|
29
|
+
from ...utils.rbln_quantization import load_weight_files
|
|
30
|
+
from .gpt_oss_architecture import RBLNGptOssWrapper
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
40
|
+
"""
|
|
41
|
+
The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
|
|
42
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
43
|
+
|
|
44
|
+
A class to convert and run pre-trained transformers based GPT-OSSForCausalLM model on RBLN devices.
|
|
45
|
+
It implements the methods to convert a pre-trained transformers GPT-OSSForCausalLM model into a RBLN transformer model by:
|
|
46
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
47
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
48
|
+
|
|
49
|
+
**Configuration:**
|
|
50
|
+
This model uses [`RBLNGptOssForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
51
|
+
the `rbln_config` parameter should be an instance of [`RBLNGptOssForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
52
|
+
|
|
53
|
+
See the [`RBLNGptOssForCausalLMConfig`] class for all available configuration options.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
```python
|
|
57
|
+
from optimum.rbln import RBLNGptOssForCausalLM
|
|
58
|
+
|
|
59
|
+
# Simple usage using rbln_* arguments
|
|
60
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
61
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
62
|
+
"openai/gpt-oss-20b",
|
|
63
|
+
export=True,
|
|
64
|
+
rbln_batch_size=1,
|
|
65
|
+
rbln_tensor_parallel_size=4,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Using a config dictionary
|
|
70
|
+
rbln_config = {
|
|
71
|
+
"batch_size": 1,
|
|
72
|
+
"tensor_parallel_size": 4,
|
|
73
|
+
}
|
|
74
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
75
|
+
"openai/gpt-oss-20b",
|
|
76
|
+
export=True,
|
|
77
|
+
rbln_config=rbln_config
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Using a RBLNGptOssForCausalLMConfig instance (recommended for type checking)
|
|
82
|
+
from optimum.rbln import RBLNGptOssForCausalLMConfig
|
|
83
|
+
|
|
84
|
+
config = RBLNGptOssForCausalLMConfig(
|
|
85
|
+
batch_size=1,
|
|
86
|
+
tensor_parallel_size=4
|
|
87
|
+
)
|
|
88
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
89
|
+
"openai/gpt-oss-20b",
|
|
90
|
+
export=True,
|
|
91
|
+
rbln_config=config
|
|
92
|
+
)
|
|
93
|
+
```
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
_decoder_wrapper_cls = RBLNGptOssWrapper
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _get_dtype(dtype: Union[str, torch.dtype] = None, torch_dtype: Union[str, torch.dtype] = None):
|
|
100
|
+
# For BC on torch_dtype argument
|
|
101
|
+
if torch_dtype is not None:
|
|
102
|
+
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
|
|
103
|
+
# If both kwargs are provided, use `dtype`
|
|
104
|
+
dtype = dtype if dtype is not None else torch_dtype
|
|
105
|
+
|
|
106
|
+
# As mxfp4_quantizer's default dtype
|
|
107
|
+
if dtype is None or dtype == "auto":
|
|
108
|
+
dtype = torch.bfloat16
|
|
109
|
+
|
|
110
|
+
return dtype
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def get_pytorch_model(
|
|
114
|
+
cls,
|
|
115
|
+
model_id: str,
|
|
116
|
+
*args,
|
|
117
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
118
|
+
dtype: Union[str, torch.dtype] = None,
|
|
119
|
+
torch_dtype: Union[str, torch.dtype] = None,
|
|
120
|
+
config: Optional[PretrainedConfig] = None,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> PreTrainedModel:
|
|
123
|
+
safetensor_files = load_weight_files(model_id, exception_keywords=["original"])
|
|
124
|
+
state_dict = {k: v for f in safetensor_files for k, v in load_file(f).items()}
|
|
125
|
+
|
|
126
|
+
if config is None:
|
|
127
|
+
config, kwargs = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True)
|
|
128
|
+
|
|
129
|
+
dtype = cls._get_dtype(dtype, torch_dtype)
|
|
130
|
+
|
|
131
|
+
with no_init_weights():
|
|
132
|
+
model = AutoModelForCausalLM.from_config(config, dtype=dtype, **kwargs)
|
|
133
|
+
|
|
134
|
+
_replace_with_mxfp4_linear(model, config)
|
|
135
|
+
model.load_state_dict(state_dict, strict=False)
|
|
136
|
+
|
|
137
|
+
return model
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def _update_rbln_config(
|
|
141
|
+
cls,
|
|
142
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
143
|
+
model: Optional["PreTrainedModel"] = None,
|
|
144
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
145
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
146
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
147
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
148
|
+
|
|
149
|
+
if rbln_config.use_attention_mask:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"use_attention_mask is not supported for GPT-OSS because custom attention does not support attention sink for masked attention"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return rbln_config
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _replace_with_mxfp4_linear(
|
|
158
|
+
model,
|
|
159
|
+
config,
|
|
160
|
+
):
|
|
161
|
+
for name, module in model.named_children():
|
|
162
|
+
if module.__class__.__name__ == "GptOssExperts":
|
|
163
|
+
model._modules[name] = Mxfp4GptOssExperts(config)
|
|
164
|
+
if len(list(module.children())) > 0:
|
|
165
|
+
_replace_with_mxfp4_linear(module, config)
|
|
@@ -50,11 +50,14 @@ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
|
|
|
50
50
|
Raises:
|
|
51
51
|
ValueError: If batch_size is not a positive integer.
|
|
52
52
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
53
|
+
|
|
54
|
+
super().__init__(batch_size=batch_size, **kwargs)
|
|
55
|
+
self.encoder = self.initialize_submodule_config(submodule_config=encoder, batch_size=self.batch_size)
|
|
56
|
+
self.decoder = self.initialize_submodule_config(submodule_config=decoder, batch_size=self.batch_size)
|
|
57
|
+
self.text_backbone = self.initialize_submodule_config(
|
|
58
|
+
submodule_config=text_backbone, batch_size=self.batch_size
|
|
59
|
+
)
|
|
60
|
+
self.backbone = self.initialize_submodule_config(submodule_config=backbone, batch_size=self.batch_size)
|
|
58
61
|
self.output_attentions = output_attentions if output_attentions is not None else False
|
|
59
62
|
self.output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
|
60
63
|
|
|
@@ -509,10 +509,12 @@ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
|
|
|
509
509
|
|
|
510
510
|
# mask vision for language
|
|
511
511
|
if vision_attention_mask is not None:
|
|
512
|
-
# RBLN FIX: bool tensor to float tensor
|
|
513
|
-
mask = vision_attention_mask
|
|
514
|
-
|
|
515
|
-
|
|
512
|
+
# RBLN FIX: bool tensor to float tensor, broadcast across heads and src_len
|
|
513
|
+
mask = vision_attention_mask
|
|
514
|
+
if mask.dim() == 3:
|
|
515
|
+
mask = mask[..., 0]
|
|
516
|
+
mask = mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
|
517
|
+
text_attn_weights = text_attn_weights + mask * torch.finfo(text_attn_weights.dtype).min
|
|
516
518
|
|
|
517
519
|
text_attn_weights = text_attn_weights.softmax(dim=-1)
|
|
518
520
|
|
|
@@ -116,7 +116,6 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
|
|
|
116
116
|
RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
|
|
117
117
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
118
118
|
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
119
|
-
|
|
120
119
|
Important Note:
|
|
121
120
|
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
|
122
121
|
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
@@ -71,6 +71,12 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class MidmModel(DecoderOnlyModel):
|
|
74
|
+
def __init__(self, model, layers, rbln_config, use_learned_pos_emb=None, use_rotary_emb=True):
|
|
75
|
+
super().__init__(
|
|
76
|
+
model, layers, rbln_config, use_learned_pos_emb=use_learned_pos_emb, use_rotary_emb=use_rotary_emb
|
|
77
|
+
)
|
|
78
|
+
self.use_layernorm1p = getattr(model, "use_layernorm1p", False)
|
|
79
|
+
|
|
74
80
|
def get_layernorm1p(self, module: nn.LayerNorm):
|
|
75
81
|
def layernorm1p(input: torch.Tensor):
|
|
76
82
|
"""Applies Layer Normalization with a slight modification on the weights."""
|
|
@@ -81,19 +87,22 @@ class MidmModel(DecoderOnlyModel):
|
|
|
81
87
|
return layernorm1p
|
|
82
88
|
|
|
83
89
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
84
|
-
if self.
|
|
85
|
-
return self.get_layernorm1p(self.
|
|
86
|
-
|
|
87
|
-
return self._original_mod.ln_f
|
|
90
|
+
if self.use_layernorm1p:
|
|
91
|
+
return self.get_layernorm1p(self.norm)
|
|
92
|
+
return self.norm
|
|
88
93
|
|
|
89
94
|
def get_embedding(self) -> nn.Embedding:
|
|
90
|
-
return self.
|
|
95
|
+
return self.embed_tokens
|
|
91
96
|
|
|
92
97
|
def get_pos_embedding(self) -> nn.Embedding:
|
|
93
|
-
return self.
|
|
98
|
+
return self.embed_positions
|
|
94
99
|
|
|
95
100
|
|
|
96
101
|
class MidmLayer(DecoderOnlyLayer):
|
|
102
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config=None):
|
|
103
|
+
super().__init__(layer, self_attn, lora_config)
|
|
104
|
+
self.use_layernorm1p = getattr(layer, "use_layernorm1p", False)
|
|
105
|
+
|
|
97
106
|
def get_layernorm1p(self, module: nn.LayerNorm):
|
|
98
107
|
def layernorm1p(input: torch.Tensor):
|
|
99
108
|
"""Applies Layer Normalization with a slight modification on the weights."""
|
|
@@ -104,24 +113,22 @@ class MidmLayer(DecoderOnlyLayer):
|
|
|
104
113
|
return layernorm1p
|
|
105
114
|
|
|
106
115
|
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
107
|
-
if self.
|
|
108
|
-
return self.get_layernorm1p(self.
|
|
109
|
-
|
|
110
|
-
return self._original_mod.ln_1
|
|
116
|
+
if self.use_layernorm1p:
|
|
117
|
+
return self.get_layernorm1p(self.pre_attention_layernorm)
|
|
118
|
+
return self.pre_attention_layernorm
|
|
111
119
|
|
|
112
120
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
113
|
-
if self.
|
|
114
|
-
return self.get_layernorm1p(self.
|
|
115
|
-
|
|
116
|
-
return self._original_mod.ln_2
|
|
121
|
+
if self.use_layernorm1p:
|
|
122
|
+
return self.get_layernorm1p(self.post_attention_layernorm)
|
|
123
|
+
return self.post_attention_layernorm
|
|
117
124
|
|
|
118
125
|
|
|
119
126
|
class MidmAttention(DecoderOnlyAttention):
|
|
120
|
-
def __post_init__(self):
|
|
121
|
-
self.c_attn =
|
|
122
|
-
self.o_proj =
|
|
123
|
-
self.split_size =
|
|
124
|
-
self.num_key_value_heads =
|
|
127
|
+
def __post_init__(self, self_attn):
|
|
128
|
+
self.c_attn = self_attn.c_attn
|
|
129
|
+
self.o_proj = self_attn.c_proj
|
|
130
|
+
self.split_size = self_attn.split_size
|
|
131
|
+
self.num_key_value_heads = self_attn.num_heads
|
|
125
132
|
|
|
126
133
|
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
127
134
|
if lora_int_id is not None:
|
|
@@ -130,12 +137,12 @@ class MidmAttention(DecoderOnlyAttention):
|
|
|
130
137
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
131
138
|
return query_states, key_states, value_states
|
|
132
139
|
|
|
133
|
-
def get_attn_scale(self):
|
|
140
|
+
def get_attn_scale(self, self_attn):
|
|
134
141
|
scale = 1.0
|
|
135
|
-
if
|
|
142
|
+
if self_attn.scale_attn_weights:
|
|
136
143
|
scale /= math.sqrt(self.head_dim)
|
|
137
144
|
|
|
138
|
-
if
|
|
145
|
+
if self_attn.scale_attn_by_inverse_layer_idx:
|
|
139
146
|
scale /= 1 + self.layer_idx
|
|
140
147
|
|
|
141
148
|
return scale
|