optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,564 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import importlib
|
|
16
|
+
import inspect
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from transformers import AutoModelForVision2Seq, PaliGemmaForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
|
22
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
23
|
+
from transformers.modeling_utils import no_init_weights
|
|
24
|
+
from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
|
|
25
|
+
from transformers.models.paligemma.modeling_paligemma import PaligemmaModelOutputWithPast, PaliGemmaMultiModalProjector
|
|
26
|
+
|
|
27
|
+
from ....configuration_utils import RBLNModelConfig
|
|
28
|
+
from ....modeling import RBLNModel
|
|
29
|
+
from ....utils.logging import get_logger
|
|
30
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
31
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
32
|
+
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
logger = get_logger(__name__)
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LoopVisionTower(LoopProcessor):
|
|
42
|
+
def __init__(self, vision_tower: "RBLNModel"):
|
|
43
|
+
super().__init__(model=vision_tower.model[0])
|
|
44
|
+
|
|
45
|
+
def _get_batch_size(self, pixel_values, **kwargs):
|
|
46
|
+
return pixel_values.shape[0]
|
|
47
|
+
|
|
48
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
|
|
49
|
+
pixel_values_item = pixel_values[index : index + 1]
|
|
50
|
+
out_buffer = kwargs["out"][index : index + 1]
|
|
51
|
+
return ([pixel_values_item], {"out": out_buffer})
|
|
52
|
+
|
|
53
|
+
def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
|
|
54
|
+
return BaseModelOutputWithPooling(
|
|
55
|
+
last_hidden_state=kwargs["out"],
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class RBLNPaliGemmaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
60
|
+
"""
|
|
61
|
+
RBLNPaliGemmaForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
62
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
63
|
+
|
|
64
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
65
|
+
|
|
66
|
+
Important Note:
|
|
67
|
+
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
|
68
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
69
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNPaliGemmaForConditionalGeneration class for details.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
```python
|
|
73
|
+
from optimum.rbln import RBLNPaliGemmaForConditionalGeneration
|
|
74
|
+
|
|
75
|
+
model = RBLNPaliGemmaForConditionalGeneration.from_pretrained(
|
|
76
|
+
"google/paligemma2-3b-mix-224",
|
|
77
|
+
export=True,
|
|
78
|
+
rbln_config={
|
|
79
|
+
"language_model": {
|
|
80
|
+
"prefill_chunk_size": 8192,
|
|
81
|
+
}
|
|
82
|
+
},
|
|
83
|
+
rbln_tensor_parallel_size=4,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
model.save_pretrained("compiled-paligemma2-3b-mix-224")
|
|
87
|
+
```
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
auto_model_class = AutoModelForVision2Seq
|
|
91
|
+
_rbln_submodules = [
|
|
92
|
+
{"name": "vision_tower"},
|
|
93
|
+
{"name": "language_model"},
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
def __getattr__(self, __name: str) -> Any:
|
|
97
|
+
def redirect(func):
|
|
98
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
99
|
+
|
|
100
|
+
val = getattr(PaliGemmaForConditionalGeneration, __name)
|
|
101
|
+
|
|
102
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
103
|
+
return redirect(val)
|
|
104
|
+
return val
|
|
105
|
+
|
|
106
|
+
def can_generate(self):
|
|
107
|
+
return True
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
def _update_submodule_rbln_config(
|
|
111
|
+
cls,
|
|
112
|
+
submodule_name: str,
|
|
113
|
+
submodule_cls: Type["RBLNModel"],
|
|
114
|
+
model: "PreTrainedModel",
|
|
115
|
+
submodule_config: PretrainedConfig,
|
|
116
|
+
submodule_rbln_config: RBLNModelConfig,
|
|
117
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
118
|
+
):
|
|
119
|
+
if submodule_name == "language_model":
|
|
120
|
+
submodule_config.use_sliding_window = False
|
|
121
|
+
else:
|
|
122
|
+
return submodule_rbln_config
|
|
123
|
+
|
|
124
|
+
return submodule_rbln_config
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
128
|
+
with no_init_weights():
|
|
129
|
+
model_cls_name = model.model.language_model.__class__.__name__
|
|
130
|
+
causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
|
|
131
|
+
causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
|
|
132
|
+
new_language_model = causal_model_cls(model.model.language_model.config)
|
|
133
|
+
|
|
134
|
+
new_language_model.lm_head = model.lm_head
|
|
135
|
+
new_language_model.model = model.model.language_model
|
|
136
|
+
model.model.language_model = new_language_model
|
|
137
|
+
model.lm_head = None
|
|
138
|
+
del model.lm_head
|
|
139
|
+
return model
|
|
140
|
+
|
|
141
|
+
def __post_init__(self, **kwargs):
|
|
142
|
+
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
|
|
143
|
+
self.language_model = self.rbln_submodules[1]
|
|
144
|
+
|
|
145
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
146
|
+
self.embed_tokens = self._create_embedding_layer()
|
|
147
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
148
|
+
self.multi_modal_projector = self._create_multi_modal_projector()
|
|
149
|
+
self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
|
|
150
|
+
|
|
151
|
+
return super().__post_init__(**kwargs)
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def save_torch_artifacts(
|
|
155
|
+
cls,
|
|
156
|
+
model: "PaliGemmaForConditionalGeneration",
|
|
157
|
+
save_dir_path: Path,
|
|
158
|
+
subfolder: str,
|
|
159
|
+
rbln_config: RBLNModelConfig,
|
|
160
|
+
):
|
|
161
|
+
save_dict = {}
|
|
162
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
163
|
+
save_dict["multi_modal_projector"] = model.multi_modal_projector.state_dict()
|
|
164
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
165
|
+
|
|
166
|
+
def get_attn_impl(self) -> str:
|
|
167
|
+
return self.rbln_config.language_model.attn_impl
|
|
168
|
+
|
|
169
|
+
def get_kvcache_num_blocks(self) -> int:
|
|
170
|
+
return self.rbln_config.language_model.kvcache_num_blocks
|
|
171
|
+
|
|
172
|
+
def get_input_embeddings(self):
|
|
173
|
+
return self.language_model.get_input_embeddings()
|
|
174
|
+
|
|
175
|
+
def _create_embedding_layer(self):
|
|
176
|
+
with no_init_weights():
|
|
177
|
+
embed_tokens = torch.nn.Embedding(
|
|
178
|
+
self.config.text_config.vocab_size,
|
|
179
|
+
self.config.text_config.hidden_size,
|
|
180
|
+
self.config.text_config.pad_token_id,
|
|
181
|
+
)
|
|
182
|
+
return embed_tokens
|
|
183
|
+
|
|
184
|
+
def _create_multi_modal_projector(self):
|
|
185
|
+
with no_init_weights():
|
|
186
|
+
multi_modal_projector = PaliGemmaMultiModalProjector(self.config)
|
|
187
|
+
return multi_modal_projector
|
|
188
|
+
|
|
189
|
+
def prepare_inputs_for_generation(
|
|
190
|
+
self,
|
|
191
|
+
input_ids,
|
|
192
|
+
inputs_embeds=None,
|
|
193
|
+
pixel_values=None,
|
|
194
|
+
image_sizes=None,
|
|
195
|
+
attention_mask=None,
|
|
196
|
+
generate_idx=None,
|
|
197
|
+
position_ids=None,
|
|
198
|
+
token_type_ids=None,
|
|
199
|
+
**kwargs,
|
|
200
|
+
):
|
|
201
|
+
# Prepare HF generation
|
|
202
|
+
is_prefill_phase = generate_idx is None
|
|
203
|
+
|
|
204
|
+
model_inputs = self.language_model.prepare_inputs_for_generation(
|
|
205
|
+
input_ids=input_ids,
|
|
206
|
+
inputs_embeds=inputs_embeds,
|
|
207
|
+
generate_idx=generate_idx, # Not affect
|
|
208
|
+
attention_mask=attention_mask,
|
|
209
|
+
position_ids=position_ids,
|
|
210
|
+
**kwargs,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if is_prefill_phase:
|
|
214
|
+
model_inputs.update(
|
|
215
|
+
{
|
|
216
|
+
"pixel_values": pixel_values,
|
|
217
|
+
"token_type_ids": token_type_ids,
|
|
218
|
+
}
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
model_inputs["attention_mask"] = attention_mask
|
|
222
|
+
|
|
223
|
+
return model_inputs
|
|
224
|
+
|
|
225
|
+
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
|
226
|
+
model_kwargs["generate_idx"] = outputs.generate_idx
|
|
227
|
+
return model_kwargs
|
|
228
|
+
|
|
229
|
+
def get_image_features(self, pixel_values: torch.Tensor):
|
|
230
|
+
vision_output_size = [
|
|
231
|
+
pixel_values.shape[0],
|
|
232
|
+
self.config.vision_config.num_image_tokens,
|
|
233
|
+
self.config.vision_config.hidden_size,
|
|
234
|
+
]
|
|
235
|
+
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
|
|
236
|
+
self.vision_tower(pixel_values, out=vision_output)
|
|
237
|
+
image_features = self.multi_modal_projector(vision_output)
|
|
238
|
+
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
|
239
|
+
return image_features
|
|
240
|
+
|
|
241
|
+
def get_placeholder_mask(
|
|
242
|
+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
|
243
|
+
):
|
|
244
|
+
if input_ids is None:
|
|
245
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
246
|
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
247
|
+
)
|
|
248
|
+
special_image_mask = special_image_mask.all(-1)
|
|
249
|
+
else:
|
|
250
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
251
|
+
|
|
252
|
+
n_image_tokens = special_image_mask.sum()
|
|
253
|
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
254
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
255
|
+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
256
|
+
raise ValueError(
|
|
257
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
258
|
+
)
|
|
259
|
+
return special_image_mask
|
|
260
|
+
|
|
261
|
+
def _preprocess_prefill(
|
|
262
|
+
self,
|
|
263
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
264
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
265
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
266
|
+
**kwargs,
|
|
267
|
+
):
|
|
268
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
269
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
270
|
+
|
|
271
|
+
if input_ids is not None and self.config.image_token_id >= self.config.text_config.vocab_size:
|
|
272
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
273
|
+
llm_input_ids = input_ids.clone()
|
|
274
|
+
llm_input_ids[special_image_mask] = 0
|
|
275
|
+
else:
|
|
276
|
+
llm_input_ids = input_ids
|
|
277
|
+
|
|
278
|
+
if inputs_embeds is None:
|
|
279
|
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
280
|
+
|
|
281
|
+
if pixel_values is not None:
|
|
282
|
+
image_features = self.get_image_features(pixel_values)
|
|
283
|
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
284
|
+
special_image_mask = self.get_placeholder_mask(
|
|
285
|
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
|
286
|
+
)
|
|
287
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
288
|
+
|
|
289
|
+
return inputs_embeds
|
|
290
|
+
|
|
291
|
+
def forward(
|
|
292
|
+
self,
|
|
293
|
+
input_ids: torch.LongTensor = None,
|
|
294
|
+
pixel_values: torch.FloatTensor = None,
|
|
295
|
+
attention_mask: torch.LongTensor = None,
|
|
296
|
+
position_ids: torch.LongTensor = None,
|
|
297
|
+
token_type_ids: torch.LongTensor = None,
|
|
298
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
299
|
+
cache_position: torch.Tensor = None,
|
|
300
|
+
generate_idx: Optional[torch.Tensor] = None,
|
|
301
|
+
return_dict: Optional[bool] = None,
|
|
302
|
+
**kwargs,
|
|
303
|
+
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
|
304
|
+
# Prefill
|
|
305
|
+
if cache_position is None:
|
|
306
|
+
inputs_embeds = self._preprocess_prefill(
|
|
307
|
+
input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
|
|
308
|
+
)
|
|
309
|
+
logits = []
|
|
310
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
311
|
+
batch_size = inputs.shape[0]
|
|
312
|
+
|
|
313
|
+
for b_idx in range(batch_size):
|
|
314
|
+
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
315
|
+
output = self.language_model.prefill_decoder(
|
|
316
|
+
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
317
|
+
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
318
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
319
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
320
|
+
cache_position=cache_position,
|
|
321
|
+
batch_idx=b_idx,
|
|
322
|
+
)
|
|
323
|
+
logits.append(output.logits)
|
|
324
|
+
|
|
325
|
+
logits = torch.cat(logits, dim=0)
|
|
326
|
+
# Decoder
|
|
327
|
+
else:
|
|
328
|
+
logits = self.language_model.decoder(
|
|
329
|
+
input_ids=input_ids,
|
|
330
|
+
inputs_embeds=inputs_embeds,
|
|
331
|
+
cache_position=cache_position,
|
|
332
|
+
position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
|
|
333
|
+
).logits
|
|
334
|
+
|
|
335
|
+
if not return_dict:
|
|
336
|
+
return logits, generate_idx
|
|
337
|
+
else:
|
|
338
|
+
return RBLNDecoderOnlyOutput(
|
|
339
|
+
logits=logits,
|
|
340
|
+
generate_idx=generate_idx,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class RBLNPaliGemmaModel(RBLNModel):
|
|
345
|
+
"""
|
|
346
|
+
RBLNPaliGemmaModel which consists of a vision backbone and a language model without language modeling head,
|
|
347
|
+
optimized for RBLN NPUs.
|
|
348
|
+
|
|
349
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
350
|
+
|
|
351
|
+
Important Note:
|
|
352
|
+
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
|
353
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
354
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNPaliGemmaModel class for details.
|
|
355
|
+
|
|
356
|
+
Examples:
|
|
357
|
+
```python
|
|
358
|
+
from optimum.rbln import RBLNPaliGemmaModel
|
|
359
|
+
|
|
360
|
+
model = RBLNPaliGemmaModel.from_pretrained(
|
|
361
|
+
"google/paligemma2-3b-mix-224",
|
|
362
|
+
export=True,
|
|
363
|
+
rbln_config={
|
|
364
|
+
"language_model": {
|
|
365
|
+
"prefill_chunk_size": 8192,
|
|
366
|
+
}
|
|
367
|
+
},
|
|
368
|
+
rbln_tensor_parallel_size=4,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
model.save_pretrained("compiled-paligemma2-3b-mix-224")
|
|
372
|
+
```
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
_rbln_submodules = [
|
|
376
|
+
{"name": "vision_tower"},
|
|
377
|
+
{"name": "language_model"},
|
|
378
|
+
]
|
|
379
|
+
|
|
380
|
+
def __post_init__(self, **kwargs):
|
|
381
|
+
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
|
|
382
|
+
self.language_model = self.rbln_submodules[1]
|
|
383
|
+
|
|
384
|
+
if not isinstance(self.config.text_config, PretrainedConfig):
|
|
385
|
+
cfg = self.config if isinstance(self.config, dict) else self.config.to_dict()
|
|
386
|
+
text_config = cfg.pop("text_config", None)
|
|
387
|
+
vision_config = cfg.pop("vision_config", None)
|
|
388
|
+
self.config = PaliGemmaConfig(
|
|
389
|
+
text_config=text_config,
|
|
390
|
+
vision_config=vision_config,
|
|
391
|
+
**cfg,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
395
|
+
self.embed_tokens = self._create_embedding_layer()
|
|
396
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
397
|
+
self.multi_modal_projector = self._create_multi_modal_projector()
|
|
398
|
+
self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
|
|
399
|
+
|
|
400
|
+
return super().__post_init__(**kwargs)
|
|
401
|
+
|
|
402
|
+
def get_input_embeddings(self):
|
|
403
|
+
return self.language_model.get_input_embeddings()
|
|
404
|
+
|
|
405
|
+
@classmethod
|
|
406
|
+
def _update_submodule_rbln_config(
|
|
407
|
+
cls,
|
|
408
|
+
submodule_name: str,
|
|
409
|
+
submodule_cls: Type["RBLNModel"],
|
|
410
|
+
model: "PreTrainedModel",
|
|
411
|
+
submodule_config: PretrainedConfig,
|
|
412
|
+
submodule_rbln_config: RBLNModelConfig,
|
|
413
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
414
|
+
):
|
|
415
|
+
if submodule_name == "language_model":
|
|
416
|
+
submodule_config.use_sliding_window = False
|
|
417
|
+
else:
|
|
418
|
+
return submodule_rbln_config
|
|
419
|
+
|
|
420
|
+
return submodule_rbln_config
|
|
421
|
+
|
|
422
|
+
@classmethod
|
|
423
|
+
def save_torch_artifacts(
|
|
424
|
+
cls,
|
|
425
|
+
model: "PaliGemmaForConditionalGeneration",
|
|
426
|
+
save_dir_path: Path,
|
|
427
|
+
subfolder: str,
|
|
428
|
+
rbln_config: RBLNModelConfig,
|
|
429
|
+
):
|
|
430
|
+
save_dict = {}
|
|
431
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
432
|
+
save_dict["multi_modal_projector"] = model.multi_modal_projector.state_dict()
|
|
433
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
434
|
+
|
|
435
|
+
def _create_embedding_layer(self):
|
|
436
|
+
with no_init_weights():
|
|
437
|
+
embed_tokens = torch.nn.Embedding(
|
|
438
|
+
self.config.text_config.vocab_size,
|
|
439
|
+
self.config.text_config.hidden_size,
|
|
440
|
+
self.config.text_config.pad_token_id,
|
|
441
|
+
)
|
|
442
|
+
return embed_tokens
|
|
443
|
+
|
|
444
|
+
def _create_multi_modal_projector(self):
|
|
445
|
+
with no_init_weights():
|
|
446
|
+
multi_modal_projector = PaliGemmaMultiModalProjector(self.config)
|
|
447
|
+
return multi_modal_projector
|
|
448
|
+
|
|
449
|
+
def get_image_features(self, pixel_values: torch.Tensor):
|
|
450
|
+
vision_output_size = [
|
|
451
|
+
pixel_values.shape[0],
|
|
452
|
+
self.config.vision_config.num_image_tokens,
|
|
453
|
+
self.config.vision_config.hidden_size,
|
|
454
|
+
]
|
|
455
|
+
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
|
|
456
|
+
self.vision_tower(pixel_values, out=vision_output)
|
|
457
|
+
image_features = self.multi_modal_projector(vision_output)
|
|
458
|
+
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
|
459
|
+
return image_features
|
|
460
|
+
|
|
461
|
+
def get_placeholder_mask(
|
|
462
|
+
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
|
463
|
+
):
|
|
464
|
+
if input_ids is None:
|
|
465
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
466
|
+
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
|
467
|
+
)
|
|
468
|
+
special_image_mask = special_image_mask.all(-1)
|
|
469
|
+
else:
|
|
470
|
+
special_image_mask = input_ids == self.config.image_token_index
|
|
471
|
+
|
|
472
|
+
n_image_tokens = special_image_mask.sum()
|
|
473
|
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
474
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
475
|
+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
476
|
+
raise ValueError(
|
|
477
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
478
|
+
)
|
|
479
|
+
return special_image_mask
|
|
480
|
+
|
|
481
|
+
def _preprocess_prefill(
|
|
482
|
+
self,
|
|
483
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
484
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
485
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
486
|
+
**kwargs,
|
|
487
|
+
):
|
|
488
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
489
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
490
|
+
|
|
491
|
+
if input_ids is not None and self.config.image_token_index >= self.config.text_config.vocab_size:
|
|
492
|
+
special_image_mask = input_ids == self.config.image_token_index
|
|
493
|
+
llm_input_ids = input_ids.clone()
|
|
494
|
+
llm_input_ids[special_image_mask] = 0
|
|
495
|
+
else:
|
|
496
|
+
llm_input_ids = input_ids
|
|
497
|
+
|
|
498
|
+
if inputs_embeds is None:
|
|
499
|
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
500
|
+
|
|
501
|
+
image_features = None
|
|
502
|
+
if pixel_values is not None:
|
|
503
|
+
image_features = self.get_image_features(pixel_values)
|
|
504
|
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
505
|
+
special_image_mask = self.get_placeholder_mask(
|
|
506
|
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
|
507
|
+
)
|
|
508
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
509
|
+
|
|
510
|
+
return inputs_embeds, image_features
|
|
511
|
+
|
|
512
|
+
def forward(
|
|
513
|
+
self,
|
|
514
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
515
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
516
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
517
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
518
|
+
token_type_ids: Optional[torch.LongTensor] = None,
|
|
519
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
520
|
+
output_hidden_states: Optional[bool] = None,
|
|
521
|
+
return_dict: Optional[bool] = None,
|
|
522
|
+
**kwargs,
|
|
523
|
+
) -> Union[Tuple, PaligemmaModelOutputWithPast]:
|
|
524
|
+
"""
|
|
525
|
+
Forward pass for the RBLN-optimized PaliGemmaModel model.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary.
|
|
529
|
+
pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images.
|
|
530
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length)) — Mask to avoid performing attention on padding token indices.
|
|
531
|
+
position_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of positions of each input sequence tokens in the position embeddings.
|
|
532
|
+
token_type_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Segment token indices to indicate first and second portions of the inputs.
|
|
533
|
+
output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
534
|
+
return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
PaligemmaModelOutputWithPast or tuple(torch.FloatTensor)
|
|
538
|
+
"""
|
|
539
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
540
|
+
output_hidden_states = (
|
|
541
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
542
|
+
)
|
|
543
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
544
|
+
raise ValueError(
|
|
545
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
546
|
+
f"Please compile again with the correct argument."
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
inputs_embeds, image_features = self._preprocess_prefill(
|
|
550
|
+
input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
outputs = self.language_model(
|
|
554
|
+
inputs_embeds=inputs_embeds,
|
|
555
|
+
attention_mask=attention_mask,
|
|
556
|
+
position_ids=position_ids,
|
|
557
|
+
output_hidden_states=output_hidden_states,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
return PaligemmaModelOutputWithPast(
|
|
561
|
+
last_hidden_state=outputs.last_hidden_state,
|
|
562
|
+
image_hidden_states=image_features if pixel_values is not None else None,
|
|
563
|
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
|
564
|
+
)
|
|
@@ -54,7 +54,7 @@ class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
|
54
54
|
support_causal_attn = True
|
|
55
55
|
|
|
56
56
|
@classmethod
|
|
57
|
-
def
|
|
57
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
|
|
58
58
|
return PegasusWrapper(
|
|
59
59
|
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
|
|
60
60
|
)
|
|
@@ -60,10 +60,10 @@ class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
|
60
60
|
class PegasusDecoder(Seq2SeqDecoder):
|
|
61
61
|
has_pos_emb = True
|
|
62
62
|
|
|
63
|
-
def __post_init__(self):
|
|
64
|
-
self.embed_positions =
|
|
65
|
-
self.embed_scale = getattr(
|
|
66
|
-
self.final_layer_norm = getattr(
|
|
63
|
+
def __post_init__(self, model: nn.Module):
|
|
64
|
+
self.embed_positions = model.embed_positions
|
|
65
|
+
self.embed_scale = getattr(model, "embed_scale", None)
|
|
66
|
+
self.final_layer_norm = getattr(model, "layer_norm", None)
|
|
67
67
|
|
|
68
68
|
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
|
69
69
|
if attention_mask is not None:
|
|
@@ -110,11 +110,11 @@ class PegasusLayerFF(nn.Module):
|
|
|
110
110
|
|
|
111
111
|
|
|
112
112
|
class PegasusDecoderLayer(Seq2SeqDecoderLayer):
|
|
113
|
-
def __post_init__(self):
|
|
114
|
-
self.self_attn_layer_norm =
|
|
115
|
-
self.encoder_attn =
|
|
116
|
-
self.encoder_attn_layer_norm =
|
|
117
|
-
self.ff_layer = PegasusLayerFF(
|
|
113
|
+
def __post_init__(self, decoder_layer: nn.Module):
|
|
114
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
|
115
|
+
self.encoder_attn = decoder_layer.encoder_attn
|
|
116
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
|
117
|
+
self.ff_layer = PegasusLayerFF(decoder_layer)
|
|
118
118
|
|
|
119
119
|
def pre_self_attn_layer_norm(self, hidden_states):
|
|
120
120
|
return self.self_attn_layer_norm(hidden_states)
|
|
@@ -130,13 +130,13 @@ class PegasusDecoderLayer(Seq2SeqDecoderLayer):
|
|
|
130
130
|
|
|
131
131
|
|
|
132
132
|
class PegasusSelfAttention(Seq2SeqSelfAttention):
|
|
133
|
-
def __post_init__(self, use_attention_mask: bool = True):
|
|
134
|
-
self.q_proj =
|
|
135
|
-
self.k_proj =
|
|
136
|
-
self.v_proj =
|
|
137
|
-
self.out_proj =
|
|
138
|
-
self.num_heads =
|
|
139
|
-
self.head_dim =
|
|
133
|
+
def __post_init__(self, attn: nn.Module, use_attention_mask: bool = True):
|
|
134
|
+
self.q_proj = attn.q_proj
|
|
135
|
+
self.k_proj = attn.k_proj
|
|
136
|
+
self.v_proj = attn.v_proj
|
|
137
|
+
self.out_proj = attn.out_proj
|
|
138
|
+
self.num_heads = attn.num_heads
|
|
139
|
+
self.head_dim = attn.embed_dim // attn.num_heads
|
|
140
140
|
self.scaling = self.head_dim**-0.5
|
|
141
141
|
if use_attention_mask:
|
|
142
142
|
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
|
@@ -151,11 +151,11 @@ class PegasusSelfAttention(Seq2SeqSelfAttention):
|
|
|
151
151
|
|
|
152
152
|
|
|
153
153
|
class PegasusCrossAttention(Seq2SeqCrossAttention):
|
|
154
|
-
def __post_init__(self):
|
|
155
|
-
self.q_proj =
|
|
156
|
-
self.k_proj =
|
|
157
|
-
self.v_proj =
|
|
158
|
-
self.out_proj =
|
|
159
|
-
self.num_heads =
|
|
160
|
-
self.head_dim =
|
|
161
|
-
self.embed_dim =
|
|
154
|
+
def __post_init__(self, attn: nn.Module):
|
|
155
|
+
self.q_proj = attn.q_proj
|
|
156
|
+
self.k_proj = attn.k_proj
|
|
157
|
+
self.v_proj = attn.v_proj
|
|
158
|
+
self.out_proj = attn.out_proj
|
|
159
|
+
self.num_heads = attn.num_heads
|
|
160
|
+
self.head_dim = attn.embed_dim // attn.num_heads
|
|
161
|
+
self.embed_dim = attn.embed_dim
|