optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch import nn
|
|
21
|
+
|
|
22
|
+
from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
|
|
23
|
+
from ..decoderonly.decoderonly_architecture import (
|
|
24
|
+
DecoderOnlyAttention,
|
|
25
|
+
DecoderOnlyLayer,
|
|
26
|
+
DecoderOnlyWrapper,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RBLNGptOssWrapper(DecoderOnlyWrapper):
|
|
31
|
+
def get_rbln_layer_class(self):
|
|
32
|
+
return RBLNGptOssLayer
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RBLNGptOssLayer(DecoderOnlyLayer):
|
|
36
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
|
|
37
|
+
super().__init__(layer, self_attn, lora_config)
|
|
38
|
+
self.mlp = RBLNGptOssMLP(layer.mlp)
|
|
39
|
+
|
|
40
|
+
def get_mlp(self) -> nn.Module:
|
|
41
|
+
return self.mlp
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RBLNGptOssTopKRouter(nn.Module):
|
|
45
|
+
def __init__(self, model):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.weight = model.weight
|
|
48
|
+
self.bias = model.bias
|
|
49
|
+
|
|
50
|
+
def forward(self, hidden_states):
|
|
51
|
+
return F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RBLNGptOssExperts(nn.Module):
|
|
55
|
+
def __init__(self, model, top_k: Optional[int] = None):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.intermediate_size = model.intermediate_size
|
|
58
|
+
self.num_experts = model.num_experts
|
|
59
|
+
self.hidden_size = model.hidden_size
|
|
60
|
+
|
|
61
|
+
self.register_buffer(
|
|
62
|
+
"gate_proj_blocks",
|
|
63
|
+
model.gate_up_proj_blocks.data[:, ::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
|
|
64
|
+
)
|
|
65
|
+
self.register_buffer("gate_proj_scales", model.gate_up_proj_scales.data[:, ::2, :])
|
|
66
|
+
self.register_buffer(
|
|
67
|
+
"gate_proj_bias",
|
|
68
|
+
model.gate_up_proj_bias.data[:, ::2].reshape(self.num_experts, self.intermediate_size),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.register_buffer(
|
|
72
|
+
"up_proj_blocks",
|
|
73
|
+
model.gate_up_proj_blocks.data[:, 1::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
|
|
74
|
+
)
|
|
75
|
+
self.register_buffer("up_proj_scales", model.gate_up_proj_scales.data[:, 1::2, :])
|
|
76
|
+
self.register_buffer(
|
|
77
|
+
"up_proj_bias", model.gate_up_proj_bias.data[:, 1::2].reshape(self.num_experts, self.intermediate_size)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.register_buffer(
|
|
81
|
+
"down_proj_blocks", model.down_proj_blocks.data.reshape(self.num_experts, self.hidden_size, -1)
|
|
82
|
+
)
|
|
83
|
+
self.register_buffer("down_proj_scales", model.down_proj_scales.data)
|
|
84
|
+
self.register_buffer("down_proj_bias", model.down_proj_bias.data)
|
|
85
|
+
|
|
86
|
+
self.alpha = model.alpha # 1.702
|
|
87
|
+
self.limit = model.limit # 7.0
|
|
88
|
+
self.top_k = top_k
|
|
89
|
+
|
|
90
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
|
|
91
|
+
return torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
|
|
92
|
+
hidden_states,
|
|
93
|
+
self.gate_proj_blocks,
|
|
94
|
+
self.gate_proj_scales,
|
|
95
|
+
self.gate_proj_bias,
|
|
96
|
+
self.up_proj_blocks,
|
|
97
|
+
self.up_proj_scales,
|
|
98
|
+
self.up_proj_bias,
|
|
99
|
+
self.down_proj_blocks,
|
|
100
|
+
self.down_proj_scales,
|
|
101
|
+
self.down_proj_bias,
|
|
102
|
+
router_logits,
|
|
103
|
+
torch.tensor(self.alpha, dtype=hidden_states.dtype),
|
|
104
|
+
torch.tensor(self.limit, dtype=hidden_states.dtype),
|
|
105
|
+
k=self.top_k,
|
|
106
|
+
post_norm=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RBLNGptOssMLP(nn.Module):
|
|
111
|
+
def __init__(self, model):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.router = RBLNGptOssTopKRouter(model.router)
|
|
114
|
+
self.experts = RBLNGptOssExperts(model.experts, top_k=model.router.top_k)
|
|
115
|
+
|
|
116
|
+
def forward(self, hidden_states):
|
|
117
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
118
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
119
|
+
router_logits = self.router(hidden_states)
|
|
120
|
+
routed_out = self.experts(hidden_states, router_logits=router_logits)
|
|
121
|
+
routed_out = routed_out.reshape(batch_size, sequence_length, hidden_dim)
|
|
122
|
+
return routed_out
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from safetensors.torch import load_file
|
|
19
|
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
|
|
20
|
+
from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
|
|
21
|
+
from transformers.modeling_utils import PreTrainedModel, no_init_weights
|
|
22
|
+
|
|
23
|
+
from ....utils.logging import get_logger
|
|
24
|
+
from ...models.decoderonly import (
|
|
25
|
+
RBLNDecoderOnlyModelConfig,
|
|
26
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
27
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
28
|
+
)
|
|
29
|
+
from ...utils.rbln_quantization import load_weight_files
|
|
30
|
+
from .gpt_oss_architecture import RBLNGptOssWrapper
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
40
|
+
"""
|
|
41
|
+
The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
|
|
42
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
43
|
+
|
|
44
|
+
A class to convert and run pre-trained transformers based GptOssForCausalLM model on RBLN devices.
|
|
45
|
+
It implements the methods to convert a pre-trained transformers GptOssForCausalLM model into a RBLN transformer model by:
|
|
46
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
47
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
48
|
+
|
|
49
|
+
**Configuration:**
|
|
50
|
+
This model uses [`RBLNGptOssForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
51
|
+
the `rbln_config` parameter should be an instance of [`RBLNGptOssForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
52
|
+
|
|
53
|
+
See the [`RBLNGptOssForCausalLMConfig`] class for all available configuration options.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
```python
|
|
57
|
+
from optimum.rbln import RBLNGptOssForCausalLM
|
|
58
|
+
|
|
59
|
+
# Simple usage using rbln_* arguments
|
|
60
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
61
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
62
|
+
"openai/gpt-oss-20b",
|
|
63
|
+
export=True,
|
|
64
|
+
rbln_batch_size=1,
|
|
65
|
+
rbln_tensor_parallel_size=8,
|
|
66
|
+
rbln_kvcache_partition_len=8192,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Using a config dictionary
|
|
71
|
+
rbln_config = {
|
|
72
|
+
"batch_size": 1,
|
|
73
|
+
"tensor_parallel_size": 8,
|
|
74
|
+
"kvcache_partition_len": 8192,
|
|
75
|
+
}
|
|
76
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
77
|
+
"openai/gpt-oss-20b",
|
|
78
|
+
export=True,
|
|
79
|
+
rbln_config=rbln_config,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Using a RBLNGptOssForCausalLMConfig instance (recommended for type checking)
|
|
84
|
+
from optimum.rbln import RBLNGptOssForCausalLMConfig
|
|
85
|
+
|
|
86
|
+
config = RBLNGptOssForCausalLMConfig(
|
|
87
|
+
batch_size=1,
|
|
88
|
+
tensor_parallel_size=8,
|
|
89
|
+
kvcache_partition_len=8192,
|
|
90
|
+
)
|
|
91
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
92
|
+
"openai/gpt-oss-20b",
|
|
93
|
+
export=True,
|
|
94
|
+
rbln_config=config,
|
|
95
|
+
)
|
|
96
|
+
```
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
_decoder_wrapper_cls = RBLNGptOssWrapper
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _get_dtype(dtype: Union[str, torch.dtype] = None, torch_dtype: Union[str, torch.dtype] = None):
|
|
103
|
+
# For BC on torch_dtype argument
|
|
104
|
+
if torch_dtype is not None:
|
|
105
|
+
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
|
|
106
|
+
# If both kwargs are provided, use `dtype`
|
|
107
|
+
dtype = dtype if dtype is not None else torch_dtype
|
|
108
|
+
|
|
109
|
+
# As mxfp4_quantizer's default dtype
|
|
110
|
+
if dtype is None or dtype == "auto":
|
|
111
|
+
dtype = torch.bfloat16
|
|
112
|
+
|
|
113
|
+
return dtype
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def get_pytorch_model(
|
|
117
|
+
cls,
|
|
118
|
+
model_id: str,
|
|
119
|
+
*args,
|
|
120
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
121
|
+
dtype: Union[str, torch.dtype] = None,
|
|
122
|
+
torch_dtype: Union[str, torch.dtype] = None,
|
|
123
|
+
config: Optional[PretrainedConfig] = None,
|
|
124
|
+
**kwargs,
|
|
125
|
+
) -> PreTrainedModel:
|
|
126
|
+
safetensor_files = load_weight_files(model_id, exception_keywords=["original"])
|
|
127
|
+
state_dict = {k: v for f in safetensor_files for k, v in load_file(f).items()}
|
|
128
|
+
|
|
129
|
+
if config is None:
|
|
130
|
+
config, kwargs = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True)
|
|
131
|
+
|
|
132
|
+
dtype = cls._get_dtype(dtype, torch_dtype)
|
|
133
|
+
|
|
134
|
+
with no_init_weights():
|
|
135
|
+
model = AutoModelForCausalLM.from_config(config, dtype=dtype, **kwargs)
|
|
136
|
+
|
|
137
|
+
_replace_with_mxfp4_linear(model, config)
|
|
138
|
+
model.load_state_dict(state_dict, strict=False)
|
|
139
|
+
|
|
140
|
+
return model
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def _update_rbln_config(
|
|
144
|
+
cls,
|
|
145
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
146
|
+
model: Optional["PreTrainedModel"] = None,
|
|
147
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
148
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
149
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
150
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
151
|
+
|
|
152
|
+
if rbln_config.use_attention_mask:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"use_attention_mask is not supported for GPT-OSS because custom attention does not support attention sink for masked attention"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return rbln_config
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _replace_with_mxfp4_linear(
|
|
161
|
+
model,
|
|
162
|
+
config,
|
|
163
|
+
):
|
|
164
|
+
for name, module in model.named_children():
|
|
165
|
+
if module.__class__.__name__ == "GptOssExperts":
|
|
166
|
+
model._modules[name] = Mxfp4GptOssExperts(config)
|
|
167
|
+
if len(list(module.children())) > 0:
|
|
168
|
+
_replace_with_mxfp4_linear(module, config)
|
|
@@ -50,11 +50,14 @@ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
|
|
|
50
50
|
Raises:
|
|
51
51
|
ValueError: If batch_size is not a positive integer.
|
|
52
52
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
53
|
+
|
|
54
|
+
super().__init__(batch_size=batch_size, **kwargs)
|
|
55
|
+
self.encoder = self.initialize_submodule_config(submodule_config=encoder, batch_size=self.batch_size)
|
|
56
|
+
self.decoder = self.initialize_submodule_config(submodule_config=decoder, batch_size=self.batch_size)
|
|
57
|
+
self.text_backbone = self.initialize_submodule_config(
|
|
58
|
+
submodule_config=text_backbone, batch_size=self.batch_size
|
|
59
|
+
)
|
|
60
|
+
self.backbone = self.initialize_submodule_config(submodule_config=backbone, batch_size=self.batch_size)
|
|
58
61
|
self.output_attentions = output_attentions if output_attentions is not None else False
|
|
59
62
|
self.output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
|
60
63
|
|
|
@@ -509,10 +509,12 @@ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
|
|
|
509
509
|
|
|
510
510
|
# mask vision for language
|
|
511
511
|
if vision_attention_mask is not None:
|
|
512
|
-
# RBLN FIX: bool tensor to float tensor
|
|
513
|
-
mask = vision_attention_mask
|
|
514
|
-
|
|
515
|
-
|
|
512
|
+
# RBLN FIX: bool tensor to float tensor, broadcast across heads and src_len
|
|
513
|
+
mask = vision_attention_mask
|
|
514
|
+
if mask.dim() == 3:
|
|
515
|
+
mask = mask[..., 0]
|
|
516
|
+
mask = mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
|
517
|
+
text_attn_weights = text_attn_weights + mask * torch.finfo(text_attn_weights.dtype).min
|
|
516
518
|
|
|
517
519
|
text_attn_weights = text_attn_weights.softmax(dim=-1)
|
|
518
520
|
|
|
@@ -116,7 +116,6 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
|
|
|
116
116
|
RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
|
|
117
117
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
118
118
|
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
119
|
-
|
|
120
119
|
Important Note:
|
|
121
120
|
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
|
122
121
|
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
@@ -71,6 +71,12 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class MidmModel(DecoderOnlyModel):
|
|
74
|
+
def __init__(self, model, layers, rbln_config, use_learned_pos_emb=None, use_rotary_emb=True):
|
|
75
|
+
super().__init__(
|
|
76
|
+
model, layers, rbln_config, use_learned_pos_emb=use_learned_pos_emb, use_rotary_emb=use_rotary_emb
|
|
77
|
+
)
|
|
78
|
+
self.use_layernorm1p = getattr(model, "use_layernorm1p", False)
|
|
79
|
+
|
|
74
80
|
def get_layernorm1p(self, module: nn.LayerNorm):
|
|
75
81
|
def layernorm1p(input: torch.Tensor):
|
|
76
82
|
"""Applies Layer Normalization with a slight modification on the weights."""
|
|
@@ -81,19 +87,22 @@ class MidmModel(DecoderOnlyModel):
|
|
|
81
87
|
return layernorm1p
|
|
82
88
|
|
|
83
89
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
84
|
-
if self.
|
|
85
|
-
return self.get_layernorm1p(self.
|
|
86
|
-
|
|
87
|
-
return self._original_mod.ln_f
|
|
90
|
+
if self.use_layernorm1p:
|
|
91
|
+
return self.get_layernorm1p(self.norm)
|
|
92
|
+
return self.norm
|
|
88
93
|
|
|
89
94
|
def get_embedding(self) -> nn.Embedding:
|
|
90
|
-
return self.
|
|
95
|
+
return self.embed_tokens
|
|
91
96
|
|
|
92
97
|
def get_pos_embedding(self) -> nn.Embedding:
|
|
93
|
-
return self.
|
|
98
|
+
return self.embed_positions
|
|
94
99
|
|
|
95
100
|
|
|
96
101
|
class MidmLayer(DecoderOnlyLayer):
|
|
102
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config=None):
|
|
103
|
+
super().__init__(layer, self_attn, lora_config)
|
|
104
|
+
self.use_layernorm1p = getattr(layer, "use_layernorm1p", False)
|
|
105
|
+
|
|
97
106
|
def get_layernorm1p(self, module: nn.LayerNorm):
|
|
98
107
|
def layernorm1p(input: torch.Tensor):
|
|
99
108
|
"""Applies Layer Normalization with a slight modification on the weights."""
|
|
@@ -104,24 +113,22 @@ class MidmLayer(DecoderOnlyLayer):
|
|
|
104
113
|
return layernorm1p
|
|
105
114
|
|
|
106
115
|
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
107
|
-
if self.
|
|
108
|
-
return self.get_layernorm1p(self.
|
|
109
|
-
|
|
110
|
-
return self._original_mod.ln_1
|
|
116
|
+
if self.use_layernorm1p:
|
|
117
|
+
return self.get_layernorm1p(self.pre_attention_layernorm)
|
|
118
|
+
return self.pre_attention_layernorm
|
|
111
119
|
|
|
112
120
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
113
|
-
if self.
|
|
114
|
-
return self.get_layernorm1p(self.
|
|
115
|
-
|
|
116
|
-
return self._original_mod.ln_2
|
|
121
|
+
if self.use_layernorm1p:
|
|
122
|
+
return self.get_layernorm1p(self.post_attention_layernorm)
|
|
123
|
+
return self.post_attention_layernorm
|
|
117
124
|
|
|
118
125
|
|
|
119
126
|
class MidmAttention(DecoderOnlyAttention):
|
|
120
|
-
def __post_init__(self):
|
|
121
|
-
self.c_attn =
|
|
122
|
-
self.o_proj =
|
|
123
|
-
self.split_size =
|
|
124
|
-
self.num_key_value_heads =
|
|
127
|
+
def __post_init__(self, self_attn):
|
|
128
|
+
self.c_attn = self_attn.c_attn
|
|
129
|
+
self.o_proj = self_attn.c_proj
|
|
130
|
+
self.split_size = self_attn.split_size
|
|
131
|
+
self.num_key_value_heads = self_attn.num_heads
|
|
125
132
|
|
|
126
133
|
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
127
134
|
if lora_int_id is not None:
|
|
@@ -130,12 +137,12 @@ class MidmAttention(DecoderOnlyAttention):
|
|
|
130
137
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
131
138
|
return query_states, key_states, value_states
|
|
132
139
|
|
|
133
|
-
def get_attn_scale(self):
|
|
140
|
+
def get_attn_scale(self, self_attn):
|
|
134
141
|
scale = 1.0
|
|
135
|
-
if
|
|
142
|
+
if self_attn.scale_attn_weights:
|
|
136
143
|
scale /= math.sqrt(self.head_dim)
|
|
137
144
|
|
|
138
|
-
if
|
|
145
|
+
if self_attn.scale_attn_by_inverse_layer_idx:
|
|
139
146
|
scale /= 1 + self.layer_idx
|
|
140
147
|
|
|
141
148
|
return scale
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_mixtral import RBLNMixtralForCausalLMConfig
|
|
16
|
+
from .modeling_mixtral import RBLNMixtralForCausalLM
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2026 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNMixtralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Mixtral models.
|
|
21
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
22
|
+
Example usage:
|
|
23
|
+
```python
|
|
24
|
+
from optimum.rbln import RBLNMixtralForCausalLM, RBLNMixtralForCausalLMConfig
|
|
25
|
+
# Create a configuration object
|
|
26
|
+
config = RBLNMixtralForCausalLMConfig(
|
|
27
|
+
batch_size=1,
|
|
28
|
+
max_seq_len=32768,
|
|
29
|
+
tensor_parallel_size=4
|
|
30
|
+
)
|
|
31
|
+
# Use the configuration with from_pretrained
|
|
32
|
+
model = RBLNMixtralForCausalLM.from_pretrained(
|
|
33
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
34
|
+
export=True,
|
|
35
|
+
rbln_config=config
|
|
36
|
+
)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright 2026 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
|
|
21
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MixtralWrapper(DecoderOnlyWrapper):
|
|
25
|
+
def get_rbln_layer_class(self):
|
|
26
|
+
return MixtralLayer
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MixtralLayer(DecoderOnlyLayer):
|
|
30
|
+
_MLP_ATTR = ("block_sparse_moe",)
|
|
31
|
+
|
|
32
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
|
|
33
|
+
super().__init__(layer, self_attn, lora_config)
|
|
34
|
+
self.block_sparse_moe = MixtralSparseMoeBlock(self.mlp)
|
|
35
|
+
|
|
36
|
+
def get_mlp(self) -> nn.Module:
|
|
37
|
+
return self.block_sparse_moe
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MixtralSparseMoeBlock(nn.Module):
|
|
41
|
+
def __init__(self, model: nn.Module):
|
|
42
|
+
super().__init__()
|
|
43
|
+
# self.num_experts = model.num_experts
|
|
44
|
+
self.top_k = model.top_k
|
|
45
|
+
self.gate = model.gate
|
|
46
|
+
self.experts = MixtralBlockSparseTop2MLP(model.experts, self.top_k)
|
|
47
|
+
|
|
48
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
50
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
51
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
52
|
+
router_logits = self.gate(hidden_states)
|
|
53
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
|
54
|
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
55
|
+
return final_hidden_states
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MixtralBlockSparseTop2MLP(nn.Module):
|
|
59
|
+
def __init__(self, expert_list, top_k):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.hidden_dim = expert_list[0].hidden_dim
|
|
62
|
+
self.ffn_dim = expert_list[0].ffn_dim
|
|
63
|
+
self.top_k = top_k
|
|
64
|
+
|
|
65
|
+
self.num_experts = len(expert_list)
|
|
66
|
+
self.w1 = nn.Linear(self.hidden_dim, self.num_experts * self.ffn_dim, bias=False)
|
|
67
|
+
self.w2 = nn.Linear(self.num_experts * self.ffn_dim, self.hidden_dim, bias=False)
|
|
68
|
+
self.w3 = nn.Linear(self.hidden_dim, self.num_experts * self.ffn_dim, bias=False)
|
|
69
|
+
self.w1.weight.data = torch.stack([expert.w1.weight.data for expert in expert_list], dim=0)
|
|
70
|
+
self.w2.weight.data = torch.stack([expert.w2.weight.data for expert in expert_list], dim=0)
|
|
71
|
+
self.w3.weight.data = torch.stack([expert.w3.weight.data for expert in expert_list], dim=0)
|
|
72
|
+
|
|
73
|
+
def forward(self, x, router_logits):
|
|
74
|
+
return torch.ops.rbln_custom_ops.custom_moe_glu(
|
|
75
|
+
x, self.w1.weight, self.w3.weight, self.w2.weight, router_logits, self.top_k, True
|
|
76
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2026 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
16
|
+
from .mixtral_architecture import MixtralWrapper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RBLNMixtralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
20
|
+
"""
|
|
21
|
+
The Mixtral is a Mixture-of-Experts (MoE) variant of Mixtral, available as a base model and an aligned chat model.
|
|
22
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
23
|
+
A class to convert and run pre-trained transformers based MixtralForCausalLM model on RBLN devices.
|
|
24
|
+
It implements the methods to convert a pre-trained transformers MixtralForCausalLM model into a RBLN transformer model by:
|
|
25
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
26
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
27
|
+
**Configuration:**
|
|
28
|
+
This model uses [`RBLNMixtralForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
29
|
+
the `rbln_config` parameter should be an instance of [`RBLNMixtralForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
30
|
+
See the [`RBLNMixtralForCausalLMConfig`] class for all available configuration options.
|
|
31
|
+
Examples:
|
|
32
|
+
```python
|
|
33
|
+
from optimum.rbln import RBLNMixtralForCausalLM
|
|
34
|
+
# Simple usage using rbln_* arguments
|
|
35
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
36
|
+
model = RBLNMixtralForCausalLM.from_pretrained(
|
|
37
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
38
|
+
export=True,
|
|
39
|
+
rbln_batch_size=1,
|
|
40
|
+
rbln_tensor_parallel_size=4,
|
|
41
|
+
)
|
|
42
|
+
# Using a config dictionary
|
|
43
|
+
rbln_config = {
|
|
44
|
+
"batch_size": 1,
|
|
45
|
+
"max_seq_len": 32768,
|
|
46
|
+
"tensor_parallel_size": 4,
|
|
47
|
+
}
|
|
48
|
+
model = RBLNMixtralForCausalLM.from_pretrained(
|
|
49
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
50
|
+
export=True,
|
|
51
|
+
rbln_config=rbln_config
|
|
52
|
+
)
|
|
53
|
+
# Using a RBLNMixtralForCausalLMConfig instance (recommended for type checking)
|
|
54
|
+
from optimum.rbln import RBLNMixtralForCausalLMConfig
|
|
55
|
+
config = RBLNMixtralForCausalLMConfig(
|
|
56
|
+
batch_size=1,
|
|
57
|
+
max_seq_len=32768,
|
|
58
|
+
tensor_parallel_size=4
|
|
59
|
+
)
|
|
60
|
+
model = RBLNMixtralForCausalLM.from_pretrained(
|
|
61
|
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
62
|
+
export=True,
|
|
63
|
+
rbln_config=config
|
|
64
|
+
)
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_decoder_wrapper_cls = MixtralWrapper
|