optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -30,17 +30,34 @@ _import_structure = {
|
|
30
30
|
"cache_utils": ["RebelDynamicCache"],
|
31
31
|
"generation": ["BatchTextIteratorStreamer"],
|
32
32
|
"models": [
|
33
|
+
"RBLNAutoModel",
|
34
|
+
"RBLNAutoModelForAudioClassification",
|
35
|
+
"RBLNAutoModelForCausalLM",
|
36
|
+
"RBLNAutoModelForCTC",
|
37
|
+
"RBLNAutoModelForDepthEstimation",
|
38
|
+
"RBLNAutoModelForImageClassification",
|
39
|
+
"RBLNAutoModelForMaskedLM",
|
40
|
+
"RBLNAutoModelForQuestionAnswering",
|
41
|
+
"RBLNAutoModelForSeq2SeqLM",
|
42
|
+
"RBLNAutoModelForSequenceClassification",
|
43
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
44
|
+
"RBLNAutoModelForVision2Seq",
|
45
|
+
"RBLNBartModel",
|
46
|
+
"RBLNBertModel",
|
33
47
|
"RBLNCLIPTextModel",
|
34
48
|
"RBLNCLIPTextModelWithProjection",
|
49
|
+
"RBLNCLIPVisionModel",
|
35
50
|
"RBLNDPTForDepthEstimation",
|
36
51
|
"RBLNGemmaForCausalLM",
|
37
52
|
"RBLNGPT2LMHeadModel",
|
38
53
|
"RBLNWav2Vec2ForCTC",
|
39
54
|
"RBLNWhisperForConditionalGeneration",
|
40
55
|
"RBLNLlamaForCausalLM",
|
56
|
+
"RBLNPhiForCausalLM",
|
57
|
+
"RBLNLlavaNextForConditionalGeneration",
|
41
58
|
"RBLNMidmLMHeadModel",
|
42
|
-
"RBLNMistralForCausalLM",
|
43
59
|
"RBLNXLMRobertaModel",
|
60
|
+
"RBLNMistralForCausalLM",
|
44
61
|
],
|
45
62
|
}
|
46
63
|
|
@@ -48,14 +65,31 @@ if TYPE_CHECKING:
|
|
48
65
|
from .cache_utils import RebelDynamicCache
|
49
66
|
from .generation import BatchTextIteratorStreamer
|
50
67
|
from .models import (
|
68
|
+
RBLNAutoModel,
|
69
|
+
RBLNAutoModelForAudioClassification,
|
70
|
+
RBLNAutoModelForCausalLM,
|
71
|
+
RBLNAutoModelForCTC,
|
72
|
+
RBLNAutoModelForDepthEstimation,
|
73
|
+
RBLNAutoModelForImageClassification,
|
74
|
+
RBLNAutoModelForMaskedLM,
|
75
|
+
RBLNAutoModelForQuestionAnswering,
|
76
|
+
RBLNAutoModelForSeq2SeqLM,
|
77
|
+
RBLNAutoModelForSequenceClassification,
|
78
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
79
|
+
RBLNAutoModelForVision2Seq,
|
80
|
+
RBLNBartModel,
|
81
|
+
RBLNBertModel,
|
51
82
|
RBLNCLIPTextModel,
|
52
83
|
RBLNCLIPTextModelWithProjection,
|
84
|
+
RBLNCLIPVisionModel,
|
53
85
|
RBLNDPTForDepthEstimation,
|
54
86
|
RBLNGemmaForCausalLM,
|
55
87
|
RBLNGPT2LMHeadModel,
|
56
88
|
RBLNLlamaForCausalLM,
|
89
|
+
RBLNLlavaNextForConditionalGeneration,
|
57
90
|
RBLNMidmLMHeadModel,
|
58
91
|
RBLNMistralForCausalLM,
|
92
|
+
RBLNPhiForCausalLM,
|
59
93
|
RBLNWav2Vec2ForCTC,
|
60
94
|
RBLNWhisperForConditionalGeneration,
|
61
95
|
RBLNXLMRobertaModel,
|
@@ -21,13 +21,32 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
|
24
|
+
|
25
|
+
from .auto import (
|
26
|
+
RBLNAutoModel,
|
27
|
+
RBLNAutoModelForAudioClassification,
|
28
|
+
RBLNAutoModelForCausalLM,
|
29
|
+
RBLNAutoModelForCTC,
|
30
|
+
RBLNAutoModelForDepthEstimation,
|
31
|
+
RBLNAutoModelForImageClassification,
|
32
|
+
RBLNAutoModelForMaskedLM,
|
33
|
+
RBLNAutoModelForQuestionAnswering,
|
34
|
+
RBLNAutoModelForSeq2SeqLM,
|
35
|
+
RBLNAutoModelForSequenceClassification,
|
36
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
37
|
+
RBLNAutoModelForVision2Seq,
|
38
|
+
)
|
39
|
+
from .bart import RBLNBartModel
|
40
|
+
from .bert import RBLNBertModel
|
41
|
+
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
25
42
|
from .dpt import RBLNDPTForDepthEstimation
|
26
43
|
from .gemma import RBLNGemmaForCausalLM
|
27
44
|
from .gpt2 import RBLNGPT2LMHeadModel
|
28
45
|
from .llama import RBLNLlamaForCausalLM
|
46
|
+
from .llava_next import RBLNLlavaNextForConditionalGeneration
|
29
47
|
from .midm import RBLNMidmLMHeadModel
|
30
48
|
from .mistral import RBLNMistralForCausalLM
|
49
|
+
from .phi import RBLNPhiForCausalLM
|
31
50
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
32
51
|
from .whisper import RBLNWhisperForConditionalGeneration
|
33
52
|
from .xlm_roberta import RBLNXLMRobertaModel
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .modeling_auto import (
|
2
|
+
RBLNAutoModel,
|
3
|
+
RBLNAutoModelForAudioClassification,
|
4
|
+
RBLNAutoModelForCausalLM,
|
5
|
+
RBLNAutoModelForCTC,
|
6
|
+
RBLNAutoModelForDepthEstimation,
|
7
|
+
RBLNAutoModelForImageClassification,
|
8
|
+
RBLNAutoModelForMaskedLM,
|
9
|
+
RBLNAutoModelForQuestionAnswering,
|
10
|
+
RBLNAutoModelForSeq2SeqLM,
|
11
|
+
RBLNAutoModelForSequenceClassification,
|
12
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
13
|
+
RBLNAutoModelForVision2Seq,
|
14
|
+
)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import importlib
|
25
|
+
|
26
|
+
from transformers import AutoConfig
|
27
|
+
|
28
|
+
|
29
|
+
class _BaseAutoModelClass:
|
30
|
+
# Base class for auto models.
|
31
|
+
_model_mapping = None
|
32
|
+
|
33
|
+
def __init__(self, *args, **kwargs):
|
34
|
+
raise EnvironmentError(
|
35
|
+
f"{self.__class__.__name__} is designed to be instantiated "
|
36
|
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
37
|
+
f"`{self.__class__.__name__}.from_config(config)` methods."
|
38
|
+
)
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def get_rbln_cls(
|
42
|
+
cls,
|
43
|
+
model_id,
|
44
|
+
*args,
|
45
|
+
**kwargs,
|
46
|
+
):
|
47
|
+
# kwargs.update({"return_unused_kwargs": True})
|
48
|
+
config = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, **kwargs)[0]
|
49
|
+
|
50
|
+
if len(config.architectures) > 1:
|
51
|
+
raise ValueError(
|
52
|
+
f"Model with ID '{model_id}' has multiple architectures defined in the configuration: "
|
53
|
+
f"{config.architectures}. `_BaseAutoModelClass` require exactly one architecture. "
|
54
|
+
)
|
55
|
+
|
56
|
+
architecture_name = config.architectures[0]
|
57
|
+
if architecture_name not in cls._model_mapping.values():
|
58
|
+
raise ValueError(
|
59
|
+
f"The 'RBLN{architecture_name}' architecture is not supported by `{cls.__name__}.from_pretrained()`."
|
60
|
+
"Please use the appropriate class's `from_pretrained()` method to load this model."
|
61
|
+
)
|
62
|
+
|
63
|
+
rbln_class_name = "RBLN" + architecture_name
|
64
|
+
module = importlib.import_module("optimum.rbln")
|
65
|
+
|
66
|
+
try:
|
67
|
+
rbln_cls = getattr(module, rbln_class_name)
|
68
|
+
except AttributeError as e:
|
69
|
+
raise AttributeError(
|
70
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{model_id}'. "
|
71
|
+
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
72
|
+
) from e
|
73
|
+
|
74
|
+
return rbln_cls
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def from_pretrained(
|
78
|
+
cls,
|
79
|
+
model_id,
|
80
|
+
*args,
|
81
|
+
**kwargs,
|
82
|
+
):
|
83
|
+
rbln_cls = cls.get_rbln_cls(model_id, *args, **kwargs)
|
84
|
+
return rbln_cls.from_pretrained(model_id, *args, **kwargs)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from transformers.models.auto.modeling_auto import (
|
25
|
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
26
|
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
27
|
+
MODEL_FOR_CTC_MAPPING_NAMES,
|
28
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
|
29
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
30
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
31
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
32
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
33
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
34
|
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
35
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
36
|
+
MODEL_MAPPING_NAMES,
|
37
|
+
)
|
38
|
+
|
39
|
+
from .auto_factory import _BaseAutoModelClass
|
40
|
+
|
41
|
+
|
42
|
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
|
43
|
+
{
|
44
|
+
"midm": "MidmLMHeadModel",
|
45
|
+
}
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
class RBLNAutoModel(_BaseAutoModelClass):
|
50
|
+
_model_mapping = MODEL_MAPPING_NAMES
|
51
|
+
|
52
|
+
|
53
|
+
class RBLNAutoModelForCTC(_BaseAutoModelClass):
|
54
|
+
_model_mapping = MODEL_FOR_CTC_MAPPING_NAMES
|
55
|
+
|
56
|
+
|
57
|
+
class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
|
58
|
+
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
59
|
+
|
60
|
+
|
61
|
+
class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
62
|
+
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
63
|
+
|
64
|
+
|
65
|
+
class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
66
|
+
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
67
|
+
|
68
|
+
|
69
|
+
class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
|
70
|
+
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
71
|
+
|
72
|
+
|
73
|
+
class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
|
74
|
+
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
75
|
+
|
76
|
+
|
77
|
+
class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
78
|
+
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
79
|
+
|
80
|
+
|
81
|
+
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
82
|
+
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
83
|
+
|
84
|
+
|
85
|
+
class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
|
86
|
+
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
87
|
+
|
88
|
+
|
89
|
+
class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
|
90
|
+
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
91
|
+
|
92
|
+
|
93
|
+
class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
94
|
+
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
@@ -54,6 +54,7 @@ class _BartAttention(BartAttention):
|
|
54
54
|
past_key_value: Tuple[torch.Tensor],
|
55
55
|
attention_mask: torch.Tensor,
|
56
56
|
cache_position: torch.Tensor,
|
57
|
+
batch_index: torch.Tensor,
|
57
58
|
key_value_states: Optional[torch.Tensor] = None,
|
58
59
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
59
60
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -72,28 +73,83 @@ class _BartAttention(BartAttention):
|
|
72
73
|
else:
|
73
74
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
74
75
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
75
|
-
key_states = past_key_value[0].slice_scatter(
|
76
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
77
|
-
)
|
78
|
-
value_states = past_key_value[1].slice_scatter(
|
79
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
80
|
-
)
|
81
76
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
77
|
+
if cache_position.dim() > 0:
|
78
|
+
proj_shape = (bsz, self.num_heads, -1, self.head_dim)
|
79
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
80
|
+
key_states = key_states.reshape(*proj_shape)
|
81
|
+
value_states = value_states.reshape(*proj_shape)
|
82
|
+
|
83
|
+
all_key_states = []
|
84
|
+
all_value_states = []
|
85
|
+
all_attn_output = []
|
86
|
+
for b in range(bsz):
|
87
|
+
batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
|
88
|
+
batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
|
89
|
+
batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
|
90
|
+
batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
|
91
|
+
if not is_cross_attention:
|
92
|
+
batch_key_states = (
|
93
|
+
past_key_value[0][b]
|
94
|
+
.unsqueeze(0)
|
95
|
+
.unsqueeze(2)
|
96
|
+
.slice_scatter(
|
97
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
98
|
+
)
|
99
|
+
)
|
100
|
+
batch_value_states = (
|
101
|
+
past_key_value[1][b]
|
102
|
+
.unsqueeze(0)
|
103
|
+
.unsqueeze(2)
|
104
|
+
.slice_scatter(
|
105
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
106
|
+
)
|
107
|
+
)
|
108
|
+
attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
|
109
|
+
attn_weights = attn_weights + batch_attention_mask
|
110
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
111
|
+
|
112
|
+
attn_output = torch.matmul(attn_weights, batch_value_states)
|
113
|
+
attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
|
114
|
+
attn_output = attn_output.transpose(1, 2)
|
115
|
+
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
116
|
+
all_key_states.append(batch_key_states)
|
117
|
+
all_value_states.append(batch_value_states)
|
118
|
+
all_attn_output.append(attn_output)
|
119
|
+
key_states = torch.cat(all_key_states, dim=0).squeeze(2)
|
120
|
+
value_states = torch.cat(all_value_states, dim=0).squeeze(2)
|
121
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
122
|
+
|
123
|
+
else:
|
124
|
+
if batch_index is None or batch_index == -1:
|
125
|
+
batch_index = 0
|
126
|
+
|
127
|
+
if not is_cross_attention:
|
128
|
+
key_states = past_key_value[0].slice_scatter(
|
129
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
130
|
+
)
|
131
|
+
value_states = past_key_value[1].slice_scatter(
|
132
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
133
|
+
)
|
134
|
+
|
135
|
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
136
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
137
|
+
key_states = key_states.reshape(*proj_shape)
|
138
|
+
value_states = value_states.reshape(*proj_shape)
|
139
|
+
|
140
|
+
src_len = key_states.size(1)
|
141
|
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
142
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
143
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
144
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
145
|
+
|
146
|
+
attn_output = torch.bmm(attn_weights, value_states)
|
147
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
148
|
+
attn_output = attn_output.transpose(1, 2)
|
149
|
+
key_states = key_states.unsqueeze(0)
|
150
|
+
value_states = value_states.unsqueeze(0)
|
151
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
152
|
+
|
97
153
|
attn_output = self.out_proj(attn_output)
|
98
154
|
|
99
155
|
present_key_value = (key_states, value_states)
|
@@ -108,6 +164,7 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
108
164
|
past_key_value: Tuple[torch.Tensor],
|
109
165
|
attention_mask: torch.Tensor,
|
110
166
|
cache_position: torch.Tensor,
|
167
|
+
batch_index: torch.Tensor,
|
111
168
|
key_value_states: Optional[torch.Tensor] = None,
|
112
169
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
113
170
|
bsz, tgt_len, _ = hidden_states.size()
|
@@ -126,23 +183,70 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
126
183
|
else:
|
127
184
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
128
185
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
129
|
-
key_states = past_key_value[0].slice_scatter(
|
130
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
131
|
-
)
|
132
|
-
value_states = past_key_value[1].slice_scatter(
|
133
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
134
|
-
)
|
135
186
|
|
136
187
|
query_states = self._shape(query_states, tgt_len, bsz)
|
137
188
|
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
189
|
+
if (batch_index is None or batch_index == -1) and bsz > 1:
|
190
|
+
all_key_states = []
|
191
|
+
all_value_states = []
|
192
|
+
all_attn_output = []
|
193
|
+
|
194
|
+
for b in range(bsz):
|
195
|
+
batch_query_states = query_states[b].unsqueeze(0)
|
196
|
+
batch_attention_mask = attention_mask[b].unsqueeze(0)
|
197
|
+
batch_key_states = key_states[b].unsqueeze(0)
|
198
|
+
batch_value_states = value_states[b].unsqueeze(0)
|
199
|
+
|
200
|
+
if not is_cross_attention:
|
201
|
+
batch_key_states = (
|
202
|
+
past_key_value[0][b]
|
203
|
+
.unsqueeze(0)
|
204
|
+
.slice_scatter(
|
205
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
206
|
+
)
|
207
|
+
)
|
208
|
+
batch_value_states = (
|
209
|
+
past_key_value[1][b]
|
210
|
+
.unsqueeze(0)
|
211
|
+
.slice_scatter(
|
212
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
213
|
+
)
|
214
|
+
)
|
215
|
+
|
216
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
217
|
+
batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
|
218
|
+
)
|
219
|
+
attn_output = attn_output.transpose(1, 2)
|
220
|
+
attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
|
221
|
+
all_key_states.append(batch_key_states)
|
222
|
+
all_value_states.append(batch_value_states)
|
223
|
+
all_attn_output.append(attn_output)
|
224
|
+
|
225
|
+
key_states = torch.cat(all_key_states, dim=0)
|
226
|
+
value_states = torch.cat(all_value_states, dim=0)
|
227
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
228
|
+
|
229
|
+
else:
|
230
|
+
if batch_index is None or batch_index == -1:
|
231
|
+
batch_index = 0
|
232
|
+
|
233
|
+
if not is_cross_attention:
|
234
|
+
key_states = past_key_value[0].slice_scatter(
|
235
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
236
|
+
)
|
237
|
+
value_states = past_key_value[1].slice_scatter(
|
238
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
239
|
+
)
|
240
|
+
|
241
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
242
|
+
query_states,
|
243
|
+
key_states,
|
244
|
+
value_states,
|
245
|
+
attn_mask=attention_mask,
|
246
|
+
)
|
247
|
+
attn_output = attn_output.transpose(1, 2)
|
248
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
249
|
+
|
146
250
|
attn_output = self.out_proj(attn_output)
|
147
251
|
|
148
252
|
present_key_value = (key_states, value_states)
|
@@ -162,6 +266,7 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
162
266
|
encoder_hidden_states: torch.Tensor,
|
163
267
|
past_key_value: Tuple[torch.Tensor],
|
164
268
|
cache_position: torch.Tensor,
|
269
|
+
batch_ids: torch.Tensor,
|
165
270
|
attn_impl: str = "eager",
|
166
271
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
167
272
|
# Self Attention Block
|
@@ -174,6 +279,7 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
174
279
|
past_key_value=self_attn_past_key_value,
|
175
280
|
attention_mask=attention_mask,
|
176
281
|
cache_position=cache_position,
|
282
|
+
batch_index=batch_ids,
|
177
283
|
)
|
178
284
|
hidden_states = residual + hidden_states
|
179
285
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
@@ -189,6 +295,7 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
189
295
|
past_key_value=cross_attn_past_key_value,
|
190
296
|
attention_mask=encoder_attention_mask,
|
191
297
|
cache_position=cache_position,
|
298
|
+
batch_index=batch_ids,
|
192
299
|
)
|
193
300
|
hidden_states = residual + hidden_states
|
194
301
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
@@ -213,14 +320,31 @@ class _BartDecoder(BartDecoder):
|
|
213
320
|
encoder_hidden_states: torch.Tensor,
|
214
321
|
past_key_values: torch.Tensor,
|
215
322
|
cache_position: torch.Tensor,
|
323
|
+
batch_ids: torch.Tensor,
|
216
324
|
attn_impl: str = "eager",
|
217
325
|
):
|
218
326
|
# embedding
|
219
|
-
|
220
|
-
|
327
|
+
# thkim fix : transformers == 4.44.2 compile
|
328
|
+
if hasattr(self, "embed_scale"):
|
329
|
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
330
|
+
else:
|
331
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
332
|
+
|
333
|
+
if cache_position.dim() == 0:
|
334
|
+
positions_idx = cache_position + self.embed_positions.offset
|
335
|
+
positions = self.embed_positions.weight[positions_idx]
|
336
|
+
hidden_states = inputs_embeds + positions
|
337
|
+
else:
|
338
|
+
hidden_all = []
|
339
|
+
for i in range(input_ids.shape[0]):
|
340
|
+
# cache position [N,1]
|
341
|
+
positions_idx = cache_position[i]
|
342
|
+
position_weight = self.embed_positions.weight[2:]
|
343
|
+
position = position_weight[positions_idx]
|
344
|
+
tmp_hidden = position + inputs_embeds[i]
|
345
|
+
hidden_all.append(tmp_hidden)
|
346
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
221
347
|
|
222
|
-
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
223
|
-
hidden_states = inputs_embeds + positions
|
224
348
|
hidden_states = self.layernorm_embedding(hidden_states)
|
225
349
|
|
226
350
|
# prepare attn_mask
|
@@ -230,14 +354,14 @@ class _BartDecoder(BartDecoder):
|
|
230
354
|
attention_mask, input_shape, inputs_embeds, cache_position
|
231
355
|
)
|
232
356
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
233
|
-
encoder_attention_mask,
|
357
|
+
encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
|
234
358
|
)
|
235
359
|
else:
|
236
360
|
attention_mask = _prepare_4d_causal_attention_mask(
|
237
361
|
attention_mask, input_shape, inputs_embeds, cache_position
|
238
362
|
)
|
239
363
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
240
|
-
encoder_attention_mask,
|
364
|
+
encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
|
241
365
|
)
|
242
366
|
|
243
367
|
# iterate decoder_layer
|
@@ -252,6 +376,7 @@ class _BartDecoder(BartDecoder):
|
|
252
376
|
encoder_attention_mask=encoder_attention_mask,
|
253
377
|
past_key_value=past_key_value,
|
254
378
|
cache_position=cache_position,
|
379
|
+
batch_ids=batch_ids,
|
255
380
|
attn_impl=attn_impl,
|
256
381
|
)
|
257
382
|
hidden_states = layer_outputs[0]
|
@@ -277,9 +402,14 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
277
402
|
attention_mask: torch.Tensor,
|
278
403
|
encoder_attention_mask: torch.Tensor,
|
279
404
|
cache_position: torch.Tensor,
|
405
|
+
batch_position: torch.Tensor,
|
280
406
|
self_kv_cache: torch.Tensor,
|
281
407
|
cross_kv_cache: torch.Tensor,
|
282
408
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
409
|
+
if input_ids.shape[1] == 1:
|
410
|
+
rbln_batch_position = None
|
411
|
+
else:
|
412
|
+
rbln_batch_position = batch_position
|
283
413
|
# prepare past_key_values
|
284
414
|
kv_cache = ()
|
285
415
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -291,7 +421,6 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
291
421
|
cross_kv_cache[i + 1],
|
292
422
|
),
|
293
423
|
)
|
294
|
-
|
295
424
|
# decode
|
296
425
|
decoder_outputs = _BartDecoder.forward(
|
297
426
|
self.decoder,
|
@@ -302,6 +431,7 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
302
431
|
past_key_values=kv_cache,
|
303
432
|
encoder_hidden_states=torch.tensor([1]),
|
304
433
|
attn_impl=self.config._attn_implementation,
|
434
|
+
batch_ids=rbln_batch_position,
|
305
435
|
)
|
306
436
|
sequence_output = decoder_outputs[0]
|
307
437
|
lm_logits = self.lm_head(sequence_output)
|
@@ -314,7 +444,7 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
314
444
|
self_kv_cache.append(past_key_values[i][1])
|
315
445
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
316
446
|
|
317
|
-
return lm_logits, self_kv_cache
|
447
|
+
return lm_logits, self_kv_cache, batch_position
|
318
448
|
|
319
449
|
|
320
450
|
class BartEncoderWrapper(torch.nn.Module):
|
@@ -330,7 +460,13 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
330
460
|
self.num_heads = self.config.decoder_attention_heads
|
331
461
|
self.d_kv = self.config.d_model // self.num_heads
|
332
462
|
|
333
|
-
def forward(
|
463
|
+
def forward(
|
464
|
+
self,
|
465
|
+
input_ids: torch.LongTensor,
|
466
|
+
attention_mask: torch.LongTensor,
|
467
|
+
cross_key_value: torch.Tensor = None,
|
468
|
+
batch_idx: torch.Tensor = None,
|
469
|
+
) -> Tuple[torch.Tensor]:
|
334
470
|
encoder_batch_size = input_ids.shape[0]
|
335
471
|
decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
|
336
472
|
|
@@ -348,7 +484,7 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
348
484
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
349
485
|
dummy_past_key_value.append(layer_pkv)
|
350
486
|
|
351
|
-
decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.
|
487
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.float32)
|
352
488
|
decoder_attention_mask[:, :1] = 1
|
353
489
|
|
354
490
|
decoder_outputs = _BartDecoder.forward(
|
@@ -359,14 +495,17 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
359
495
|
cache_position=torch.tensor(0, dtype=torch.int32),
|
360
496
|
encoder_hidden_states=last_hidden_states,
|
361
497
|
past_key_values=dummy_past_key_value,
|
498
|
+
batch_ids=torch.tensor(0, dtype=torch.int32),
|
362
499
|
attn_impl=self.config._attn_implementation,
|
363
500
|
)
|
364
501
|
first_past_kv = decoder_outputs[1]
|
365
502
|
|
366
|
-
# 3. return cross_key_values to recurrence port. fyi (enc_ir.outputs[0] -> dec_ir.inputs[5])
|
367
503
|
encoder_kv = []
|
368
|
-
for
|
369
|
-
encoder_kv.append(
|
370
|
-
|
504
|
+
for i in range(self.model.config.decoder_layers):
|
505
|
+
encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
|
506
|
+
encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
|
507
|
+
encoder_kv = torch.cat(encoder_kv, dim=0)
|
508
|
+
|
509
|
+
cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
|
371
510
|
|
372
|
-
return
|
511
|
+
return cross_key_value
|