optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -123,7 +123,10 @@ class MidmAttention(DecoderOnlyAttention):
|
|
|
123
123
|
self.split_size = self._original_mod.split_size
|
|
124
124
|
self.num_key_value_heads = self._original_mod.num_heads
|
|
125
125
|
|
|
126
|
-
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
126
|
+
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
127
|
+
if lora_int_id is not None:
|
|
128
|
+
raise NotImplementedError("LoRA is not supported for MidmAttention")
|
|
129
|
+
|
|
127
130
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
128
131
|
return query_states, key_states, value_states
|
|
129
132
|
|
|
@@ -13,11 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
from
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Callable, Dict, Optional, Union
|
|
17
18
|
|
|
18
19
|
from transformers import AutoModelForCausalLM
|
|
19
20
|
from transformers.generation.utils import GenerationMixin
|
|
20
21
|
|
|
22
|
+
from ....configuration_utils import RBLNModelConfig
|
|
21
23
|
from ....utils import logging
|
|
22
24
|
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
23
25
|
from .midm_architecture import MidmLMHeadModelWrapper
|
|
@@ -91,9 +93,45 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
|
91
93
|
_supports_cache_class = True
|
|
92
94
|
|
|
93
95
|
@classmethod
|
|
94
|
-
def from_pretrained(
|
|
95
|
-
|
|
96
|
-
|
|
96
|
+
def from_pretrained(
|
|
97
|
+
cls,
|
|
98
|
+
model_id: Union[str, Path],
|
|
99
|
+
*,
|
|
100
|
+
export: Optional[bool] = None,
|
|
101
|
+
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
102
|
+
trust_remote_code: Optional[bool] = None,
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
107
|
+
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
|
|
111
|
+
It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
|
112
|
+
export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
|
|
113
|
+
If None, it will be determined based on the existence of the compiled model files in the model_id.
|
|
114
|
+
rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
|
|
115
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNMidmLMHeadModelConfig` for Mi:dm models).
|
|
116
|
+
For detailed configuration options, see the specific model's configuration class documentation.
|
|
117
|
+
trust_remote_code (bool): Whether or not to trust the remote code when loading a model from the Hub.
|
|
118
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
if trust_remote_code is not None:
|
|
125
|
+
kwargs["trust_remote_code"] = trust_remote_code
|
|
126
|
+
elif "trust_remote_code" not in kwargs:
|
|
127
|
+
kwargs["trust_remote_code"] = True
|
|
128
|
+
|
|
129
|
+
return super().from_pretrained(
|
|
130
|
+
model_id=model_id,
|
|
131
|
+
export=export,
|
|
132
|
+
rbln_config=rbln_config,
|
|
133
|
+
**kwargs,
|
|
134
|
+
)
|
|
97
135
|
|
|
98
136
|
def __getattr__(self, __name: str) -> Any:
|
|
99
137
|
def redirect(func):
|
|
@@ -12,5 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_mistral import RBLNMistralForCausalLMConfig
|
|
16
|
-
from .modeling_mistral import RBLNMistralForCausalLM
|
|
15
|
+
from .configuration_mistral import RBLNMistralForCausalLMConfig, RBLNMistralModelConfig
|
|
16
|
+
from .modeling_mistral import RBLNMistralForCausalLM, RBLNMistralModel
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
@@ -40,3 +40,11 @@ class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
|
40
40
|
)
|
|
41
41
|
```
|
|
42
42
|
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNMistralModelConfig(RBLNDecoderOnlyModelConfig):
|
|
46
|
+
"""
|
|
47
|
+
Configuration class for RBLN Mistral models.
|
|
48
|
+
|
|
49
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
50
|
+
"""
|
|
@@ -15,8 +15,12 @@
|
|
|
15
15
|
from transformers import PretrainedConfig
|
|
16
16
|
|
|
17
17
|
from ....utils import logging
|
|
18
|
-
from ...models.decoderonly import
|
|
19
|
-
|
|
18
|
+
from ...models.decoderonly import (
|
|
19
|
+
RBLNDecoderOnlyModel,
|
|
20
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
21
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
22
|
+
)
|
|
23
|
+
from .mistral_architecture import MistralWrapper
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
logger = logging.get_logger(__name__)
|
|
@@ -79,7 +83,26 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
79
83
|
```
|
|
80
84
|
"""
|
|
81
85
|
|
|
82
|
-
_decoder_wrapper_cls =
|
|
86
|
+
_decoder_wrapper_cls = MistralWrapper
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def _update_sliding_window_config(
|
|
90
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
91
|
+
):
|
|
92
|
+
rbln_config.cache_impl = "sliding_window"
|
|
93
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
94
|
+
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
95
|
+
|
|
96
|
+
return rbln_config
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class RBLNMistralModel(RBLNDecoderOnlyModel):
|
|
100
|
+
"""
|
|
101
|
+
The Mistral Model transformer without a language modeling head.
|
|
102
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
_decoder_wrapper_cls = MistralWrapper
|
|
83
106
|
|
|
84
107
|
@classmethod
|
|
85
108
|
def _update_sliding_window_config(
|
|
@@ -12,5 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_opt import RBLNOPTForCausalLMConfig
|
|
16
|
-
from .modeling_opt import RBLNOPTForCausalLM
|
|
15
|
+
from .configuration_opt import RBLNOPTForCausalLMConfig, RBLNOPTModelConfig
|
|
16
|
+
from .modeling_opt import RBLNOPTForCausalLM, RBLNOPTModel
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
@@ -20,3 +20,10 @@ class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
|
20
20
|
Configuration class for OPT causal language model.
|
|
21
21
|
Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
|
|
22
22
|
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class RBLNOPTModelConfig(RBLNDecoderOnlyModelConfig):
|
|
26
|
+
"""
|
|
27
|
+
Configuration class for OPT model.
|
|
28
|
+
Inherits from RBLNDecoderOnlyModelConfig with no additional parameters.
|
|
29
|
+
"""
|
|
@@ -16,7 +16,7 @@ import torch.nn as nn
|
|
|
16
16
|
from transformers import PreTrainedModel
|
|
17
17
|
|
|
18
18
|
from ....utils import logging
|
|
19
|
-
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
19
|
+
from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|
|
20
20
|
from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
21
21
|
from .opt_architecture import OPTWrapper
|
|
22
22
|
|
|
@@ -69,22 +69,34 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
69
69
|
return layer
|
|
70
70
|
|
|
71
71
|
@classmethod
|
|
72
|
-
def
|
|
73
|
-
wrapper_cfg = {
|
|
74
|
-
"max_seq_len": rbln_config.max_seq_len,
|
|
75
|
-
"attn_impl": rbln_config.attn_impl,
|
|
76
|
-
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
|
77
|
-
"kvcache_block_size": rbln_config.kvcache_block_size,
|
|
78
|
-
"use_rotary_emb": cls._use_rotary_emb,
|
|
79
|
-
"use_attention_mask": rbln_config.use_attention_mask,
|
|
80
|
-
"use_position_ids": rbln_config.use_position_ids,
|
|
81
|
-
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
|
82
|
-
"cache_impl": rbln_config.cache_impl,
|
|
83
|
-
"sliding_window": rbln_config.sliding_window,
|
|
84
|
-
"sliding_window_layers": rbln_config.sliding_window_layers,
|
|
85
|
-
}
|
|
86
|
-
|
|
72
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
87
73
|
for i in range(len(model.model.decoder.layers)):
|
|
88
74
|
model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
|
|
89
75
|
|
|
90
|
-
return cls._decoder_wrapper_cls(model,
|
|
76
|
+
return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class RBLNOPTModel(RBLNDecoderOnlyModel):
|
|
80
|
+
"""
|
|
81
|
+
The OPT Model transformer without a language modeling head.
|
|
82
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
_decoder_wrapper_cls = OPTWrapper
|
|
86
|
+
_use_rotary_emb = False
|
|
87
|
+
|
|
88
|
+
def modify_opt_decoder_layer(layer):
|
|
89
|
+
mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
|
|
90
|
+
layer.mlp = mlp
|
|
91
|
+
del layer.fc1
|
|
92
|
+
del layer.fc2
|
|
93
|
+
del layer.activation_fn
|
|
94
|
+
|
|
95
|
+
return layer
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
99
|
+
for i in range(len(model.decoder.layers)):
|
|
100
|
+
model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
|
|
101
|
+
|
|
102
|
+
return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
|
|
@@ -40,11 +40,11 @@ class OPTWrapper(DecoderOnlyWrapper):
|
|
|
40
40
|
def get_rbln_model_class(self):
|
|
41
41
|
return OPTModel
|
|
42
42
|
|
|
43
|
-
def get_model_layer(self,
|
|
44
|
-
return
|
|
43
|
+
def get_model_layer(self, model: "OPTForCausalLM"):
|
|
44
|
+
return model.model.decoder if self.is_causal_lm else model.decoder
|
|
45
45
|
|
|
46
|
-
def get_decoder_layers(self,
|
|
47
|
-
return
|
|
46
|
+
def get_decoder_layers(self, model: "OPTForCausalLM"):
|
|
47
|
+
return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
class OPTAttention(DecoderOnlyAttention):
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ....ops import paged_attn_decode, paged_causal_attn_decode
|
|
16
|
+
from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig, RBLNPegasusModelConfig
|
|
17
|
+
from .modeling_pegasus import RBLNPegasusForConditionalGeneration, RBLNPegasusModel
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
|
16
|
+
from ..seq2seq import RBLNModelForSeq2SeqLMConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
|
20
|
+
"""
|
|
21
|
+
Configuration class for RBLNPegasusModel.
|
|
22
|
+
|
|
23
|
+
This configuration class stores the configuration parameters specific to
|
|
24
|
+
RBLN-optimized PEGASUS models for feature extraction tasks.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
|
|
31
|
+
"""
|
|
32
|
+
Configuration class for RBLNPegasusForConditionalGeneration.
|
|
33
|
+
|
|
34
|
+
This configuration class stores the configuration parameters specific to
|
|
35
|
+
RBLN-optimized PEGASUS models for conditional text generation tasks.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
support_paged_attention = True
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
17
|
+
|
|
18
|
+
from transformers import PegasusForConditionalGeneration, PreTrainedModel
|
|
19
|
+
|
|
20
|
+
from ....utils.logging import get_logger
|
|
21
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
|
22
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
|
23
|
+
from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig
|
|
24
|
+
from .pegasus_architecture import PegasusWrapper
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = get_logger()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from transformers import PreTrainedModel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
35
|
+
"""
|
|
36
|
+
RBLN optimized PEGASUS model for feature extraction tasks.
|
|
37
|
+
|
|
38
|
+
This class provides hardware-accelerated inference for PEGASUS encoder models
|
|
39
|
+
on RBLN devices, optimized for feature extraction use cases.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
46
|
+
"""
|
|
47
|
+
RBLN optimized PEGASUS model for conditional text generation tasks.
|
|
48
|
+
|
|
49
|
+
This class provides hardware-accelerated inference for PEGASUS models
|
|
50
|
+
on RBLN devices, supporting sequence-to-sequence generation tasks
|
|
51
|
+
such as summarization, translation, and text generation.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
support_causal_attn = True
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
|
|
58
|
+
return PegasusWrapper(
|
|
59
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def __getattr__(self, __name: str) -> Any:
|
|
63
|
+
def redirect(func):
|
|
64
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
65
|
+
|
|
66
|
+
val = getattr(PegasusForConditionalGeneration, __name)
|
|
67
|
+
|
|
68
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
69
|
+
return redirect(val)
|
|
70
|
+
|
|
71
|
+
return val
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
20
|
+
from transformers.utils import logging
|
|
21
|
+
|
|
22
|
+
from ..seq2seq.seq2seq_architecture import (
|
|
23
|
+
Seq2SeqCrossAttention,
|
|
24
|
+
Seq2SeqDecoder,
|
|
25
|
+
Seq2SeqDecoderLayer,
|
|
26
|
+
Seq2SeqDecoderWrapper,
|
|
27
|
+
Seq2SeqEncoderWrapper,
|
|
28
|
+
Seq2SeqForConditionalGeneration,
|
|
29
|
+
Seq2SeqSelfAttention,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PegasusWrapper:
|
|
37
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
|
|
38
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
|
39
|
+
self.decoder = PegasusDecoderWrapper(model, use_attention_mask=use_attention_mask)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PegasusDecoderWrapper(Seq2SeqDecoderWrapper):
|
|
43
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
|
44
|
+
new_layers = []
|
|
45
|
+
for layer in model.get_decoder().layers:
|
|
46
|
+
self_attn = PegasusSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
|
|
47
|
+
cross_attn = PegasusCrossAttention(layer.encoder_attn)
|
|
48
|
+
new_layers.append(PegasusDecoderLayer(layer, self_attn, cross_attn))
|
|
49
|
+
|
|
50
|
+
decoder_model = PegasusDecoder(model.get_decoder(), new_layers)
|
|
51
|
+
new_model = PegasusForConditionalGeneration(model, decoder_model)
|
|
52
|
+
|
|
53
|
+
return new_model
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class PegasusDecoder(Seq2SeqDecoder):
|
|
61
|
+
has_pos_emb = True
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
self.embed_positions = self._original_mod.embed_positions
|
|
65
|
+
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
|
66
|
+
self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
|
|
67
|
+
|
|
68
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
|
69
|
+
if attention_mask is not None:
|
|
70
|
+
attention_mask = attention_mask[:, None, None, :]
|
|
71
|
+
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
|
72
|
+
|
|
73
|
+
return attention_mask, encoder_attention_mask
|
|
74
|
+
|
|
75
|
+
def apply_position_embedding(self, inputs_embeds, cache_position):
|
|
76
|
+
hidden_all = []
|
|
77
|
+
for i in range(inputs_embeds.shape[0]):
|
|
78
|
+
positions_idx = cache_position[i]
|
|
79
|
+
position_weight = self.embed_positions.weight
|
|
80
|
+
position = position_weight[positions_idx]
|
|
81
|
+
batch_hidden = position + inputs_embeds[i]
|
|
82
|
+
hidden_all.append(batch_hidden)
|
|
83
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
|
84
|
+
|
|
85
|
+
return hidden_states
|
|
86
|
+
|
|
87
|
+
def get_embedding(self):
|
|
88
|
+
if self.embed_scale is not None:
|
|
89
|
+
return lambda x: self.embed_tokens(x) * self.embed_scale
|
|
90
|
+
else:
|
|
91
|
+
return self.embed_tokens
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class PegasusLayerFF(nn.Module):
|
|
95
|
+
def __init__(self, decoder_layer):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.fc1 = decoder_layer.fc1
|
|
98
|
+
self.fc2 = decoder_layer.fc2
|
|
99
|
+
self.activation_fn = decoder_layer.activation_fn
|
|
100
|
+
self.layer_norm = decoder_layer.final_layer_norm
|
|
101
|
+
|
|
102
|
+
def forward(self, hidden_states):
|
|
103
|
+
# Residual Connection
|
|
104
|
+
residual = hidden_states
|
|
105
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
106
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
107
|
+
hidden_states = self.fc2(hidden_states)
|
|
108
|
+
hidden_states = residual + hidden_states
|
|
109
|
+
return hidden_states
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class PegasusDecoderLayer(Seq2SeqDecoderLayer):
|
|
113
|
+
def __post_init__(self):
|
|
114
|
+
self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
|
|
115
|
+
self.encoder_attn = self._original_mod.encoder_attn
|
|
116
|
+
self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
|
|
117
|
+
self.ff_layer = PegasusLayerFF(self._original_mod)
|
|
118
|
+
|
|
119
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
|
120
|
+
return self.self_attn_layer_norm(hidden_states)
|
|
121
|
+
|
|
122
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
|
123
|
+
return hidden_states
|
|
124
|
+
|
|
125
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
|
126
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
|
127
|
+
|
|
128
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
|
129
|
+
return hidden_states
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class PegasusSelfAttention(Seq2SeqSelfAttention):
|
|
133
|
+
def __post_init__(self, use_attention_mask: bool = True):
|
|
134
|
+
self.q_proj = self._original_mod.q_proj
|
|
135
|
+
self.k_proj = self._original_mod.k_proj
|
|
136
|
+
self.v_proj = self._original_mod.v_proj
|
|
137
|
+
self.out_proj = self._original_mod.out_proj
|
|
138
|
+
self.num_heads = self._original_mod.num_heads
|
|
139
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
|
140
|
+
self.scaling = self.head_dim**-0.5
|
|
141
|
+
if use_attention_mask:
|
|
142
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
|
143
|
+
else:
|
|
144
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
|
|
145
|
+
|
|
146
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
147
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
|
148
|
+
key_states = self.k_proj(hidden_states)
|
|
149
|
+
value_states = self.v_proj(hidden_states)
|
|
150
|
+
return query_states, key_states, value_states
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class PegasusCrossAttention(Seq2SeqCrossAttention):
|
|
154
|
+
def __post_init__(self):
|
|
155
|
+
self.q_proj = self._original_mod.q_proj
|
|
156
|
+
self.k_proj = self._original_mod.k_proj
|
|
157
|
+
self.v_proj = self._original_mod.v_proj
|
|
158
|
+
self.out_proj = self._original_mod.out_proj
|
|
159
|
+
self.num_heads = self._original_mod.num_heads
|
|
160
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
|
161
|
+
self.embed_dim = self._original_mod.embed_dim
|
|
@@ -12,5 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_phi import RBLNPhiForCausalLMConfig
|
|
16
|
-
from .modeling_phi import RBLNPhiForCausalLM
|
|
15
|
+
from .configuration_phi import RBLNPhiForCausalLMConfig, RBLNPhiModelConfig
|
|
16
|
+
from .modeling_phi import RBLNPhiForCausalLM, RBLNPhiModel
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
@@ -40,3 +40,11 @@ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
|
40
40
|
)
|
|
41
41
|
```
|
|
42
42
|
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNPhiModelConfig(RBLNDecoderOnlyModelConfig):
|
|
46
|
+
"""
|
|
47
|
+
Configuration class for RBLN Phi models.
|
|
48
|
+
|
|
49
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
50
|
+
"""
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from ....utils import logging
|
|
16
|
-
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
16
|
+
from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|
|
17
17
|
from .phi_architecture import PhiWrapper
|
|
18
18
|
|
|
19
19
|
|
|
@@ -81,3 +81,12 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
81
81
|
"""
|
|
82
82
|
|
|
83
83
|
_decoder_wrapper_cls = PhiWrapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class RBLNPhiModel(RBLNDecoderOnlyModel):
|
|
87
|
+
"""
|
|
88
|
+
The Phi Model transformer without a language modeling head.
|
|
89
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
_decoder_wrapper_cls = PhiWrapper
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Optional, Tuple
|
|
15
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from transformers import PhiForCausalLM
|
|
@@ -27,7 +27,7 @@ from ..decoderonly.decoderonly_architecture import (
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
|
-
from transformers import PhiForCausalLM
|
|
30
|
+
from transformers import PhiForCausalLM, PhiModel
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class PhiWrapper(DecoderOnlyWrapper):
|
|
@@ -40,11 +40,11 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
|
40
40
|
def get_rbln_model_class(self):
|
|
41
41
|
return PhiModel
|
|
42
42
|
|
|
43
|
-
def get_model_layer(self,
|
|
44
|
-
return
|
|
43
|
+
def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
|
|
44
|
+
return model.model if self.is_causal_lm else model
|
|
45
45
|
|
|
46
|
-
def get_decoder_layers(self,
|
|
47
|
-
return
|
|
46
|
+
def get_decoder_layers(self, model: Union["PhiForCausalLM", "PhiModel"]):
|
|
47
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
class PhiAttention(DecoderOnlyAttention):
|
|
@@ -56,7 +56,10 @@ class PhiAttention(DecoderOnlyAttention):
|
|
|
56
56
|
self.qk_layernorm = self._original_mod.qk_layernorm
|
|
57
57
|
self.rotary_ndims = self._original_mod.rotary_ndims
|
|
58
58
|
|
|
59
|
-
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
59
|
+
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
60
|
+
if lora_int_id is not None:
|
|
61
|
+
raise NotImplementedError("LoRA is not supported for PhiAttention")
|
|
62
|
+
|
|
60
63
|
query_states = self.q_proj(hidden_states)
|
|
61
64
|
key_states = self.k_proj(hidden_states)
|
|
62
65
|
value_states = self.v_proj(hidden_states)
|
|
@@ -84,6 +87,7 @@ class PhiLayer(DecoderOnlyLayer):
|
|
|
84
87
|
cos: Optional[torch.Tensor] = None,
|
|
85
88
|
sin: Optional[torch.Tensor] = None,
|
|
86
89
|
block_tables: Optional[torch.Tensor] = None,
|
|
90
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
87
91
|
):
|
|
88
92
|
residual = hidden_states
|
|
89
93
|
|