optimum-rbln 0.7.4a2__py3-none-any.whl → 0.7.4a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. optimum/rbln/__version__.py +1 -1
  2. optimum/rbln/modeling.py +8 -1
  3. optimum/rbln/modeling_base.py +0 -5
  4. optimum/rbln/ops/__init__.py +3 -7
  5. optimum/rbln/ops/attn.py +271 -207
  6. optimum/rbln/ops/flash_attn.py +161 -67
  7. optimum/rbln/ops/kv_cache_update.py +4 -40
  8. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
  10. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
  11. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
  12. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  13. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  14. optimum/rbln/transformers/models/t5/modeling_t5.py +3 -37
  15. optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
  16. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  17. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
  18. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  19. optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
  20. optimum/rbln/transformers/models/whisper/whisper_architecture.py +20 -32
  21. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/METADATA +1 -1
  22. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/RECORD +24 -24
  23. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/WHEEL +0 -0
  24. {optimum_rbln-0.7.4a2.dist-info → optimum_rbln-0.7.4a4.dist-info}/licenses/LICENSE +0 -0
@@ -12,71 +12,165 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
-
17
15
  import torch
18
- from packaging import version
19
-
20
-
21
- if version.parse(torch.__version__) > version.parse("2.4.0"):
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
-
26
-
27
- @lru_cache
28
- def register_rbln_custom_paged_flash_attention():
29
- torch.library.define(
30
- "rbln_custom_ops::paged_flash_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
32
- )
33
-
34
- @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
- return q
37
-
38
- @register_fake("rbln_custom_ops::paged_flash_attn_decode")
39
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
40
- return q
41
-
42
- torch.library.define(
43
- "rbln_custom_ops::paged_flash_attn_prefill",
44
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
45
- )
46
-
47
- @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
48
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
49
- return q
50
-
51
- @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
52
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
53
- return q
54
-
55
-
56
- @lru_cache
57
- def register_rbln_custom_paged_flash_causal_attention():
58
- torch.library.define(
59
- "rbln_custom_ops::paged_flash_causal_attn_decode",
60
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
61
- )
62
-
63
- @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
64
- def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
65
- return q
66
-
67
- @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
68
- def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
69
- return q
70
-
71
- torch.library.define(
72
- "rbln_custom_ops::paged_flash_causal_attn_prefill",
73
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
74
- )
75
-
76
- @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
77
- def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
78
- return q
79
-
80
- @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
81
- def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
82
- return q
16
+ from torch import Tensor
17
+
18
+
19
+ @torch.library.custom_op(
20
+ "rbln_custom_ops::paged_flash_attn_decode",
21
+ mutates_args=(["kcache", "vcache"]),
22
+ )
23
+ def paged_flash_attn_decode(
24
+ q: Tensor,
25
+ k: Tensor,
26
+ v: Tensor,
27
+ mask: Tensor,
28
+ kcache: Tensor,
29
+ vcache: Tensor,
30
+ seq: Tensor,
31
+ scale: Tensor,
32
+ block_table: Tensor,
33
+ block_size: int,
34
+ partition: int,
35
+ ) -> Tensor:
36
+ """Defines the computation pattern for fused flash attention with KV cache for decoding.
37
+
38
+ Returns a tensor with the same shape as q.
39
+ """
40
+ return torch.empty_like(q)
41
+
42
+
43
+ @paged_flash_attn_decode.register_fake
44
+ def paged_flash_attn_decode_fake(
45
+ q: Tensor,
46
+ k: Tensor,
47
+ v: Tensor,
48
+ mask: Tensor,
49
+ kcache: Tensor,
50
+ vcache: Tensor,
51
+ seq: Tensor,
52
+ scale: Tensor,
53
+ block_table: Tensor,
54
+ block_size: int,
55
+ partition: int,
56
+ ) -> Tensor:
57
+ return torch.empty_like(q)
58
+
59
+
60
+ @torch.library.custom_op(
61
+ "rbln_custom_ops::paged_flash_attn_prefill",
62
+ mutates_args=(["kcache", "vcache"]),
63
+ )
64
+ def paged_flash_attn_prefill(
65
+ q: Tensor,
66
+ k: Tensor,
67
+ v: Tensor,
68
+ mask: Tensor,
69
+ kcache: Tensor,
70
+ vcache: Tensor,
71
+ seq: Tensor,
72
+ scale: Tensor,
73
+ block_table: Tensor,
74
+ block_size: int,
75
+ partition: int,
76
+ ) -> Tensor:
77
+ """Defines the computation pattern for fused flash attention with KV cache for prefill.
78
+
79
+ Returns a tensor with the same shape as q.
80
+ """
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_prefill.register_fake
85
+ def paged_flash_attn_prefill_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ ) -> Tensor:
98
+ return torch.empty_like(q)
99
+
100
+
101
+ @torch.library.custom_op(
102
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
103
+ mutates_args=(["kcache", "vcache"]),
104
+ )
105
+ def paged_flash_causal_attn_decode(
106
+ q: Tensor,
107
+ k: Tensor,
108
+ v: Tensor,
109
+ kcache: Tensor,
110
+ vcache: Tensor,
111
+ seq: Tensor,
112
+ scale: Tensor,
113
+ block_table: Tensor,
114
+ block_size: int,
115
+ partition: int,
116
+ ) -> Tensor:
117
+ """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
118
+
119
+ Returns a tensor with the same shape as q.
120
+ """
121
+ return torch.empty_like(q)
122
+
123
+
124
+ @paged_flash_causal_attn_decode.register_fake
125
+ def paged_flash_causal_attn_decode_fake(
126
+ q: Tensor,
127
+ k: Tensor,
128
+ v: Tensor,
129
+ kcache: Tensor,
130
+ vcache: Tensor,
131
+ seq: Tensor,
132
+ scale: Tensor,
133
+ block_table: Tensor,
134
+ block_size: int,
135
+ partition: int,
136
+ ) -> Tensor:
137
+ return torch.empty_like(q)
138
+
139
+
140
+ @torch.library.custom_op(
141
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
142
+ mutates_args=(["kcache", "vcache"]),
143
+ )
144
+ def paged_flash_causal_attn_prefill(
145
+ q: Tensor,
146
+ k: Tensor,
147
+ v: Tensor,
148
+ kcache: Tensor,
149
+ vcache: Tensor,
150
+ seq: Tensor,
151
+ scale: Tensor,
152
+ block_table: Tensor,
153
+ block_size: int,
154
+ partition: int,
155
+ ) -> Tensor:
156
+ """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
157
+
158
+ Returns a tensor with the same shape as q.
159
+ """
160
+ return torch.empty_like(q)
161
+
162
+
163
+ @paged_flash_causal_attn_prefill.register_fake
164
+ def paged_flash_causal_attn_prefill_fake(
165
+ q: Tensor,
166
+ k: Tensor,
167
+ v: Tensor,
168
+ kcache: Tensor,
169
+ vcache: Tensor,
170
+ seq: Tensor,
171
+ scale: Tensor,
172
+ block_table: Tensor,
173
+ block_size: int,
174
+ partition: int,
175
+ ) -> Tensor:
176
+ return torch.empty_like(q)
@@ -12,49 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
-
17
15
  import torch
18
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
19
-
16
+ from torch import Tensor
20
17
 
21
- if is_torch_greater_or_equal_than_2_4:
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
18
 
26
-
27
- @lru_cache
28
- def register_rbln_custom_cache_update():
19
+ @torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
20
+ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
21
  # Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
30
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
31
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
32
- torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
33
-
34
- # Implementation of the "rbln_cache_update" operation for the CPU.
35
- @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
36
- def rbln_cache_update_cpu(cache, state, position, axis):
37
- assert position.dim() == 0
38
- assert axis.dim() == 0
39
-
40
- # Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
41
- s = position # Start index for the update, specified by the position.
42
- e = (
43
- position + state.shape[axis]
44
- ) # End index is determined by adding the size of the state along the given axis.
45
-
46
- # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
47
- # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
48
- cache.slice_scatter(state, dim=axis, start=s, end=e)
49
-
50
- # 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
51
- return torch.empty([256])
52
-
53
- # Register a "fake" implementation of the "rbln_cache_update" operation.
54
- # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
55
- @register_fake("rbln_custom_ops::rbln_cache_update")
56
- def rbln_cache_update_abstract(cache, state, position, axis):
57
- # Return a tensor with the same shape as the input cache tensor.
58
- # This is a placeholder for the abstract implementation and does not perform any actual computation.
59
- # Like the actual implementation, the abstraction assumes in-place device-side updates.
60
- return torch.empty([256])
24
+ return torch.empty_like(cache)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
15
16
  from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -12,4 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import (
16
+ paged_attn_decode,
17
+ paged_attn_prefill,
18
+ paged_causal_attn_decode,
19
+ paged_causal_attn_prefill,
20
+ paged_flash_attn_decode,
21
+ paged_flash_attn_prefill,
22
+ paged_flash_causal_attn_decode,
23
+ paged_flash_causal_attn_prefill,
24
+ )
15
25
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -19,12 +19,6 @@ import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
- from ....ops import (
23
- register_rbln_custom_paged_attention,
24
- register_rbln_custom_paged_causal_attention,
25
- register_rbln_custom_paged_flash_attention,
26
- register_rbln_custom_paged_flash_causal_attention,
27
- )
28
22
  from ....utils import logging
29
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
30
24
 
@@ -162,16 +156,8 @@ class DecoderOnlyWrapper(nn.Module):
162
156
  self.use_attention_mask = use_attention_mask
163
157
  if self.attn_impl == "flash_attn":
164
158
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
165
- if self.use_attention_mask:
166
- register_rbln_custom_paged_flash_attention()
167
- else:
168
- register_rbln_custom_paged_flash_causal_attention()
169
159
  elif self.attn_impl == "eager":
170
160
  self.kvcache_partition_len = None
171
- if self.use_attention_mask:
172
- register_rbln_custom_paged_attention()
173
- else:
174
- register_rbln_custom_paged_causal_attention()
175
161
  else:
176
162
  raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
177
163
 
@@ -756,55 +742,55 @@ class AttentionOp(nn.Module):
756
742
  if self.phase == "decode":
757
743
  if self.use_attention_mask:
758
744
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
759
- query_state,
760
- key_state,
761
- value_state,
762
- attn_mask,
763
- past_key_state.unsqueeze(2),
764
- past_value_state.unsqueeze(2),
765
- seq_position,
766
- scale,
767
- block_tables,
768
- block_size,
745
+ q=query_state,
746
+ k=key_state,
747
+ v=value_state,
748
+ mask=attn_mask,
749
+ kcache=past_key_state.unsqueeze(2),
750
+ vcache=past_value_state.unsqueeze(2),
751
+ seq=seq_position,
752
+ scale=scale,
753
+ block_table=block_tables,
754
+ block_size=block_size,
769
755
  )
770
756
  else:
771
757
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
772
- query_state,
773
- key_state,
774
- value_state,
775
- past_key_state.unsqueeze(2),
776
- past_value_state.unsqueeze(2),
777
- seq_position,
778
- scale,
779
- block_tables,
780
- block_size,
758
+ q=query_state,
759
+ k=key_state,
760
+ v=value_state,
761
+ kcache=past_key_state.unsqueeze(2),
762
+ vcache=past_value_state.unsqueeze(2),
763
+ seq=seq_position,
764
+ scale=scale,
765
+ block_table=block_tables,
766
+ block_size=block_size,
781
767
  )
782
768
 
783
769
  else:
784
770
  if self.use_attention_mask:
785
771
  attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
786
- query_state,
787
- key_state,
788
- value_state,
789
- attn_mask,
790
- past_key_state.unsqueeze(2),
791
- past_value_state.unsqueeze(2),
792
- seq_position,
793
- scale,
794
- block_tables,
795
- block_size,
772
+ q=query_state,
773
+ k=key_state,
774
+ v=value_state,
775
+ mask=attn_mask,
776
+ kcache=past_key_state.unsqueeze(2),
777
+ vcache=past_value_state.unsqueeze(2),
778
+ seq=seq_position,
779
+ scale=scale,
780
+ block_table=block_tables,
781
+ block_size=block_size,
796
782
  )
797
783
  else:
798
784
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
799
- query_state,
800
- key_state,
801
- value_state,
802
- past_key_state.unsqueeze(2),
803
- past_value_state.unsqueeze(2),
804
- seq_position,
805
- scale,
806
- block_tables,
807
- block_size,
785
+ q=query_state,
786
+ k=key_state,
787
+ v=value_state,
788
+ kcache=past_key_state.unsqueeze(2),
789
+ vcache=past_value_state.unsqueeze(2),
790
+ seq=seq_position,
791
+ scale=scale,
792
+ block_table=block_tables,
793
+ block_size=block_size,
808
794
  )
809
795
 
810
796
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -1015,58 +1001,58 @@ class FlashAttentionOp(AttentionOp):
1015
1001
  if self.phase == "decode":
1016
1002
  if self.use_attention_mask:
1017
1003
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1018
- query_state,
1019
- key_state,
1020
- value_state,
1021
- attn_mask,
1022
- past_key_state.unsqueeze(2),
1023
- past_value_state.unsqueeze(2),
1024
- seq_position,
1025
- scale,
1026
- block_tables,
1027
- kvcache_block_size,
1028
- self.kvcache_partition_size,
1004
+ q=query_state,
1005
+ k=key_state,
1006
+ v=value_state,
1007
+ mask=attn_mask,
1008
+ kcache=past_key_state.unsqueeze(2),
1009
+ vcache=past_value_state.unsqueeze(2),
1010
+ seq=seq_position,
1011
+ scale=scale,
1012
+ block_table=block_tables,
1013
+ block_size=kvcache_block_size,
1014
+ partition=self.kvcache_partition_size,
1029
1015
  )
1030
1016
  else:
1031
1017
  attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1032
- query_state,
1033
- key_state,
1034
- value_state,
1035
- past_key_state.unsqueeze(2),
1036
- past_value_state.unsqueeze(2),
1037
- seq_position,
1038
- scale,
1039
- block_tables,
1040
- kvcache_block_size,
1041
- self.kvcache_partition_size,
1018
+ q=query_state,
1019
+ k=key_state,
1020
+ v=value_state,
1021
+ kcache=past_key_state.unsqueeze(2),
1022
+ vcache=past_value_state.unsqueeze(2),
1023
+ seq=seq_position,
1024
+ scale=scale,
1025
+ block_table=block_tables,
1026
+ block_size=kvcache_block_size,
1027
+ partition=self.kvcache_partition_size,
1042
1028
  )
1043
1029
  else:
1044
1030
  if self.use_attention_mask:
1045
1031
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1046
- query_state,
1047
- key_state,
1048
- value_state,
1049
- attn_mask,
1050
- past_key_state.unsqueeze(2),
1051
- past_value_state.unsqueeze(2),
1052
- seq_position,
1053
- scale,
1054
- block_tables,
1055
- kvcache_block_size,
1056
- self.kvcache_partition_size,
1032
+ q=query_state,
1033
+ k=key_state,
1034
+ v=value_state,
1035
+ mask=attn_mask,
1036
+ kcache=past_key_state.unsqueeze(2),
1037
+ vcache=past_value_state.unsqueeze(2),
1038
+ seq=seq_position,
1039
+ scale=scale,
1040
+ block_table=block_tables,
1041
+ block_size=kvcache_block_size,
1042
+ partition=self.kvcache_partition_size,
1057
1043
  )
1058
1044
  else:
1059
1045
  attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1060
- query_state,
1061
- key_state,
1062
- value_state,
1063
- past_key_state.unsqueeze(2),
1064
- past_value_state.unsqueeze(2),
1065
- seq_position,
1066
- scale,
1067
- block_tables,
1068
- kvcache_block_size,
1069
- self.kvcache_partition_size,
1046
+ q=query_state,
1047
+ k=key_state,
1048
+ v=value_state,
1049
+ kcache=past_key_state.unsqueeze(2),
1050
+ vcache=past_value_state.unsqueeze(2),
1051
+ seq=seq_position,
1052
+ scale=scale,
1053
+ block_table=block_tables,
1054
+ block_size=kvcache_block_size,
1055
+ partition=self.kvcache_partition_size,
1070
1056
  )
1071
1057
 
1072
1058
  # reshape for removing repeat_kv
@@ -247,19 +247,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
247
247
  enc_input_info = [
248
248
  ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
249
249
  ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
250
- (
251
- "cross_key_value_states",
252
- [
253
- n_layer * 2,
254
- rbln_batch_size,
255
- n_head,
256
- rbln_enc_max_seq_len,
257
- d_kv,
258
- ],
259
- "float32",
260
- ),
261
250
  ("block_tables", [1], "int16"),
262
251
  ]
252
+ enc_input_info.extend(
253
+ [
254
+ (
255
+ f"cross_key_value_states_{i}",
256
+ [
257
+ rbln_batch_size,
258
+ n_head,
259
+ rbln_enc_max_seq_len,
260
+ d_kv,
261
+ ],
262
+ "float32",
263
+ )
264
+ for i in range(n_layer * 2)
265
+ ]
266
+ )
263
267
 
264
268
  dec_input_info = [
265
269
  ("input_ids", [rbln_batch_size, 1], "int64"),
@@ -274,9 +278,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
274
278
  dec_input_info.extend(
275
279
  [
276
280
  (
277
- "cross_key_value_states",
281
+ f"cross_key_value_states_{i}",
278
282
  [
279
- n_layer * 2,
280
283
  rbln_batch_size,
281
284
  n_head,
282
285
  rbln_enc_max_seq_len,
@@ -284,6 +287,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
284
287
  ],
285
288
  "float32",
286
289
  )
290
+ for i in range(n_layer * 2)
287
291
  ]
288
292
  )
289
293
  dec_input_info.extend(
@@ -18,12 +18,6 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import (
22
- register_rbln_custom_cache_update,
23
- register_rbln_custom_paged_attention,
24
- register_rbln_custom_paged_causal_attention,
25
- )
26
-
27
21
 
28
22
  logger = logging.get_logger(__name__)
29
23
 
@@ -59,7 +53,6 @@ class Seq2SeqEncoderWrapper(nn.Module):
59
53
 
60
54
  def __init__(self, model: nn.Module, enc_max_seq_len: int):
61
55
  super().__init__()
62
- register_rbln_custom_cache_update()
63
56
  self.config = model.config
64
57
  self.encoder = model.get_encoder()
65
58
  self.encoder_max_length = enc_max_seq_len
@@ -90,8 +83,8 @@ class Seq2SeqEncoderWrapper(nn.Module):
90
83
  self,
91
84
  input_ids: torch.Tensor,
92
85
  attention_mask: torch.Tensor,
93
- cross_key_values: torch.Tensor,
94
86
  b_idx: torch.Tensor,
87
+ *cross_key_values: Tuple[torch.Tensor],
95
88
  ) -> Tuple[torch.Tensor]:
96
89
  # 1. get encoder last_hidden_states
97
90
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -110,13 +103,15 @@ class Seq2SeqEncoderWrapper(nn.Module):
110
103
  cross_kv.append(past_k)
111
104
  cross_kv.append(past_v)
112
105
 
113
- cross_kv = torch.stack(cross_kv, dim=0)
114
-
115
106
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
116
- batch_axis = torch.tensor(1, dtype=torch.int16)
117
- enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
107
+ batch_axis = torch.tensor(0, dtype=torch.int16)
108
+ cross_key_values = list(cross_key_values)
109
+ for i in range(self.n_layer * 2):
110
+ cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
111
+ cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
112
+ )
118
113
 
119
- return enc_out
114
+ return cross_key_values
120
115
 
121
116
 
122
117
  class Seq2SeqDecoderWrapper(nn.Module):
@@ -146,11 +141,6 @@ class Seq2SeqDecoderWrapper(nn.Module):
146
141
  It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
147
142
  by subclasses to modify or add custom attributes as necessary.
148
143
  """
149
- if self.use_attention_mask:
150
- register_rbln_custom_paged_attention()
151
- else:
152
- register_rbln_custom_paged_causal_attention()
153
-
154
144
  self.num_layers = self.config.decoder_layers
155
145
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
156
146
 
@@ -176,16 +166,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
176
166
  encoder_attention_mask,
177
167
  cache_position,
178
168
  block_tables,
179
- cross_kv_cache,
180
- *self_kv_cache,
169
+ *kv_cache,
181
170
  ) = args
182
171
 
183
172
  else:
184
173
  attention_mask = None
185
- (input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
174
+ (input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
186
175
 
187
176
  self_past_key_values = ()
188
177
  cross_past_key_values = ()
178
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
179
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
189
180
  for i in range(0, self.num_layers * 2, 2):
190
181
  self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
191
182
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)