optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -12,8 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from transformers import PretrainedConfig
|
16
|
+
|
15
17
|
from ....utils import logging
|
16
|
-
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
18
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
|
17
19
|
from .qwen2_architecture import QWEN2Wrapper
|
18
20
|
|
19
21
|
|
@@ -22,13 +24,74 @@ logger = logging.get_logger(__name__)
|
|
22
24
|
|
23
25
|
class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
24
26
|
"""
|
25
|
-
The
|
27
|
+
The Qwen2 Model transformer with a language modeling head (linear layer) on top.
|
26
28
|
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
27
29
|
|
28
|
-
A class to convert and run pre-trained transformers based
|
29
|
-
It implements the methods to convert a pre-trained transformers
|
30
|
+
A class to convert and run pre-trained transformers based Qwen2ForCausalLM model on RBLN devices.
|
31
|
+
It implements the methods to convert a pre-trained transformers Qwen2ForCausalLM model into a RBLN transformer model by:
|
30
32
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
31
33
|
- compiling the resulting graph using the RBLN compiler.
|
34
|
+
|
35
|
+
**Configuration:**
|
36
|
+
This model uses [`RBLNQwen2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
37
|
+
the `rbln_config` parameter should be an instance of [`RBLNQwen2ForCausalLMConfig`] or a dictionary conforming to its structure.
|
38
|
+
|
39
|
+
See the [`RBLNQwen2ForCausalLMConfig`] class for all available configuration options.
|
40
|
+
|
41
|
+
Examples:
|
42
|
+
```python
|
43
|
+
from optimum.rbln import RBLNQwen2ForCausalLM
|
44
|
+
|
45
|
+
# Simple usage using rbln_* arguments
|
46
|
+
# `max_seq_len` is automatically inferred from the model config
|
47
|
+
model = RBLNQwen2ForCausalLM.from_pretrained(
|
48
|
+
"Qwen/Qwen2-7B-Instruct",
|
49
|
+
export=True,
|
50
|
+
rbln_batch_size=1,
|
51
|
+
rbln_tensor_parallel_size=4,
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
# Using a config dictionary
|
56
|
+
rbln_config = {
|
57
|
+
"batch_size": 1,
|
58
|
+
"max_seq_len": 4096,
|
59
|
+
"tensor_parallel_size": 4,
|
60
|
+
}
|
61
|
+
model = RBLNQwen2ForCausalLM.from_pretrained(
|
62
|
+
"Qwen/Qwen2-7B-Instruct",
|
63
|
+
export=True,
|
64
|
+
rbln_config=rbln_config
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
# Using a RBLNQwen2ForCausalLMConfig instance (recommended for type checking)
|
69
|
+
from optimum.rbln import RBLNQwen2ForCausalLMConfig
|
70
|
+
|
71
|
+
config = RBLNQwen2ForCausalLMConfig(
|
72
|
+
batch_size=1,
|
73
|
+
max_seq_len=4096,
|
74
|
+
tensor_parallel_size=4
|
75
|
+
)
|
76
|
+
model = RBLNQwen2ForCausalLM.from_pretrained(
|
77
|
+
"Qwen/Qwen2-7B-Instruct",
|
78
|
+
export=True,
|
79
|
+
rbln_config=config
|
80
|
+
)
|
81
|
+
```
|
32
82
|
"""
|
33
83
|
|
34
84
|
_decoder_wrapper_cls = QWEN2Wrapper
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def _update_sliding_window_config(
|
88
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
89
|
+
):
|
90
|
+
# https://github.com/huggingface/transformers/issues/35896
|
91
|
+
# There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
|
92
|
+
# we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
|
93
|
+
|
94
|
+
rbln_config.cache_impl = "sliding_window"
|
95
|
+
rbln_config.sliding_window = model_config.sliding_window
|
96
|
+
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
97
|
+
return rbln_config
|
@@ -12,20 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import List, Optional, Union
|
15
|
+
from typing import Any, Dict, List, Optional, Union
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
19
19
|
|
20
20
|
|
21
21
|
class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
22
|
+
"""
|
23
|
+
Configuration class for RBLNQwen2_5_VLForConditionalGeneration.
|
24
|
+
|
25
|
+
This configuration class stores the configuration parameters specific to
|
26
|
+
RBLN-optimized Qwen2.5-VL models for multimodal conditional generation tasks
|
27
|
+
that combine vision and language processing capabilities.
|
28
|
+
"""
|
29
|
+
|
22
30
|
submodules = ["visual"]
|
23
31
|
|
24
32
|
def __init__(
|
25
33
|
self,
|
26
34
|
visual: Optional[RBLNModelConfig] = None,
|
27
35
|
use_inputs_embeds: bool = True,
|
28
|
-
**kwargs,
|
36
|
+
**kwargs: Dict[str, Any],
|
29
37
|
):
|
30
38
|
super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
|
31
39
|
if not self.use_inputs_embeds:
|
@@ -37,7 +45,15 @@ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausal
|
|
37
45
|
|
38
46
|
|
39
47
|
class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
|
40
|
-
|
48
|
+
"""
|
49
|
+
Configuration class for RBLNQwen2_5_VisionTransformerPretrainedModel.
|
50
|
+
|
51
|
+
This configuration class stores the configuration parameters specific to
|
52
|
+
RBLN-optimized Qwen2.5-VL vision transformer models with window-based attention
|
53
|
+
mechanisms for processing images and videos.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs: Dict[str, Any]):
|
41
57
|
"""
|
42
58
|
Args:
|
43
59
|
max_seq_lens (Optional[Union[int, List[int]]]): Maximum sequence lengths for Vision
|
@@ -54,6 +70,18 @@ class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
|
|
54
70
|
|
55
71
|
Raises:
|
56
72
|
ValueError: If batch_size is not a positive integer.
|
73
|
+
|
74
|
+
Max Seq Lens:
|
75
|
+
Since `Qwen2_5_VLForConditionalGeneration` performs inference on a per-image or per-frame basis,
|
76
|
+
`max_seq_lens` should be set based on the maximum expected resolution of the input images or video frames,
|
77
|
+
according to the following guidelines:
|
78
|
+
|
79
|
+
1. **Minimum Value**: `max_seq_lens` must be greater than or equal to the number of patches generated from the input image.
|
80
|
+
For example, a 224x224 image with a patch size of 14 results in (224 / 14) * (224 / 14) = 256 patches.
|
81
|
+
Therefore, `max_seq_lens` must be at least 256.
|
82
|
+
2. **Alignment Requirement**: `max_seq_lens` must be a multiple of `(window_size / patch_size)^2` due to the requirements
|
83
|
+
of the window-based attention mechanism. For instance, if `window_size` is 112 and `patch_size` is 14, then
|
84
|
+
`(112 / 14)^2 = 64`, meaning valid values for `max_seq_lens` include 64, 128, 192, 256, etc.
|
57
85
|
"""
|
58
86
|
super().__init__(**kwargs)
|
59
87
|
|
@@ -37,6 +37,7 @@ from ....utils.logging import get_logger
|
|
37
37
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput
|
38
38
|
from .configuration_qwen2_5_vl import (
|
39
39
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
40
|
+
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
40
41
|
)
|
41
42
|
from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_VL_LanguageModelWrapper
|
42
43
|
|
@@ -53,6 +54,14 @@ if TYPE_CHECKING:
|
|
53
54
|
|
54
55
|
|
55
56
|
class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
57
|
+
"""
|
58
|
+
RBLN optimized Qwen2.5-VL vision transformer model.
|
59
|
+
|
60
|
+
This class provides hardware-accelerated inference for Qwen2.5-VL vision transformers
|
61
|
+
on RBLN devices, supporting image and video encoding for multimodal vision-language tasks
|
62
|
+
with window-based attention mechanisms.
|
63
|
+
"""
|
64
|
+
|
56
65
|
auto_model_class = None
|
57
66
|
|
58
67
|
def __post_init__(self, **kwargs):
|
@@ -338,6 +347,40 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
|
|
338
347
|
|
339
348
|
|
340
349
|
class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
350
|
+
"""
|
351
|
+
RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
352
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
353
|
+
|
354
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
355
|
+
|
356
|
+
Important Note:
|
357
|
+
This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
|
358
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
359
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
|
360
|
+
|
361
|
+
Examples:
|
362
|
+
```python
|
363
|
+
from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
|
364
|
+
|
365
|
+
model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
|
366
|
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
367
|
+
export=True,
|
368
|
+
rbln_config={
|
369
|
+
"visual": {
|
370
|
+
"max_seq_lens": 6400,
|
371
|
+
"device": 0,
|
372
|
+
},
|
373
|
+
"tensor_parallel_size": 8,
|
374
|
+
"kvcache_partition_len": 16_384,
|
375
|
+
"max_seq_len": 114_688,
|
376
|
+
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
377
|
+
},
|
378
|
+
)
|
379
|
+
|
380
|
+
model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
|
381
|
+
```
|
382
|
+
"""
|
383
|
+
|
341
384
|
auto_model_class = AutoModelForVision2Seq
|
342
385
|
_rbln_submodules = [
|
343
386
|
{"name": "visual"},
|
@@ -369,33 +412,19 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
369
412
|
cls,
|
370
413
|
batch_size: int,
|
371
414
|
query_length: int,
|
372
|
-
|
373
|
-
|
374
|
-
use_position_ids: bool,
|
375
|
-
max_seq_len: int,
|
376
|
-
kvcache_block_size: int,
|
377
|
-
kvcache_num_blocks: int,
|
378
|
-
num_key_value_heads: int,
|
379
|
-
num_hidden_layers: int,
|
380
|
-
hidden_size: int,
|
381
|
-
head_dim: int,
|
415
|
+
rbln_config: RBLNQwen2_5_VLForConditionalGenerationConfig,
|
416
|
+
model_config: PretrainedConfig,
|
382
417
|
):
|
383
|
-
input_info = super().get_input_info(
|
384
|
-
batch_size,
|
385
|
-
query_length,
|
386
|
-
use_inputs_embeds,
|
387
|
-
use_attention_mask,
|
388
|
-
use_position_ids,
|
389
|
-
max_seq_len,
|
390
|
-
kvcache_block_size,
|
391
|
-
kvcache_num_blocks,
|
392
|
-
num_key_value_heads,
|
393
|
-
num_hidden_layers,
|
394
|
-
hidden_size,
|
395
|
-
head_dim,
|
396
|
-
)
|
418
|
+
input_info = super().get_input_info(batch_size, query_length, rbln_config, model_config)
|
397
419
|
pos_idx = 3
|
398
|
-
input_info.insert(
|
420
|
+
input_info.insert(
|
421
|
+
pos_idx,
|
422
|
+
(
|
423
|
+
"position_emb",
|
424
|
+
[2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
|
425
|
+
"float32",
|
426
|
+
),
|
427
|
+
)
|
399
428
|
|
400
429
|
return input_info
|
401
430
|
|
@@ -79,7 +79,7 @@ class Qwen2_5_VLVisionFullAttention(nn.Module):
|
|
79
79
|
super().__init__()
|
80
80
|
self._origin_model = model
|
81
81
|
self.num_heads = model.num_heads
|
82
|
-
self.head_dim = model.
|
82
|
+
self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
|
83
83
|
self.qkv = model.qkv
|
84
84
|
self.proj = model.proj
|
85
85
|
|
@@ -114,7 +114,7 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
|
|
114
114
|
super().__init__()
|
115
115
|
self._origin_model = model
|
116
116
|
self.num_heads = model.num_heads
|
117
|
-
self.head_dim = model.
|
117
|
+
self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
|
118
118
|
self.qkv = model.qkv
|
119
119
|
self.proj = model.proj
|
120
120
|
self.window_seq_len = window_seq_len
|
@@ -162,7 +162,8 @@ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
|
162
162
|
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
163
163
|
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
164
164
|
cache_position = args.pop(0)
|
165
|
-
|
165
|
+
global_block_tables = args.pop(0)
|
166
|
+
local_block_tables = None
|
166
167
|
position_embeds = args.pop(0)
|
167
168
|
query_position = args.pop(0) if self.phase == "prefill" else None
|
168
169
|
position_ids = None
|
@@ -188,7 +189,8 @@ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
|
188
189
|
input_ids,
|
189
190
|
inputs_embeds,
|
190
191
|
cache_position,
|
191
|
-
|
192
|
+
global_block_tables,
|
193
|
+
local_block_tables,
|
192
194
|
query_position,
|
193
195
|
attention_mask,
|
194
196
|
position_ids,
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from .configuration_resnet import RBLNResNetForImageClassificationConfig
|
17
|
+
from .modeling_resnet import RBLNResNetForImageClassification
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
"RBLNResNetForImageClassificationConfig",
|
22
|
+
"RBLNResNetForImageClassification",
|
23
|
+
]
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
17
|
+
|
18
|
+
|
19
|
+
class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
20
|
+
"""
|
21
|
+
Configuration class for RBLNResNetForImageClassification.
|
22
|
+
|
23
|
+
This configuration class stores the configuration parameters specific to
|
24
|
+
RBLN-optimized ResNet models for image classification tasks.
|
25
|
+
"""
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from ...modeling_generic import RBLNModelForImageClassification
|
17
|
+
|
18
|
+
|
19
|
+
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
20
|
+
"""
|
21
|
+
RBLN optimized ResNet model for image classification tasks.
|
22
|
+
|
23
|
+
This class provides hardware-accelerated inference for ResNet models
|
24
|
+
on RBLN devices, supporting image classification with convolutional neural networks
|
25
|
+
designed for computer vision tasks.
|
26
|
+
"""
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_roberta import RBLNRobertaForMaskedLMConfig, RBLNRobertaForSequenceClassificationConfig
|
16
|
+
from .modeling_roberta import RBLNRobertaForMaskedLM, RBLNRobertaForSequenceClassification
|
17
|
+
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"RBLNRobertaForMaskedLMConfig",
|
21
|
+
"RBLNRobertaForSequenceClassificationConfig",
|
22
|
+
"RBLNRobertaForMaskedLM",
|
23
|
+
"RBLNRobertaForSequenceClassification",
|
24
|
+
]
|
optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py}
RENAMED
@@ -12,38 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from
|
16
|
-
RBLNModelForAudioClassificationConfig,
|
17
|
-
RBLNModelForImageClassificationConfig,
|
18
|
-
RBLNModelForMaskedLMConfig,
|
19
|
-
RBLNModelForQuestionAnsweringConfig,
|
20
|
-
RBLNModelForSequenceClassificationConfig,
|
21
|
-
)
|
15
|
+
from ...configuration_generic import RBLNModelForMaskedLMConfig, RBLNModelForSequenceClassificationConfig
|
22
16
|
|
23
17
|
|
24
|
-
class
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
|
29
|
-
pass
|
30
|
-
|
31
|
-
|
32
|
-
class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
33
|
-
pass
|
34
|
-
|
18
|
+
class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
|
19
|
+
"""
|
20
|
+
Configuration class for RBLNRobertaForMaskedLM.
|
35
21
|
|
36
|
-
class
|
37
|
-
|
22
|
+
This configuration class stores the configuration parameters specific to
|
23
|
+
RBLN-optimized RoBERTa models for masked language modeling tasks.
|
24
|
+
"""
|
38
25
|
|
39
26
|
|
40
27
|
class RBLNRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
|
45
|
-
pass
|
46
|
-
|
28
|
+
"""
|
29
|
+
Configuration class for RBLNRobertaForSequenceClassification.
|
47
30
|
|
48
|
-
class
|
49
|
-
|
31
|
+
This configuration class stores the configuration parameters specific to
|
32
|
+
RBLN-optimized RoBERTa models for sequence classification tasks.
|
33
|
+
"""
|
@@ -12,42 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from
|
16
|
-
from .modeling_generic import (
|
17
|
-
RBLNModelForAudioClassification,
|
18
|
-
RBLNModelForImageClassification,
|
19
|
-
RBLNModelForMaskedLM,
|
20
|
-
RBLNModelForQuestionAnswering,
|
21
|
-
RBLNModelForSequenceClassification,
|
22
|
-
)
|
15
|
+
from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
|
23
16
|
|
24
17
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
|
29
|
-
pass
|
30
|
-
|
31
|
-
|
32
|
-
class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
33
|
-
rbln_model_input_names = ["input_ids", "attention_mask"]
|
34
|
-
|
35
|
-
|
36
|
-
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
37
|
-
pass
|
18
|
+
class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
19
|
+
"""
|
20
|
+
RBLN optimized RoBERTa model for masked language modeling tasks.
|
38
21
|
|
22
|
+
This class provides hardware-accelerated inference for RoBERTa models
|
23
|
+
on RBLN devices, supporting masked language modeling tasks such as
|
24
|
+
token prediction and text completion.
|
25
|
+
"""
|
39
26
|
|
40
|
-
class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
41
27
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
42
28
|
|
43
29
|
|
44
30
|
class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
45
|
-
|
31
|
+
"""
|
32
|
+
RBLN optimized RoBERTa model for sequence classification tasks.
|
46
33
|
|
34
|
+
This class provides hardware-accelerated inference for RoBERTa models
|
35
|
+
on RBLN devices, supporting text classification tasks such as sentiment analysis,
|
36
|
+
topic classification, and other sequence-level prediction tasks.
|
37
|
+
"""
|
47
38
|
|
48
|
-
class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
49
39
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
50
|
-
|
51
|
-
|
52
|
-
class RBLNViTForImageClassification(RBLNModelForImageClassification):
|
53
|
-
pass
|
@@ -12,5 +12,5 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from .
|
15
|
+
from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
|
16
16
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py}
RENAMED
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Dict, Optional
|
16
16
|
|
17
17
|
import rebel
|
18
18
|
|
@@ -31,7 +31,7 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
31
31
|
dec_max_seq_len: Optional[int] = None,
|
32
32
|
use_attention_mask: Optional[bool] = None,
|
33
33
|
pad_token_id: Optional[int] = None,
|
34
|
-
**kwargs,
|
34
|
+
**kwargs: Dict[str, Any],
|
35
35
|
):
|
36
36
|
"""
|
37
37
|
Args:
|
@@ -26,7 +26,7 @@ from ....configuration_utils import RBLNCompileConfig
|
|
26
26
|
from ....modeling import RBLNModel
|
27
27
|
from ....utils.logging import get_logger
|
28
28
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
29
|
-
from .
|
29
|
+
from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
|
30
30
|
|
31
31
|
|
32
32
|
logger = get_logger(__name__)
|
@@ -161,16 +161,20 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
161
161
|
if "key_value_states" in name:
|
162
162
|
context.mark_static_address(tensor)
|
163
163
|
|
164
|
-
compiled_encoder =
|
164
|
+
compiled_encoder = cls.compile(
|
165
165
|
wrapped_model.encoder,
|
166
166
|
enc_compile_config,
|
167
|
+
create_runtimes=rbln_config.create_runtimes,
|
168
|
+
device=rbln_config.device,
|
167
169
|
example_inputs=enc_example_inputs,
|
168
170
|
compile_context=context,
|
169
171
|
)
|
170
172
|
|
171
|
-
compiled_decoder =
|
173
|
+
compiled_decoder = cls.compile(
|
172
174
|
wrapped_model.decoder,
|
173
175
|
dec_compile_config,
|
176
|
+
create_runtimes=rbln_config.create_runtimes,
|
177
|
+
device=rbln_config.device,
|
174
178
|
example_inputs=dec_example_inputs,
|
175
179
|
compile_context=context,
|
176
180
|
)
|
@@ -148,7 +148,8 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
148
148
|
new_layers = []
|
149
149
|
for layer in model.get_decoder().layers:
|
150
150
|
self_attn = Seq2SeqSelfAttention(layer.self_attn)
|
151
|
-
|
151
|
+
cross_attn = Seq2SeqCrossAttention(layer.encoder_attn)
|
152
|
+
new_layers.append(Seq2SeqDecoderLayer(layer, self_attn, cross_attn))
|
152
153
|
|
153
154
|
decoder_model = Seq2SeqDecoder(model.get_decoder(), new_layers)
|
154
155
|
new_model = Seq2SeqForConditionalGeneration(model, decoder_model)
|
@@ -341,10 +342,11 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
341
342
|
self_attn (Seq2SeqSelfAttention): Modified self-attention layer optimized for RBLN
|
342
343
|
"""
|
343
344
|
|
344
|
-
def __init__(self, decoder_layer, self_attn):
|
345
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
345
346
|
super().__init__()
|
346
347
|
self._original_mod = decoder_layer
|
347
348
|
self.self_attn = self_attn
|
349
|
+
self.cross_attn = cross_attn
|
348
350
|
self.__post_init__()
|
349
351
|
|
350
352
|
def __post_init__(self, **kwargs):
|
@@ -402,7 +404,8 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
402
404
|
# Cross-Attention Block
|
403
405
|
residual = hidden_states
|
404
406
|
hidden_states = self.pre_cross_attn_layer_norm(hidden_states)
|
405
|
-
|
407
|
+
|
408
|
+
cross_attn_output = self.cross_attn(
|
406
409
|
hidden_states=hidden_states,
|
407
410
|
past_key_value=cross_past_key_value,
|
408
411
|
attention_mask=encoder_attention_mask,
|
@@ -487,3 +490,38 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
487
490
|
attn_output = self.out_proj(attn_output)
|
488
491
|
|
489
492
|
return attn_output
|
493
|
+
|
494
|
+
|
495
|
+
class Seq2SeqCrossAttention(nn.Module):
|
496
|
+
def __init__(self, attn, **kwargs):
|
497
|
+
super().__init__()
|
498
|
+
self._original_mod = attn
|
499
|
+
self.__post_init__(**kwargs)
|
500
|
+
|
501
|
+
def forward(
|
502
|
+
self,
|
503
|
+
hidden_states: torch.Tensor,
|
504
|
+
key_value_states: torch.Tensor = None,
|
505
|
+
past_key_value: Optional[object] = None,
|
506
|
+
attention_mask: Optional[torch.Tensor] = None,
|
507
|
+
):
|
508
|
+
bsz, tgt_len, _ = hidden_states.size()
|
509
|
+
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
510
|
+
|
511
|
+
is_cross_attention = key_value_states is not None
|
512
|
+
if is_cross_attention:
|
513
|
+
key_states = past_key_value[0]
|
514
|
+
value_states = past_key_value[1]
|
515
|
+
|
516
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
517
|
+
query_states,
|
518
|
+
key_states,
|
519
|
+
value_states,
|
520
|
+
attn_mask=attention_mask,
|
521
|
+
)
|
522
|
+
|
523
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
524
|
+
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
525
|
+
attn_output = self.out_proj(attn_output)
|
526
|
+
|
527
|
+
return attn_output, None, past_key_value
|