optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNGemma2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Gemma2 models.
|
|
21
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
22
|
+
Example usage:
|
|
23
|
+
```python
|
|
24
|
+
from optimum.rbln import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig
|
|
25
|
+
# Create a configuration object
|
|
26
|
+
config = RBLNGemma2ForCausalLMConfig(
|
|
27
|
+
batch_size=1,
|
|
28
|
+
max_seq_len=8192,
|
|
29
|
+
tensor_parallel_size=4
|
|
30
|
+
)
|
|
31
|
+
# Use the configuration with from_pretrained
|
|
32
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
33
|
+
"google/gemma-2-9b",
|
|
34
|
+
export=True,
|
|
35
|
+
rbln_config=config
|
|
36
|
+
)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RBLNGemma2ModelConfig(RBLNDecoderOnlyModelConfig):
|
|
42
|
+
"""
|
|
43
|
+
Configuration class for RBLN Gemma2 models.
|
|
44
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
45
|
+
"""
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from ...models.decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyModel
|
|
20
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Gemma2Wrapper(DecoderOnlyWrapper):
|
|
24
|
+
def get_rbln_layer_class(self):
|
|
25
|
+
return Gemma2DecoderLayer
|
|
26
|
+
|
|
27
|
+
def get_rbln_attn_class(self):
|
|
28
|
+
return Gemma2Attention
|
|
29
|
+
|
|
30
|
+
def get_rbln_model_class(self):
|
|
31
|
+
return Gemma2Model
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Gemma2DecoderLayer(DecoderOnlyLayer):
|
|
35
|
+
_PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
|
|
36
|
+
_POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
|
|
37
|
+
|
|
38
|
+
def forward(
|
|
39
|
+
self,
|
|
40
|
+
hidden_states: torch.Tensor,
|
|
41
|
+
attention_mask: torch.Tensor,
|
|
42
|
+
seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
|
|
43
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
44
|
+
cos: Optional[torch.Tensor] = None,
|
|
45
|
+
sin: Optional[torch.Tensor] = None,
|
|
46
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
47
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
48
|
+
):
|
|
49
|
+
residual = hidden_states
|
|
50
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
51
|
+
|
|
52
|
+
hidden_states = self.self_attn(
|
|
53
|
+
hidden_states=hidden_states,
|
|
54
|
+
attention_mask=attention_mask,
|
|
55
|
+
seq_positions=seq_positions,
|
|
56
|
+
past_key_values=past_key_values,
|
|
57
|
+
cos=cos,
|
|
58
|
+
sin=sin,
|
|
59
|
+
block_tables=block_tables,
|
|
60
|
+
lora_int_id=lora_int_id,
|
|
61
|
+
)
|
|
62
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
63
|
+
hidden_states = residual + hidden_states
|
|
64
|
+
|
|
65
|
+
# Fully Connected
|
|
66
|
+
residual = hidden_states
|
|
67
|
+
hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
|
|
68
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
69
|
+
hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
|
|
70
|
+
hidden_states = residual + hidden_states
|
|
71
|
+
|
|
72
|
+
return hidden_states
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Gemma2Attention(DecoderOnlyAttention):
|
|
76
|
+
def get_attn_scale(self, self_attn):
|
|
77
|
+
return self_attn.config.query_pre_attn_scalar**-0.5
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Gemma2Model(DecoderOnlyModel):
|
|
81
|
+
@property
|
|
82
|
+
def hidden_multiplier(self):
|
|
83
|
+
return self.config.hidden_size**0.5
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from ....utils import logging
|
|
17
|
+
from ...models.decoderonly import (
|
|
18
|
+
RBLNDecoderOnlyModel,
|
|
19
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
20
|
+
)
|
|
21
|
+
from .gemma2_architecture import Gemma2Wrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBLNGemma2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
28
|
+
"""
|
|
29
|
+
The Gemma2 Model transformer with a language modeling head (linear layer) on top.
|
|
30
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
31
|
+
|
|
32
|
+
A class to convert and run pre-trained transformers based Gemma2ForCausalLM model on RBLN devices.
|
|
33
|
+
It implements the methods to convert a pre-trained transformers Gemma2ForCausalLM model into a RBLN transformer model by:
|
|
34
|
+
|
|
35
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
36
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
37
|
+
|
|
38
|
+
**Configuration:**
|
|
39
|
+
This model uses [`RBLNGemma2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
40
|
+
the `rbln_config` parameter should be an instance of [`RBLNGemma2ForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
41
|
+
|
|
42
|
+
See the [`RBLNGemma2ForCausalLMConfig`] class for all available configuration options.
|
|
43
|
+
Examples:
|
|
44
|
+
```python
|
|
45
|
+
from optimum.rbln import RBLNGemma2ForCausalLM
|
|
46
|
+
# Simple usage using rbln_* arguments
|
|
47
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
48
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
49
|
+
"google/gemma-2-9b",
|
|
50
|
+
export=True,
|
|
51
|
+
rbln_batch_size=1,
|
|
52
|
+
rbln_tensor_parallel_size=4,
|
|
53
|
+
)
|
|
54
|
+
# Using a config dictionary
|
|
55
|
+
rbln_config = {
|
|
56
|
+
"batch_size": 1,
|
|
57
|
+
"max_seq_len": 8192,
|
|
58
|
+
"tensor_parallel_size": 4,
|
|
59
|
+
}
|
|
60
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
61
|
+
"google/gemma-2-9b",
|
|
62
|
+
export=True,
|
|
63
|
+
rbln_config=rbln_config
|
|
64
|
+
)
|
|
65
|
+
# Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
|
|
66
|
+
from optimum.rbln import RBLNGemma2ForCausalLMConfig
|
|
67
|
+
config = RBLNGemma2ForCausalLMConfig(
|
|
68
|
+
batch_size=1,
|
|
69
|
+
max_seq_len=8192,
|
|
70
|
+
tensor_parallel_size=4
|
|
71
|
+
)
|
|
72
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
73
|
+
"google/gemma-2-9b",
|
|
74
|
+
export=True,
|
|
75
|
+
rbln_config=config
|
|
76
|
+
)
|
|
77
|
+
```
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
_decoder_wrapper_cls = Gemma2Wrapper
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class RBLNGemma2Model(RBLNDecoderOnlyModel):
|
|
84
|
+
"""
|
|
85
|
+
The Gemma2 Model transformer without a language modeling head.
|
|
86
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
87
|
+
|
|
88
|
+
A class to convert and run pre-trained transformers based Gemma2Model model on RBLN devices.
|
|
89
|
+
It implements the methods to convert a pre-trained transformers Gemma2Model model into a RBLN transformer model by:
|
|
90
|
+
|
|
91
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
92
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
93
|
+
|
|
94
|
+
**Configuration:**
|
|
95
|
+
This model uses [`RBLNGemma2ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
96
|
+
the `rbln_config` parameter should be an instance of [`RBLNGemma2ModelConfig`] or a dictionary conforming to its structure.
|
|
97
|
+
|
|
98
|
+
See the [`RBLNGemma2ModelConfig`] class for all available configuration options.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
_decoder_wrapper_cls = Gemma2Wrapper
|
|
@@ -58,13 +58,8 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
|
58
58
|
)
|
|
59
59
|
self.image_prefill_chunk_size = image_prefill_chunk_size
|
|
60
60
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
return self.image_prefill_chunk_size is not None
|
|
64
|
-
|
|
65
|
-
@property
|
|
66
|
-
def decoder_runtime_idx(self):
|
|
67
|
-
return 2 if self.use_image_prefill else 1
|
|
61
|
+
if not (self.use_attention_mask and self.use_position_ids):
|
|
62
|
+
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
68
63
|
|
|
69
64
|
|
|
70
65
|
class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
|
|
@@ -16,7 +16,6 @@ import copy
|
|
|
16
16
|
from typing import Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
|
|
20
19
|
|
|
21
20
|
from ..decoderonly.decoderonly_architecture import (
|
|
22
21
|
DecoderOnlyAttention,
|
|
@@ -95,16 +94,18 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
95
94
|
else:
|
|
96
95
|
seq_positions = cache_position[:, :1]
|
|
97
96
|
|
|
98
|
-
|
|
97
|
+
cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
|
|
98
|
+
sliding_cache_pos = (cache_seq_len, cache_offset)
|
|
99
99
|
|
|
100
100
|
all_hidden_states = () if output_hidden_states else None
|
|
101
101
|
for layer_idx, layer in enumerate(self.layers):
|
|
102
102
|
if output_hidden_states:
|
|
103
103
|
all_hidden_states += (hidden_states,)
|
|
104
104
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
105
|
+
is_sliding_decode = is_sliding and self.phase == "decode"
|
|
105
106
|
hidden_states = layer(
|
|
106
107
|
hidden_states=hidden_states,
|
|
107
|
-
attention_mask=attention_mask,
|
|
108
|
+
attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
|
|
108
109
|
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
109
110
|
past_key_values=past_key_values,
|
|
110
111
|
cos=cos_local if is_sliding else cos_global,
|
|
@@ -120,11 +121,8 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
120
121
|
|
|
121
122
|
|
|
122
123
|
class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
|
|
127
|
-
return self._original_mod.post_feedforward_layernorm
|
|
124
|
+
_PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
|
|
125
|
+
_POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
|
|
128
126
|
|
|
129
127
|
def forward(
|
|
130
128
|
self,
|
|
@@ -164,13 +162,13 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
|
164
162
|
|
|
165
163
|
|
|
166
164
|
class Gemma3Attention(DecoderOnlyAttention):
|
|
167
|
-
def __post_init__(self):
|
|
168
|
-
self.q_proj =
|
|
169
|
-
self.k_proj =
|
|
170
|
-
self.v_proj =
|
|
171
|
-
self.o_proj =
|
|
172
|
-
self.q_norm =
|
|
173
|
-
self.k_norm =
|
|
174
|
-
|
|
175
|
-
def get_attn_scale(self):
|
|
176
|
-
return
|
|
165
|
+
def __post_init__(self, self_attn):
|
|
166
|
+
self.q_proj = self_attn.q_proj
|
|
167
|
+
self.k_proj = self_attn.k_proj
|
|
168
|
+
self.v_proj = self_attn.v_proj
|
|
169
|
+
self.o_proj = self_attn.o_proj
|
|
170
|
+
self.q_norm = self_attn.q_norm
|
|
171
|
+
self.k_norm = self_attn.k_norm
|
|
172
|
+
|
|
173
|
+
def get_attn_scale(self, self_attn):
|
|
174
|
+
return self_attn.config.query_pre_attn_scalar**-0.5
|
|
@@ -13,11 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import importlib
|
|
15
15
|
import inspect
|
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict,
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
|
17
17
|
|
|
18
|
-
import rebel
|
|
19
18
|
import torch
|
|
20
|
-
from rebel.compile_context import CompileContext
|
|
21
19
|
from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
|
22
20
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
23
21
|
from transformers.modeling_utils import no_init_weights
|
|
@@ -29,10 +27,7 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
|
29
27
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
28
|
from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
|
|
31
29
|
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
32
|
-
from ..decoderonly.modeling_decoderonly import
|
|
33
|
-
RBLNDecoderOnlyModelForCausalLM,
|
|
34
|
-
)
|
|
35
|
-
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
|
30
|
+
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
36
31
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
|
37
32
|
from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
|
|
38
33
|
|
|
@@ -325,7 +320,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
325
320
|
batch_size,
|
|
326
321
|
inputs_embeds.shape[1],
|
|
327
322
|
self.config.text_config.hidden_size,
|
|
328
|
-
dtype=self.rbln_config.
|
|
323
|
+
dtype=self.rbln_config.dtype,
|
|
329
324
|
)
|
|
330
325
|
for _ in range(self.config.text_config.num_hidden_layers + 1)
|
|
331
326
|
)
|
|
@@ -455,174 +450,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
455
450
|
f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
|
|
456
451
|
)
|
|
457
452
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
@classmethod
|
|
461
|
-
def _update_rbln_config(
|
|
462
|
-
cls,
|
|
463
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
464
|
-
model: Optional["PreTrainedModel"] = None,
|
|
465
|
-
model_config: Optional["PretrainedConfig"] = None,
|
|
466
|
-
rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
|
|
467
|
-
) -> RBLNGemma3ForCausalLMConfig:
|
|
468
|
-
# Update rbln_config with super class
|
|
469
|
-
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
470
|
-
|
|
471
|
-
if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
|
|
472
|
-
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
473
|
-
|
|
474
|
-
if rbln_config.use_image_prefill:
|
|
475
|
-
if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
|
|
476
|
-
raise NotImplementedError(
|
|
477
|
-
"Not implemented for different prefill chunk sizes between text and image prefill."
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
# Update image prefill compile config
|
|
481
|
-
img_prefill_input_info = cls.get_input_info(
|
|
482
|
-
batch_size=1,
|
|
483
|
-
query_length=rbln_config.image_prefill_chunk_size,
|
|
484
|
-
rbln_config=rbln_config,
|
|
485
|
-
model_config=model_config,
|
|
486
|
-
)
|
|
487
|
-
image_prefill_compile_config = RBLNCompileConfig(
|
|
488
|
-
compiled_model_name="image_prefill", input_info=img_prefill_input_info
|
|
489
|
-
)
|
|
490
|
-
# Insert image_prefill compile config at index 1
|
|
491
|
-
compile_cfgs = rbln_config.compile_cfgs
|
|
492
|
-
compile_cfgs.insert(1, image_prefill_compile_config)
|
|
493
|
-
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
453
|
+
if "image_prefill" not in rbln_config.phases:
|
|
454
|
+
rbln_config.phases = ["prefill", "image_prefill", "decode"]
|
|
494
455
|
|
|
495
456
|
return rbln_config
|
|
496
|
-
|
|
497
|
-
@classmethod
|
|
498
|
-
@torch.inference_mode()
|
|
499
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
500
|
-
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
501
|
-
|
|
502
|
-
rbln_compile_configs = rbln_config.compile_cfgs
|
|
503
|
-
prefill_compile_config = rbln_compile_configs[0]
|
|
504
|
-
|
|
505
|
-
context = CompileContext(use_weight_sharing=True)
|
|
506
|
-
|
|
507
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
508
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
509
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
510
|
-
|
|
511
|
-
# Mark static tensors (self kv states)
|
|
512
|
-
static_tensors = {}
|
|
513
|
-
for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
|
|
514
|
-
if "past_key_values" in name:
|
|
515
|
-
static_tensors[name] = tensor
|
|
516
|
-
context.mark_static_address(tensor)
|
|
517
|
-
|
|
518
|
-
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
|
519
|
-
try:
|
|
520
|
-
if quantization:
|
|
521
|
-
quantization.maybe_set_quantization_env()
|
|
522
|
-
original_linear = torch.nn.functional.linear
|
|
523
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
524
|
-
compiled_model = cls.compile(
|
|
525
|
-
wrapped_model,
|
|
526
|
-
compile_config,
|
|
527
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
528
|
-
device=rbln_config.device,
|
|
529
|
-
example_inputs=example_inputs,
|
|
530
|
-
compile_context=compile_context,
|
|
531
|
-
)
|
|
532
|
-
return compiled_model
|
|
533
|
-
finally:
|
|
534
|
-
torch.nn.functional.linear = original_linear
|
|
535
|
-
if quantization:
|
|
536
|
-
quantization.maybe_reset_quantization_env()
|
|
537
|
-
|
|
538
|
-
wrapped_model.phase = "prefill"
|
|
539
|
-
compiled_prefill = compile_model(
|
|
540
|
-
wrapped_model,
|
|
541
|
-
prefill_compile_config,
|
|
542
|
-
prefill_example_inputs,
|
|
543
|
-
context,
|
|
544
|
-
rbln_config.quantization,
|
|
545
|
-
)
|
|
546
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
547
|
-
|
|
548
|
-
if rbln_config.use_image_prefill:
|
|
549
|
-
image_prefill_compile_config = rbln_compile_configs[1]
|
|
550
|
-
image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
|
|
551
|
-
fill=0, static_tensors=static_tensors
|
|
552
|
-
)
|
|
553
|
-
wrapped_model.phase = "image_prefill"
|
|
554
|
-
compiled_image_prefill = compile_model(
|
|
555
|
-
wrapped_model,
|
|
556
|
-
image_prefill_compile_config,
|
|
557
|
-
image_prefill_example_inputs,
|
|
558
|
-
context,
|
|
559
|
-
rbln_config.quantization,
|
|
560
|
-
)
|
|
561
|
-
compiled_models["image_prefill"] = compiled_image_prefill
|
|
562
|
-
|
|
563
|
-
wrapped_model.phase = "decode"
|
|
564
|
-
for batch_size, dec_compile_config in zip(
|
|
565
|
-
rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
|
|
566
|
-
):
|
|
567
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
568
|
-
compiled_decoder = compile_model(
|
|
569
|
-
wrapped_model,
|
|
570
|
-
dec_compile_config,
|
|
571
|
-
dec_example_inputs,
|
|
572
|
-
context,
|
|
573
|
-
rbln_config.quantization,
|
|
574
|
-
)
|
|
575
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
576
|
-
|
|
577
|
-
return compiled_models
|
|
578
|
-
|
|
579
|
-
@classmethod
|
|
580
|
-
def _create_runtimes(
|
|
581
|
-
cls,
|
|
582
|
-
compiled_models: List[rebel.RBLNCompiledModel],
|
|
583
|
-
rbln_config: RBLNGemma3ForCausalLMConfig,
|
|
584
|
-
) -> List[rebel.Runtime]:
|
|
585
|
-
expected_model_names = [
|
|
586
|
-
"prefill",
|
|
587
|
-
*[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
|
|
588
|
-
]
|
|
589
|
-
if rbln_config.use_image_prefill:
|
|
590
|
-
expected_model_names.insert(1, "image_prefill")
|
|
591
|
-
|
|
592
|
-
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
593
|
-
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
594
|
-
|
|
595
|
-
ret_val = [
|
|
596
|
-
rebel.Runtime(
|
|
597
|
-
compiled_models[0],
|
|
598
|
-
tensor_type="pt",
|
|
599
|
-
device=rbln_config.device_map["prefill"],
|
|
600
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
601
|
-
timeout=rbln_config.timeout,
|
|
602
|
-
)
|
|
603
|
-
]
|
|
604
|
-
if rbln_config.use_image_prefill:
|
|
605
|
-
ret_val.append(
|
|
606
|
-
rebel.Runtime(
|
|
607
|
-
compiled_models[1],
|
|
608
|
-
tensor_type="pt",
|
|
609
|
-
device=rbln_config.device_map["image_prefill"],
|
|
610
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
611
|
-
timeout=rbln_config.timeout,
|
|
612
|
-
),
|
|
613
|
-
)
|
|
614
|
-
|
|
615
|
-
ret_val.extend(
|
|
616
|
-
[
|
|
617
|
-
rebel.Runtime(
|
|
618
|
-
compiled_models[i + rbln_config.decoder_runtime_idx],
|
|
619
|
-
tensor_type="pt",
|
|
620
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
621
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
622
|
-
timeout=rbln_config.timeout,
|
|
623
|
-
)
|
|
624
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
625
|
-
]
|
|
626
|
-
)
|
|
627
|
-
|
|
628
|
-
return ret_val
|
|
@@ -20,8 +20,6 @@ import torch.nn as nn
|
|
|
20
20
|
|
|
21
21
|
from ..decoderonly.decoderonly_architecture import (
|
|
22
22
|
DecoderOnlyAttention,
|
|
23
|
-
DecoderOnlyLayer,
|
|
24
|
-
DecoderOnlyModel,
|
|
25
23
|
DecoderOnlyWrapper,
|
|
26
24
|
)
|
|
27
25
|
|
|
@@ -34,12 +32,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
|
34
32
|
def get_rbln_attn_class(self):
|
|
35
33
|
return GPT2Attention
|
|
36
34
|
|
|
37
|
-
def get_rbln_layer_class(self):
|
|
38
|
-
return GPT2Layer
|
|
39
|
-
|
|
40
|
-
def get_rbln_model_class(self):
|
|
41
|
-
return GPT2Model
|
|
42
|
-
|
|
43
35
|
def get_attn_layer(self, layer: nn.Module):
|
|
44
36
|
return layer.attn
|
|
45
37
|
|
|
@@ -50,30 +42,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
|
50
42
|
return model.transformer.h if self.is_causal_lm else model.h
|
|
51
43
|
|
|
52
44
|
|
|
53
|
-
class GPT2Model(DecoderOnlyModel):
|
|
54
|
-
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
55
|
-
return self._original_mod.ln_f
|
|
56
|
-
|
|
57
|
-
def get_embedding(self) -> nn.Embedding:
|
|
58
|
-
return self._original_mod.wte
|
|
59
|
-
|
|
60
|
-
def get_pos_embedding(self) -> nn.Embedding:
|
|
61
|
-
return self._original_mod.wpe
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class GPT2Layer(DecoderOnlyLayer):
|
|
65
|
-
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
66
|
-
return self._original_mod.ln_1
|
|
67
|
-
|
|
68
|
-
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
69
|
-
return self._original_mod.ln_2
|
|
70
|
-
|
|
71
|
-
|
|
72
45
|
class GPT2Attention(DecoderOnlyAttention):
|
|
73
|
-
def __post_init__(self):
|
|
74
|
-
self.c_attn =
|
|
75
|
-
self.o_proj =
|
|
76
|
-
self.split_size =
|
|
46
|
+
def __post_init__(self, self_attn):
|
|
47
|
+
self.c_attn = self_attn.c_attn
|
|
48
|
+
self.o_proj = self_attn.c_proj
|
|
49
|
+
self.split_size = self_attn.split_size
|
|
50
|
+
self.num_key_value_heads = self_attn.num_heads
|
|
77
51
|
|
|
78
52
|
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
79
53
|
if lora_int_id is not None:
|
|
@@ -82,12 +56,12 @@ class GPT2Attention(DecoderOnlyAttention):
|
|
|
82
56
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
83
57
|
return query_states, key_states, value_states
|
|
84
58
|
|
|
85
|
-
def get_attn_scale(self):
|
|
59
|
+
def get_attn_scale(self, self_attn):
|
|
86
60
|
scale = 1.0
|
|
87
|
-
if
|
|
61
|
+
if self_attn.scale_attn_weights:
|
|
88
62
|
scale /= math.sqrt(self.head_dim)
|
|
89
63
|
|
|
90
|
-
if
|
|
64
|
+
if self_attn.scale_attn_by_inverse_layer_idx:
|
|
91
65
|
scale /= 1 + self.layer_idx
|
|
92
66
|
|
|
93
67
|
return scale
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_gpt_oss import RBLNGptOssForCausalLMConfig
|
|
16
|
+
from .modeling_gpt_oss import RBLNGptOssForCausalLM
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN GptOss models.
|
|
21
|
+
|
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
```python
|
|
26
|
+
from optimum.rbln import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
|
|
27
|
+
|
|
28
|
+
# Create a configuration object
|
|
29
|
+
config = RBLNGptOssForCausalLMConfig(
|
|
30
|
+
batch_size=1,
|
|
31
|
+
tensor_parallel_size=8,
|
|
32
|
+
kvcache_partition_len=8192,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Use the configuration with from_pretrained
|
|
36
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
37
|
+
"openai/gpt-oss-20b",
|
|
38
|
+
export=True,
|
|
39
|
+
rbln_config=config,
|
|
40
|
+
)
|
|
41
|
+
```
|
|
42
|
+
"""
|