optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,9 +18,6 @@ import torch.nn as nn
|
|
|
18
18
|
|
|
19
19
|
from ....utils import logging
|
|
20
20
|
from ...models.decoderonly.decoderonly_architecture import (
|
|
21
|
-
DecoderOnlyAttention,
|
|
22
|
-
DecoderOnlyLayer,
|
|
23
|
-
DecoderOnlyModel,
|
|
24
21
|
DecoderOnlyWrapper,
|
|
25
22
|
)
|
|
26
23
|
|
|
@@ -42,36 +39,3 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
|
42
39
|
|
|
43
40
|
def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
|
|
44
41
|
return causal_lm.transformer
|
|
45
|
-
|
|
46
|
-
def get_rbln_attn_class(self):
|
|
47
|
-
return ExaoneAttention
|
|
48
|
-
|
|
49
|
-
def get_rbln_layer_class(self):
|
|
50
|
-
return ExaoneLayer
|
|
51
|
-
|
|
52
|
-
def get_rbln_model_class(self):
|
|
53
|
-
return ExaoneModel
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class ExaoneModel(DecoderOnlyModel):
|
|
57
|
-
def get_embedding(self) -> nn.Embedding:
|
|
58
|
-
return self._original_mod.wte
|
|
59
|
-
|
|
60
|
-
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
61
|
-
return self._original_mod.ln_f
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class ExaoneLayer(DecoderOnlyLayer):
|
|
65
|
-
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
66
|
-
return self._original_mod.ln_1
|
|
67
|
-
|
|
68
|
-
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
69
|
-
return self._original_mod.ln_2
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class ExaoneAttention(DecoderOnlyAttention):
|
|
73
|
-
def __post_init__(self):
|
|
74
|
-
self.q_proj = self._original_mod.q_proj
|
|
75
|
-
self.k_proj = self._original_mod.k_proj
|
|
76
|
-
self.v_proj = self._original_mod.v_proj
|
|
77
|
-
self.o_proj = self._original_mod.out_proj
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_gemma2 import RBLNGemma2ForCausalLMConfig, RBLNGemma2ModelConfig
|
|
16
|
+
from .modeling_gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2Model
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNGemma2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Gemma2 models.
|
|
21
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
22
|
+
Example usage:
|
|
23
|
+
```python
|
|
24
|
+
from optimum.rbln import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig
|
|
25
|
+
# Create a configuration object
|
|
26
|
+
config = RBLNGemma2ForCausalLMConfig(
|
|
27
|
+
batch_size=1,
|
|
28
|
+
max_seq_len=8192,
|
|
29
|
+
tensor_parallel_size=4
|
|
30
|
+
)
|
|
31
|
+
# Use the configuration with from_pretrained
|
|
32
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
33
|
+
"google/gemma-2-9b",
|
|
34
|
+
export=True,
|
|
35
|
+
rbln_config=config
|
|
36
|
+
)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RBLNGemma2ModelConfig(RBLNDecoderOnlyModelConfig):
|
|
42
|
+
"""
|
|
43
|
+
Configuration class for RBLN Gemma2 models.
|
|
44
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
45
|
+
"""
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from ...models.decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyModel
|
|
20
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Gemma2Wrapper(DecoderOnlyWrapper):
|
|
24
|
+
def get_rbln_layer_class(self):
|
|
25
|
+
return Gemma2DecoderLayer
|
|
26
|
+
|
|
27
|
+
def get_rbln_attn_class(self):
|
|
28
|
+
return Gemma2Attention
|
|
29
|
+
|
|
30
|
+
def get_rbln_model_class(self):
|
|
31
|
+
return Gemma2Model
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Gemma2DecoderLayer(DecoderOnlyLayer):
|
|
35
|
+
_PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
|
|
36
|
+
_POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
|
|
37
|
+
|
|
38
|
+
def forward(
|
|
39
|
+
self,
|
|
40
|
+
hidden_states: torch.Tensor,
|
|
41
|
+
attention_mask: torch.Tensor,
|
|
42
|
+
seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
|
|
43
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
44
|
+
cos: Optional[torch.Tensor] = None,
|
|
45
|
+
sin: Optional[torch.Tensor] = None,
|
|
46
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
47
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
48
|
+
):
|
|
49
|
+
residual = hidden_states
|
|
50
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
51
|
+
|
|
52
|
+
hidden_states = self.self_attn(
|
|
53
|
+
hidden_states=hidden_states,
|
|
54
|
+
attention_mask=attention_mask,
|
|
55
|
+
seq_positions=seq_positions,
|
|
56
|
+
past_key_values=past_key_values,
|
|
57
|
+
cos=cos,
|
|
58
|
+
sin=sin,
|
|
59
|
+
block_tables=block_tables,
|
|
60
|
+
lora_int_id=lora_int_id,
|
|
61
|
+
)
|
|
62
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
63
|
+
hidden_states = residual + hidden_states
|
|
64
|
+
|
|
65
|
+
# Fully Connected
|
|
66
|
+
residual = hidden_states
|
|
67
|
+
hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
|
|
68
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
69
|
+
hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
|
|
70
|
+
hidden_states = residual + hidden_states
|
|
71
|
+
|
|
72
|
+
return hidden_states
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Gemma2Attention(DecoderOnlyAttention):
|
|
76
|
+
def get_attn_scale(self, self_attn):
|
|
77
|
+
return self_attn.config.query_pre_attn_scalar**-0.5
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Gemma2Model(DecoderOnlyModel):
|
|
81
|
+
@property
|
|
82
|
+
def hidden_multiplier(self):
|
|
83
|
+
return self.config.hidden_size**0.5
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from ....utils import logging
|
|
17
|
+
from ...models.decoderonly import (
|
|
18
|
+
RBLNDecoderOnlyModel,
|
|
19
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
20
|
+
)
|
|
21
|
+
from .gemma2_architecture import Gemma2Wrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBLNGemma2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
28
|
+
"""
|
|
29
|
+
The Gemma2 Model transformer with a language modeling head (linear layer) on top.
|
|
30
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
31
|
+
|
|
32
|
+
A class to convert and run pre-trained transformers based Gemma2ForCausalLM model on RBLN devices.
|
|
33
|
+
It implements the methods to convert a pre-trained transformers Gemma2ForCausalLM model into a RBLN transformer model by:
|
|
34
|
+
|
|
35
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
36
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
37
|
+
|
|
38
|
+
**Configuration:**
|
|
39
|
+
This model uses [`RBLNGemma2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
40
|
+
the `rbln_config` parameter should be an instance of [`RBLNGemma2ForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
41
|
+
|
|
42
|
+
See the [`RBLNGemma2ForCausalLMConfig`] class for all available configuration options.
|
|
43
|
+
Examples:
|
|
44
|
+
```python
|
|
45
|
+
from optimum.rbln import RBLNGemma2ForCausalLM
|
|
46
|
+
# Simple usage using rbln_* arguments
|
|
47
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
48
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
49
|
+
"google/gemma-2-9b",
|
|
50
|
+
export=True,
|
|
51
|
+
rbln_batch_size=1,
|
|
52
|
+
rbln_tensor_parallel_size=4,
|
|
53
|
+
)
|
|
54
|
+
# Using a config dictionary
|
|
55
|
+
rbln_config = {
|
|
56
|
+
"batch_size": 1,
|
|
57
|
+
"max_seq_len": 8192,
|
|
58
|
+
"tensor_parallel_size": 4,
|
|
59
|
+
}
|
|
60
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
61
|
+
"google/gemma-2-9b",
|
|
62
|
+
export=True,
|
|
63
|
+
rbln_config=rbln_config
|
|
64
|
+
)
|
|
65
|
+
# Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
|
|
66
|
+
from optimum.rbln import RBLNGemma2ForCausalLMConfig
|
|
67
|
+
config = RBLNGemma2ForCausalLMConfig(
|
|
68
|
+
batch_size=1,
|
|
69
|
+
max_seq_len=8192,
|
|
70
|
+
tensor_parallel_size=4
|
|
71
|
+
)
|
|
72
|
+
model = RBLNGemma2ForCausalLM.from_pretrained(
|
|
73
|
+
"google/gemma-2-9b",
|
|
74
|
+
export=True,
|
|
75
|
+
rbln_config=config
|
|
76
|
+
)
|
|
77
|
+
```
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
_decoder_wrapper_cls = Gemma2Wrapper
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class RBLNGemma2Model(RBLNDecoderOnlyModel):
|
|
84
|
+
"""
|
|
85
|
+
The Gemma2 Model transformer without a language modeling head.
|
|
86
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
87
|
+
|
|
88
|
+
A class to convert and run pre-trained transformers based Gemma2Model model on RBLN devices.
|
|
89
|
+
It implements the methods to convert a pre-trained transformers Gemma2Model model into a RBLN transformer model by:
|
|
90
|
+
|
|
91
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
92
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
93
|
+
|
|
94
|
+
**Configuration:**
|
|
95
|
+
This model uses [`RBLNGemma2ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
96
|
+
the `rbln_config` parameter should be an instance of [`RBLNGemma2ModelConfig`] or a dictionary conforming to its structure.
|
|
97
|
+
|
|
98
|
+
See the [`RBLNGemma2ModelConfig`] class for all available configuration options.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
_decoder_wrapper_cls = Gemma2Wrapper
|
|
@@ -16,7 +16,6 @@ import copy
|
|
|
16
16
|
from typing import Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
|
|
20
19
|
|
|
21
20
|
from ..decoderonly.decoderonly_architecture import (
|
|
22
21
|
DecoderOnlyAttention,
|
|
@@ -64,6 +63,7 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
64
63
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
65
64
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
66
65
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
66
|
+
output_hidden_states: Optional[bool] = None,
|
|
67
67
|
):
|
|
68
68
|
# retrieve input_ids and inputs_embeds
|
|
69
69
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -94,13 +94,18 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
94
94
|
else:
|
|
95
95
|
seq_positions = cache_position[:, :1]
|
|
96
96
|
|
|
97
|
-
|
|
97
|
+
cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
|
|
98
|
+
sliding_cache_pos = (cache_seq_len, cache_offset)
|
|
98
99
|
|
|
100
|
+
all_hidden_states = () if output_hidden_states else None
|
|
99
101
|
for layer_idx, layer in enumerate(self.layers):
|
|
102
|
+
if output_hidden_states:
|
|
103
|
+
all_hidden_states += (hidden_states,)
|
|
100
104
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
105
|
+
is_sliding_decode = is_sliding and self.phase == "decode"
|
|
101
106
|
hidden_states = layer(
|
|
102
107
|
hidden_states=hidden_states,
|
|
103
|
-
attention_mask=attention_mask,
|
|
108
|
+
attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
|
|
104
109
|
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
105
110
|
past_key_values=past_key_values,
|
|
106
111
|
cos=cos_local if is_sliding else cos_global,
|
|
@@ -110,15 +115,14 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
110
115
|
)
|
|
111
116
|
|
|
112
117
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
113
|
-
|
|
118
|
+
if output_hidden_states:
|
|
119
|
+
all_hidden_states += (hidden_states,)
|
|
120
|
+
return hidden_states, all_hidden_states
|
|
114
121
|
|
|
115
122
|
|
|
116
123
|
class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
|
|
121
|
-
return self._original_mod.post_feedforward_layernorm
|
|
124
|
+
_PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
|
|
125
|
+
_POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
|
|
122
126
|
|
|
123
127
|
def forward(
|
|
124
128
|
self,
|
|
@@ -158,13 +162,13 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
|
158
162
|
|
|
159
163
|
|
|
160
164
|
class Gemma3Attention(DecoderOnlyAttention):
|
|
161
|
-
def __post_init__(self):
|
|
162
|
-
self.q_proj =
|
|
163
|
-
self.k_proj =
|
|
164
|
-
self.v_proj =
|
|
165
|
-
self.o_proj =
|
|
166
|
-
self.q_norm =
|
|
167
|
-
self.k_norm =
|
|
168
|
-
|
|
169
|
-
def get_attn_scale(self):
|
|
170
|
-
return
|
|
165
|
+
def __post_init__(self, self_attn):
|
|
166
|
+
self.q_proj = self_attn.q_proj
|
|
167
|
+
self.k_proj = self_attn.k_proj
|
|
168
|
+
self.v_proj = self_attn.v_proj
|
|
169
|
+
self.o_proj = self_attn.o_proj
|
|
170
|
+
self.q_norm = self_attn.q_norm
|
|
171
|
+
self.k_norm = self_attn.k_norm
|
|
172
|
+
|
|
173
|
+
def get_attn_scale(self, self_attn):
|
|
174
|
+
return self_attn.config.query_pre_attn_scalar**-0.5
|
|
@@ -16,7 +16,7 @@ from typing import Optional
|
|
|
16
16
|
import rebel
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from ...modeling_outputs import
|
|
19
|
+
from ...modeling_outputs import RBLNGemma3ForCausalLMOutput
|
|
20
20
|
from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
|
|
21
21
|
from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
|
|
22
22
|
|
|
@@ -26,7 +26,6 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
26
26
|
super().__init__(*args, **kwargs)
|
|
27
27
|
self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
|
|
28
28
|
self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
|
|
29
|
-
self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
|
|
30
29
|
|
|
31
30
|
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
32
31
|
(
|
|
@@ -106,6 +105,8 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
106
105
|
)
|
|
107
106
|
|
|
108
107
|
step = 0
|
|
108
|
+
output_logits = []
|
|
109
|
+
all_hidden_states = [] if self.rbln_config.output_hidden_states else None
|
|
109
110
|
while step < query_length:
|
|
110
111
|
if self.rbln_config.use_image_prefill:
|
|
111
112
|
# Check if the prefill chunk is an image prefill
|
|
@@ -146,7 +147,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
146
147
|
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
147
148
|
|
|
148
149
|
if is_image_prefill:
|
|
149
|
-
|
|
150
|
+
outputs = self.image_prefill(
|
|
150
151
|
input_chunk,
|
|
151
152
|
cache_pos_chunk,
|
|
152
153
|
block_tables,
|
|
@@ -157,7 +158,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
157
158
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
158
159
|
)
|
|
159
160
|
else:
|
|
160
|
-
|
|
161
|
+
outputs = self.prefill(
|
|
161
162
|
input_chunk,
|
|
162
163
|
cache_pos_chunk,
|
|
163
164
|
block_tables,
|
|
@@ -168,78 +169,49 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
168
169
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
169
170
|
)
|
|
170
171
|
|
|
172
|
+
if self.rbln_config.output_hidden_states:
|
|
173
|
+
output_logits.append(outputs[0])
|
|
174
|
+
all_hidden_states.append(tuple(outputs[1:]))
|
|
175
|
+
else:
|
|
176
|
+
output_logits.append(outputs)
|
|
177
|
+
|
|
171
178
|
padded_cache_lengths += current_padded_cache_lengths
|
|
172
179
|
step += num_processed_tokens
|
|
173
180
|
|
|
174
|
-
if
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def decode_forward(
|
|
182
|
-
self,
|
|
183
|
-
inputs: torch.Tensor,
|
|
184
|
-
cache_position: torch.Tensor = None,
|
|
185
|
-
block_tables: torch.Tensor = None,
|
|
186
|
-
is_external_block_tables: bool = None,
|
|
187
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
188
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
189
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
190
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
191
|
-
lora_int_ids: Optional[torch.Tensor] = None,
|
|
192
|
-
) -> torch.FloatTensor:
|
|
193
|
-
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
194
|
-
if self.lora_int_ids is None:
|
|
195
|
-
raise ValueError(
|
|
196
|
-
"lora_int_id is required when using LoRA. "
|
|
197
|
-
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
lora_int_ids = self.lora_int_ids
|
|
201
|
-
|
|
202
|
-
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
203
|
-
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
204
|
-
|
|
205
|
-
batch_size = inputs.shape[0]
|
|
206
|
-
if batch_size != self.batch_size:
|
|
207
|
-
raise RuntimeError(
|
|
208
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
209
|
-
)
|
|
181
|
+
if self.rbln_config.output_hidden_states:
|
|
182
|
+
num_hidden_layers = len(all_hidden_states[0]) - 1
|
|
183
|
+
concatenated_hidden_states = ()
|
|
184
|
+
for l_idx in range(num_hidden_layers + 1):
|
|
185
|
+
l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
|
|
186
|
+
l_hidden_states = l_hidden_states[:, :query_length, :]
|
|
187
|
+
concatenated_hidden_states += (l_hidden_states,)
|
|
210
188
|
|
|
211
|
-
|
|
212
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
189
|
+
all_hidden_states = concatenated_hidden_states
|
|
213
190
|
|
|
214
|
-
#
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
if local_block_tables is None:
|
|
219
|
-
raise ValueError("local_block_tables should be provided with external block tables.")
|
|
191
|
+
# Aggregate output_logits
|
|
192
|
+
output_logits = torch.concat(output_logits, dim=-2)
|
|
193
|
+
if self.rbln_config.logits_to_keep > 0:
|
|
194
|
+
output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
|
|
220
195
|
else:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
232
|
-
)
|
|
233
|
-
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
234
|
-
|
|
235
|
-
attention_mask = self.dec_attn_mask
|
|
236
|
-
|
|
237
|
-
if self.batch_size < block_tables.shape[0]:
|
|
238
|
-
block_tables = block_tables[: self.batch_size]
|
|
196
|
+
output_logits = output_logits[:, :query_length, :]
|
|
197
|
+
# index copy for masked output_logits
|
|
198
|
+
if attention_mask is not None:
|
|
199
|
+
new_output_logits = torch.full(
|
|
200
|
+
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
201
|
+
fill_value=1e-10,
|
|
202
|
+
dtype=output_logits.dtype,
|
|
203
|
+
)
|
|
204
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
205
|
+
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
239
206
|
|
|
240
|
-
|
|
241
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
207
|
+
output_logits = new_output_logits
|
|
242
208
|
|
|
243
|
-
|
|
209
|
+
if not is_external_block_tables:
|
|
210
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
244
211
|
|
|
245
|
-
return
|
|
212
|
+
return RBLNGemma3ForCausalLMOutput(
|
|
213
|
+
logits=output_logits,
|
|
214
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
215
|
+
attention_mask=chunked_attention_mask,
|
|
216
|
+
hidden_states=all_hidden_states,
|
|
217
|
+
)
|