optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -30,17 +30,34 @@ _import_structure = {
30
30
  "cache_utils": ["RebelDynamicCache"],
31
31
  "generation": ["BatchTextIteratorStreamer"],
32
32
  "models": [
33
+ "RBLNAutoModel",
34
+ "RBLNAutoModelForAudioClassification",
35
+ "RBLNAutoModelForCausalLM",
36
+ "RBLNAutoModelForCTC",
37
+ "RBLNAutoModelForDepthEstimation",
38
+ "RBLNAutoModelForImageClassification",
39
+ "RBLNAutoModelForMaskedLM",
40
+ "RBLNAutoModelForQuestionAnswering",
41
+ "RBLNAutoModelForSeq2SeqLM",
42
+ "RBLNAutoModelForSequenceClassification",
43
+ "RBLNAutoModelForSpeechSeq2Seq",
44
+ "RBLNAutoModelForVision2Seq",
45
+ "RBLNBartModel",
46
+ "RBLNBertModel",
33
47
  "RBLNCLIPTextModel",
34
48
  "RBLNCLIPTextModelWithProjection",
49
+ "RBLNCLIPVisionModel",
35
50
  "RBLNDPTForDepthEstimation",
36
51
  "RBLNGemmaForCausalLM",
37
52
  "RBLNGPT2LMHeadModel",
38
53
  "RBLNWav2Vec2ForCTC",
39
54
  "RBLNWhisperForConditionalGeneration",
40
55
  "RBLNLlamaForCausalLM",
56
+ "RBLNPhiForCausalLM",
57
+ "RBLNLlavaNextForConditionalGeneration",
41
58
  "RBLNMidmLMHeadModel",
42
- "RBLNMistralForCausalLM",
43
59
  "RBLNXLMRobertaModel",
60
+ "RBLNMistralForCausalLM",
44
61
  ],
45
62
  }
46
63
 
@@ -48,14 +65,31 @@ if TYPE_CHECKING:
48
65
  from .cache_utils import RebelDynamicCache
49
66
  from .generation import BatchTextIteratorStreamer
50
67
  from .models import (
68
+ RBLNAutoModel,
69
+ RBLNAutoModelForAudioClassification,
70
+ RBLNAutoModelForCausalLM,
71
+ RBLNAutoModelForCTC,
72
+ RBLNAutoModelForDepthEstimation,
73
+ RBLNAutoModelForImageClassification,
74
+ RBLNAutoModelForMaskedLM,
75
+ RBLNAutoModelForQuestionAnswering,
76
+ RBLNAutoModelForSeq2SeqLM,
77
+ RBLNAutoModelForSequenceClassification,
78
+ RBLNAutoModelForSpeechSeq2Seq,
79
+ RBLNAutoModelForVision2Seq,
80
+ RBLNBartModel,
81
+ RBLNBertModel,
51
82
  RBLNCLIPTextModel,
52
83
  RBLNCLIPTextModelWithProjection,
84
+ RBLNCLIPVisionModel,
53
85
  RBLNDPTForDepthEstimation,
54
86
  RBLNGemmaForCausalLM,
55
87
  RBLNGPT2LMHeadModel,
56
88
  RBLNLlamaForCausalLM,
89
+ RBLNLlavaNextForConditionalGeneration,
57
90
  RBLNMidmLMHeadModel,
58
91
  RBLNMistralForCausalLM,
92
+ RBLNPhiForCausalLM,
59
93
  RBLNWav2Vec2ForCTC,
60
94
  RBLNWhisperForConditionalGeneration,
61
95
  RBLNXLMRobertaModel,
@@ -21,13 +21,32 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
24
+
25
+ from .auto import (
26
+ RBLNAutoModel,
27
+ RBLNAutoModelForAudioClassification,
28
+ RBLNAutoModelForCausalLM,
29
+ RBLNAutoModelForCTC,
30
+ RBLNAutoModelForDepthEstimation,
31
+ RBLNAutoModelForImageClassification,
32
+ RBLNAutoModelForMaskedLM,
33
+ RBLNAutoModelForQuestionAnswering,
34
+ RBLNAutoModelForSeq2SeqLM,
35
+ RBLNAutoModelForSequenceClassification,
36
+ RBLNAutoModelForSpeechSeq2Seq,
37
+ RBLNAutoModelForVision2Seq,
38
+ )
39
+ from .bart import RBLNBartModel
40
+ from .bert import RBLNBertModel
41
+ from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
25
42
  from .dpt import RBLNDPTForDepthEstimation
26
43
  from .gemma import RBLNGemmaForCausalLM
27
44
  from .gpt2 import RBLNGPT2LMHeadModel
28
45
  from .llama import RBLNLlamaForCausalLM
46
+ from .llava_next import RBLNLlavaNextForConditionalGeneration
29
47
  from .midm import RBLNMidmLMHeadModel
30
48
  from .mistral import RBLNMistralForCausalLM
49
+ from .phi import RBLNPhiForCausalLM
31
50
  from .wav2vec2 import RBLNWav2Vec2ForCTC
32
51
  from .whisper import RBLNWhisperForConditionalGeneration
33
52
  from .xlm_roberta import RBLNXLMRobertaModel
@@ -0,0 +1,14 @@
1
+ from .modeling_auto import (
2
+ RBLNAutoModel,
3
+ RBLNAutoModelForAudioClassification,
4
+ RBLNAutoModelForCausalLM,
5
+ RBLNAutoModelForCTC,
6
+ RBLNAutoModelForDepthEstimation,
7
+ RBLNAutoModelForImageClassification,
8
+ RBLNAutoModelForMaskedLM,
9
+ RBLNAutoModelForQuestionAnswering,
10
+ RBLNAutoModelForSeq2SeqLM,
11
+ RBLNAutoModelForSequenceClassification,
12
+ RBLNAutoModelForSpeechSeq2Seq,
13
+ RBLNAutoModelForVision2Seq,
14
+ )
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import importlib
25
+
26
+ from transformers import AutoConfig
27
+
28
+
29
+ class _BaseAutoModelClass:
30
+ # Base class for auto models.
31
+ _model_mapping = None
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ raise EnvironmentError(
35
+ f"{self.__class__.__name__} is designed to be instantiated "
36
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
37
+ f"`{self.__class__.__name__}.from_config(config)` methods."
38
+ )
39
+
40
+ @classmethod
41
+ def get_rbln_cls(
42
+ cls,
43
+ model_id,
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+ # kwargs.update({"return_unused_kwargs": True})
48
+ config = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, **kwargs)[0]
49
+
50
+ if len(config.architectures) > 1:
51
+ raise ValueError(
52
+ f"Model with ID '{model_id}' has multiple architectures defined in the configuration: "
53
+ f"{config.architectures}. `_BaseAutoModelClass` require exactly one architecture. "
54
+ )
55
+
56
+ architecture_name = config.architectures[0]
57
+ if architecture_name not in cls._model_mapping.values():
58
+ raise ValueError(
59
+ f"The 'RBLN{architecture_name}' architecture is not supported by `{cls.__name__}.from_pretrained()`."
60
+ "Please use the appropriate class's `from_pretrained()` method to load this model."
61
+ )
62
+
63
+ rbln_class_name = "RBLN" + architecture_name
64
+ module = importlib.import_module("optimum.rbln")
65
+
66
+ try:
67
+ rbln_cls = getattr(module, rbln_class_name)
68
+ except AttributeError as e:
69
+ raise AttributeError(
70
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{model_id}'. "
71
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
72
+ ) from e
73
+
74
+ return rbln_cls
75
+
76
+ @classmethod
77
+ def from_pretrained(
78
+ cls,
79
+ model_id,
80
+ *args,
81
+ **kwargs,
82
+ ):
83
+ rbln_cls = cls.get_rbln_cls(model_id, *args, **kwargs)
84
+ return rbln_cls.from_pretrained(model_id, *args, **kwargs)
@@ -0,0 +1,94 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from transformers.models.auto.modeling_auto import (
25
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
26
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
27
+ MODEL_FOR_CTC_MAPPING_NAMES,
28
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
29
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
30
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
31
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
32
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
33
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
34
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
35
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
36
+ MODEL_MAPPING_NAMES,
37
+ )
38
+
39
+ from .auto_factory import _BaseAutoModelClass
40
+
41
+
42
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
43
+ {
44
+ "midm": "MidmLMHeadModel",
45
+ }
46
+ )
47
+
48
+
49
+ class RBLNAutoModel(_BaseAutoModelClass):
50
+ _model_mapping = MODEL_MAPPING_NAMES
51
+
52
+
53
+ class RBLNAutoModelForCTC(_BaseAutoModelClass):
54
+ _model_mapping = MODEL_FOR_CTC_MAPPING_NAMES
55
+
56
+
57
+ class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
58
+ _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
59
+
60
+
61
+ class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
62
+ _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
63
+
64
+
65
+ class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
66
+ _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
67
+
68
+
69
+ class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
70
+ _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
71
+
72
+
73
+ class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
74
+ _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
75
+
76
+
77
+ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
78
+ _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
79
+
80
+
81
+ class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
82
+ _model_mapping = MODEL_FOR_MASKED_LM_MAPPING_NAMES
83
+
84
+
85
+ class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
86
+ _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
87
+
88
+
89
+ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
90
+ _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
91
+
92
+
93
+ class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
94
+ _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
@@ -22,3 +22,4 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .bart_architecture import BartDecoderWrapper, BartEncoderWrapper
25
+ from .modeling_bart import RBLNBartModel
@@ -54,6 +54,7 @@ class _BartAttention(BartAttention):
54
54
  past_key_value: Tuple[torch.Tensor],
55
55
  attention_mask: torch.Tensor,
56
56
  cache_position: torch.Tensor,
57
+ batch_index: torch.Tensor,
57
58
  key_value_states: Optional[torch.Tensor] = None,
58
59
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
59
60
  bsz, tgt_len, _ = hidden_states.size()
@@ -72,28 +73,83 @@ class _BartAttention(BartAttention):
72
73
  else:
73
74
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
74
75
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
75
- key_states = past_key_value[0].slice_scatter(
76
- key_states, dim=2, start=cache_position, end=cache_position + 1
77
- )
78
- value_states = past_key_value[1].slice_scatter(
79
- value_states, dim=2, start=cache_position, end=cache_position + 1
80
- )
81
76
 
82
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
83
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
84
- key_states = key_states.reshape(*proj_shape)
85
- value_states = value_states.reshape(*proj_shape)
86
-
87
- src_len = key_states.size(1)
88
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
89
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
90
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
91
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
92
-
93
- attn_output = torch.bmm(attn_weights, value_states)
94
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
95
- attn_output = attn_output.transpose(1, 2)
96
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
77
+ if cache_position.dim() > 0:
78
+ proj_shape = (bsz, self.num_heads, -1, self.head_dim)
79
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
80
+ key_states = key_states.reshape(*proj_shape)
81
+ value_states = value_states.reshape(*proj_shape)
82
+
83
+ all_key_states = []
84
+ all_value_states = []
85
+ all_attn_output = []
86
+ for b in range(bsz):
87
+ batch_query_states = query_states[b].unsqueeze(0).unsqueeze(2)
88
+ batch_attention_mask = attention_mask[b].unsqueeze(0).unsqueeze(2)
89
+ batch_key_states = key_states[b].unsqueeze(0).unsqueeze(2)
90
+ batch_value_states = value_states[b].unsqueeze(0).unsqueeze(2)
91
+ if not is_cross_attention:
92
+ batch_key_states = (
93
+ past_key_value[0][b]
94
+ .unsqueeze(0)
95
+ .unsqueeze(2)
96
+ .slice_scatter(
97
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
98
+ )
99
+ )
100
+ batch_value_states = (
101
+ past_key_value[1][b]
102
+ .unsqueeze(0)
103
+ .unsqueeze(2)
104
+ .slice_scatter(
105
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
106
+ )
107
+ )
108
+ attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4))
109
+ attn_weights = attn_weights + batch_attention_mask
110
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
111
+
112
+ attn_output = torch.matmul(attn_weights, batch_value_states)
113
+ attn_output = attn_output.view(1, self.num_heads, tgt_len, self.head_dim)
114
+ attn_output = attn_output.transpose(1, 2)
115
+ attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
116
+ all_key_states.append(batch_key_states)
117
+ all_value_states.append(batch_value_states)
118
+ all_attn_output.append(attn_output)
119
+ key_states = torch.cat(all_key_states, dim=0).squeeze(2)
120
+ value_states = torch.cat(all_value_states, dim=0).squeeze(2)
121
+ attn_output = torch.cat(all_attn_output, dim=0)
122
+
123
+ else:
124
+ if batch_index is None or batch_index == -1:
125
+ batch_index = 0
126
+
127
+ if not is_cross_attention:
128
+ key_states = past_key_value[0].slice_scatter(
129
+ key_states, dim=2, start=cache_position, end=cache_position + 1
130
+ )
131
+ value_states = past_key_value[1].slice_scatter(
132
+ value_states, dim=2, start=cache_position, end=cache_position + 1
133
+ )
134
+
135
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
136
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
137
+ key_states = key_states.reshape(*proj_shape)
138
+ value_states = value_states.reshape(*proj_shape)
139
+
140
+ src_len = key_states.size(1)
141
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
142
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
143
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
144
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
145
+
146
+ attn_output = torch.bmm(attn_weights, value_states)
147
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
148
+ attn_output = attn_output.transpose(1, 2)
149
+ key_states = key_states.unsqueeze(0)
150
+ value_states = value_states.unsqueeze(0)
151
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
152
+
97
153
  attn_output = self.out_proj(attn_output)
98
154
 
99
155
  present_key_value = (key_states, value_states)
@@ -108,6 +164,7 @@ class _BartSdpaAttention(BartSdpaAttention):
108
164
  past_key_value: Tuple[torch.Tensor],
109
165
  attention_mask: torch.Tensor,
110
166
  cache_position: torch.Tensor,
167
+ batch_index: torch.Tensor,
111
168
  key_value_states: Optional[torch.Tensor] = None,
112
169
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
113
170
  bsz, tgt_len, _ = hidden_states.size()
@@ -126,23 +183,70 @@ class _BartSdpaAttention(BartSdpaAttention):
126
183
  else:
127
184
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
128
185
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
129
- key_states = past_key_value[0].slice_scatter(
130
- key_states, dim=2, start=cache_position, end=cache_position + 1
131
- )
132
- value_states = past_key_value[1].slice_scatter(
133
- value_states, dim=2, start=cache_position, end=cache_position + 1
134
- )
135
186
 
136
187
  query_states = self._shape(query_states, tgt_len, bsz)
137
188
 
138
- attn_output = torch.nn.functional.scaled_dot_product_attention(
139
- query_states,
140
- key_states,
141
- value_states,
142
- attn_mask=attention_mask,
143
- )
144
- attn_output = attn_output.transpose(1, 2)
145
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
189
+ if (batch_index is None or batch_index == -1) and bsz > 1:
190
+ all_key_states = []
191
+ all_value_states = []
192
+ all_attn_output = []
193
+
194
+ for b in range(bsz):
195
+ batch_query_states = query_states[b].unsqueeze(0)
196
+ batch_attention_mask = attention_mask[b].unsqueeze(0)
197
+ batch_key_states = key_states[b].unsqueeze(0)
198
+ batch_value_states = value_states[b].unsqueeze(0)
199
+
200
+ if not is_cross_attention:
201
+ batch_key_states = (
202
+ past_key_value[0][b]
203
+ .unsqueeze(0)
204
+ .slice_scatter(
205
+ batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
206
+ )
207
+ )
208
+ batch_value_states = (
209
+ past_key_value[1][b]
210
+ .unsqueeze(0)
211
+ .slice_scatter(
212
+ batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
213
+ )
214
+ )
215
+
216
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
217
+ batch_query_states, batch_key_states, batch_value_states, attn_mask=batch_attention_mask
218
+ )
219
+ attn_output = attn_output.transpose(1, 2)
220
+ attn_output = attn_output.reshape(1, tgt_len, self.embed_dim)
221
+ all_key_states.append(batch_key_states)
222
+ all_value_states.append(batch_value_states)
223
+ all_attn_output.append(attn_output)
224
+
225
+ key_states = torch.cat(all_key_states, dim=0)
226
+ value_states = torch.cat(all_value_states, dim=0)
227
+ attn_output = torch.cat(all_attn_output, dim=0)
228
+
229
+ else:
230
+ if batch_index is None or batch_index == -1:
231
+ batch_index = 0
232
+
233
+ if not is_cross_attention:
234
+ key_states = past_key_value[0].slice_scatter(
235
+ key_states, dim=2, start=cache_position, end=cache_position + 1
236
+ )
237
+ value_states = past_key_value[1].slice_scatter(
238
+ value_states, dim=2, start=cache_position, end=cache_position + 1
239
+ )
240
+
241
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
242
+ query_states,
243
+ key_states,
244
+ value_states,
245
+ attn_mask=attention_mask,
246
+ )
247
+ attn_output = attn_output.transpose(1, 2)
248
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
249
+
146
250
  attn_output = self.out_proj(attn_output)
147
251
 
148
252
  present_key_value = (key_states, value_states)
@@ -162,6 +266,7 @@ class _BartDecoderLayer(BartDecoderLayer):
162
266
  encoder_hidden_states: torch.Tensor,
163
267
  past_key_value: Tuple[torch.Tensor],
164
268
  cache_position: torch.Tensor,
269
+ batch_ids: torch.Tensor,
165
270
  attn_impl: str = "eager",
166
271
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
167
272
  # Self Attention Block
@@ -174,6 +279,7 @@ class _BartDecoderLayer(BartDecoderLayer):
174
279
  past_key_value=self_attn_past_key_value,
175
280
  attention_mask=attention_mask,
176
281
  cache_position=cache_position,
282
+ batch_index=batch_ids,
177
283
  )
178
284
  hidden_states = residual + hidden_states
179
285
  hidden_states = self.self_attn_layer_norm(hidden_states)
@@ -189,6 +295,7 @@ class _BartDecoderLayer(BartDecoderLayer):
189
295
  past_key_value=cross_attn_past_key_value,
190
296
  attention_mask=encoder_attention_mask,
191
297
  cache_position=cache_position,
298
+ batch_index=batch_ids,
192
299
  )
193
300
  hidden_states = residual + hidden_states
194
301
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
@@ -213,14 +320,31 @@ class _BartDecoder(BartDecoder):
213
320
  encoder_hidden_states: torch.Tensor,
214
321
  past_key_values: torch.Tensor,
215
322
  cache_position: torch.Tensor,
323
+ batch_ids: torch.Tensor,
216
324
  attn_impl: str = "eager",
217
325
  ):
218
326
  # embedding
219
- positions_idx = cache_position + self.embed_positions.offset
220
- positions = self.embed_positions.weight[positions_idx]
327
+ # thkim fix : transformers == 4.44.2 compile
328
+ if hasattr(self, "embed_scale"):
329
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
330
+ else:
331
+ inputs_embeds = self.embed_tokens(input_ids)
332
+
333
+ if cache_position.dim() == 0:
334
+ positions_idx = cache_position + self.embed_positions.offset
335
+ positions = self.embed_positions.weight[positions_idx]
336
+ hidden_states = inputs_embeds + positions
337
+ else:
338
+ hidden_all = []
339
+ for i in range(input_ids.shape[0]):
340
+ # cache position [N,1]
341
+ positions_idx = cache_position[i]
342
+ position_weight = self.embed_positions.weight[2:]
343
+ position = position_weight[positions_idx]
344
+ tmp_hidden = position + inputs_embeds[i]
345
+ hidden_all.append(tmp_hidden)
346
+ hidden_states = torch.stack(hidden_all, dim=0)
221
347
 
222
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
223
- hidden_states = inputs_embeds + positions
224
348
  hidden_states = self.layernorm_embedding(hidden_states)
225
349
 
226
350
  # prepare attn_mask
@@ -230,14 +354,14 @@ class _BartDecoder(BartDecoder):
230
354
  attention_mask, input_shape, inputs_embeds, cache_position
231
355
  )
232
356
  encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
233
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
357
+ encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
234
358
  )
235
359
  else:
236
360
  attention_mask = _prepare_4d_causal_attention_mask(
237
361
  attention_mask, input_shape, inputs_embeds, cache_position
238
362
  )
239
363
  encoder_attention_mask = _prepare_4d_attention_mask(
240
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
364
+ encoder_attention_mask, torch.float32, tgt_len=input_shape[-1]
241
365
  )
242
366
 
243
367
  # iterate decoder_layer
@@ -252,6 +376,7 @@ class _BartDecoder(BartDecoder):
252
376
  encoder_attention_mask=encoder_attention_mask,
253
377
  past_key_value=past_key_value,
254
378
  cache_position=cache_position,
379
+ batch_ids=batch_ids,
255
380
  attn_impl=attn_impl,
256
381
  )
257
382
  hidden_states = layer_outputs[0]
@@ -277,9 +402,14 @@ class BartDecoderWrapper(torch.nn.Module):
277
402
  attention_mask: torch.Tensor,
278
403
  encoder_attention_mask: torch.Tensor,
279
404
  cache_position: torch.Tensor,
405
+ batch_position: torch.Tensor,
280
406
  self_kv_cache: torch.Tensor,
281
407
  cross_kv_cache: torch.Tensor,
282
408
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
409
+ if input_ids.shape[1] == 1:
410
+ rbln_batch_position = None
411
+ else:
412
+ rbln_batch_position = batch_position
283
413
  # prepare past_key_values
284
414
  kv_cache = ()
285
415
  for i in range(0, self.num_layers * 2, 2):
@@ -291,7 +421,6 @@ class BartDecoderWrapper(torch.nn.Module):
291
421
  cross_kv_cache[i + 1],
292
422
  ),
293
423
  )
294
-
295
424
  # decode
296
425
  decoder_outputs = _BartDecoder.forward(
297
426
  self.decoder,
@@ -302,6 +431,7 @@ class BartDecoderWrapper(torch.nn.Module):
302
431
  past_key_values=kv_cache,
303
432
  encoder_hidden_states=torch.tensor([1]),
304
433
  attn_impl=self.config._attn_implementation,
434
+ batch_ids=rbln_batch_position,
305
435
  )
306
436
  sequence_output = decoder_outputs[0]
307
437
  lm_logits = self.lm_head(sequence_output)
@@ -314,7 +444,7 @@ class BartDecoderWrapper(torch.nn.Module):
314
444
  self_kv_cache.append(past_key_values[i][1])
315
445
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
316
446
 
317
- return lm_logits, self_kv_cache
447
+ return lm_logits, self_kv_cache, batch_position
318
448
 
319
449
 
320
450
  class BartEncoderWrapper(torch.nn.Module):
@@ -330,7 +460,13 @@ class BartEncoderWrapper(torch.nn.Module):
330
460
  self.num_heads = self.config.decoder_attention_heads
331
461
  self.d_kv = self.config.d_model // self.num_heads
332
462
 
333
- def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> Tuple[torch.Tensor]:
463
+ def forward(
464
+ self,
465
+ input_ids: torch.LongTensor,
466
+ attention_mask: torch.LongTensor,
467
+ cross_key_value: torch.Tensor = None,
468
+ batch_idx: torch.Tensor = None,
469
+ ) -> Tuple[torch.Tensor]:
334
470
  encoder_batch_size = input_ids.shape[0]
335
471
  decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
336
472
 
@@ -348,7 +484,7 @@ class BartEncoderWrapper(torch.nn.Module):
348
484
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
349
485
  dummy_past_key_value.append(layer_pkv)
350
486
 
351
- decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
487
+ decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.float32)
352
488
  decoder_attention_mask[:, :1] = 1
353
489
 
354
490
  decoder_outputs = _BartDecoder.forward(
@@ -359,14 +495,17 @@ class BartEncoderWrapper(torch.nn.Module):
359
495
  cache_position=torch.tensor(0, dtype=torch.int32),
360
496
  encoder_hidden_states=last_hidden_states,
361
497
  past_key_values=dummy_past_key_value,
498
+ batch_ids=torch.tensor(0, dtype=torch.int32),
362
499
  attn_impl=self.config._attn_implementation,
363
500
  )
364
501
  first_past_kv = decoder_outputs[1]
365
502
 
366
- # 3. return cross_key_values to recurrence port. fyi (enc_ir.outputs[0] -> dec_ir.inputs[5])
367
503
  encoder_kv = []
368
- for layer_out in first_past_kv: # for layer
369
- encoder_kv.append(torch.stack(layer_out[2:], dim=0))
370
- encoder_kv = torch.stack(encoder_kv, dim=0)
504
+ for i in range(self.model.config.decoder_layers):
505
+ encoder_kv.append(first_past_kv[i][2].unsqueeze(0))
506
+ encoder_kv.append(first_past_kv[i][3].unsqueeze(0))
507
+ encoder_kv = torch.cat(encoder_kv, dim=0)
508
+
509
+ cross_key_value = cross_key_value.slice_scatter(encoder_kv, dim=1, start=batch_idx, end=batch_idx + 1)
371
510
 
372
- return encoder_kv
511
+ return cross_key_value