optimum-rbln 0.7.4a1__py3-none-any.whl → 0.7.4a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. optimum/rbln/__version__.py +1 -1
  2. optimum/rbln/modeling.py +8 -1
  3. optimum/rbln/ops/__init__.py +3 -7
  4. optimum/rbln/ops/attn.py +271 -207
  5. optimum/rbln/ops/flash_attn.py +161 -67
  6. optimum/rbln/ops/kv_cache_update.py +4 -40
  7. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  8. optimum/rbln/transformers/models/decoderonly/__init__.py +10 -0
  9. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +80 -94
  10. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +39 -20
  11. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +17 -13
  12. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  13. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  14. optimum/rbln/transformers/models/t5/t5_architecture.py +3 -4
  15. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  16. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +12 -22
  17. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  18. optimum/rbln/transformers/models/whisper/modeling_whisper.py +0 -1
  19. optimum/rbln/transformers/models/whisper/whisper_architecture.py +22 -34
  20. {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/METADATA +1 -1
  21. {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/RECORD +23 -23
  22. {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/WHEEL +0 -0
  23. {optimum_rbln-0.7.4a1.dist-info → optimum_rbln-0.7.4a3.dist-info}/licenses/LICENSE +0 -0
@@ -12,71 +12,165 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
-
17
15
  import torch
18
- from packaging import version
19
-
20
-
21
- if version.parse(torch.__version__) > version.parse("2.4.0"):
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
-
26
-
27
- @lru_cache
28
- def register_rbln_custom_paged_flash_attention():
29
- torch.library.define(
30
- "rbln_custom_ops::paged_flash_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
32
- )
33
-
34
- @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
- return q
37
-
38
- @register_fake("rbln_custom_ops::paged_flash_attn_decode")
39
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
40
- return q
41
-
42
- torch.library.define(
43
- "rbln_custom_ops::paged_flash_attn_prefill",
44
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
45
- )
46
-
47
- @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
48
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
49
- return q
50
-
51
- @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
52
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
53
- return q
54
-
55
-
56
- @lru_cache
57
- def register_rbln_custom_paged_flash_causal_attention():
58
- torch.library.define(
59
- "rbln_custom_ops::paged_flash_causal_attn_decode",
60
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
61
- )
62
-
63
- @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
64
- def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
65
- return q
66
-
67
- @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
68
- def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
69
- return q
70
-
71
- torch.library.define(
72
- "rbln_custom_ops::paged_flash_causal_attn_prefill",
73
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
74
- )
75
-
76
- @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
77
- def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
78
- return q
79
-
80
- @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
81
- def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
82
- return q
16
+ from torch import Tensor
17
+
18
+
19
+ @torch.library.custom_op(
20
+ "rbln_custom_ops::paged_flash_attn_decode",
21
+ mutates_args=(["kcache", "vcache"]),
22
+ )
23
+ def paged_flash_attn_decode(
24
+ q: Tensor,
25
+ k: Tensor,
26
+ v: Tensor,
27
+ mask: Tensor,
28
+ kcache: Tensor,
29
+ vcache: Tensor,
30
+ seq: Tensor,
31
+ scale: Tensor,
32
+ block_table: Tensor,
33
+ block_size: int,
34
+ partition: int,
35
+ ) -> Tensor:
36
+ """Defines the computation pattern for fused flash attention with KV cache for decoding.
37
+
38
+ Returns a tensor with the same shape as q.
39
+ """
40
+ return torch.empty_like(q)
41
+
42
+
43
+ @paged_flash_attn_decode.register_fake
44
+ def paged_flash_attn_decode_fake(
45
+ q: Tensor,
46
+ k: Tensor,
47
+ v: Tensor,
48
+ mask: Tensor,
49
+ kcache: Tensor,
50
+ vcache: Tensor,
51
+ seq: Tensor,
52
+ scale: Tensor,
53
+ block_table: Tensor,
54
+ block_size: int,
55
+ partition: int,
56
+ ) -> Tensor:
57
+ return torch.empty_like(q)
58
+
59
+
60
+ @torch.library.custom_op(
61
+ "rbln_custom_ops::paged_flash_attn_prefill",
62
+ mutates_args=(["kcache", "vcache"]),
63
+ )
64
+ def paged_flash_attn_prefill(
65
+ q: Tensor,
66
+ k: Tensor,
67
+ v: Tensor,
68
+ mask: Tensor,
69
+ kcache: Tensor,
70
+ vcache: Tensor,
71
+ seq: Tensor,
72
+ scale: Tensor,
73
+ block_table: Tensor,
74
+ block_size: int,
75
+ partition: int,
76
+ ) -> Tensor:
77
+ """Defines the computation pattern for fused flash attention with KV cache for prefill.
78
+
79
+ Returns a tensor with the same shape as q.
80
+ """
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_prefill.register_fake
85
+ def paged_flash_attn_prefill_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ ) -> Tensor:
98
+ return torch.empty_like(q)
99
+
100
+
101
+ @torch.library.custom_op(
102
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
103
+ mutates_args=(["kcache", "vcache"]),
104
+ )
105
+ def paged_flash_causal_attn_decode(
106
+ q: Tensor,
107
+ k: Tensor,
108
+ v: Tensor,
109
+ kcache: Tensor,
110
+ vcache: Tensor,
111
+ seq: Tensor,
112
+ scale: Tensor,
113
+ block_table: Tensor,
114
+ block_size: int,
115
+ partition: int,
116
+ ) -> Tensor:
117
+ """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
118
+
119
+ Returns a tensor with the same shape as q.
120
+ """
121
+ return torch.empty_like(q)
122
+
123
+
124
+ @paged_flash_causal_attn_decode.register_fake
125
+ def paged_flash_causal_attn_decode_fake(
126
+ q: Tensor,
127
+ k: Tensor,
128
+ v: Tensor,
129
+ kcache: Tensor,
130
+ vcache: Tensor,
131
+ seq: Tensor,
132
+ scale: Tensor,
133
+ block_table: Tensor,
134
+ block_size: int,
135
+ partition: int,
136
+ ) -> Tensor:
137
+ return torch.empty_like(q)
138
+
139
+
140
+ @torch.library.custom_op(
141
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
142
+ mutates_args=(["kcache", "vcache"]),
143
+ )
144
+ def paged_flash_causal_attn_prefill(
145
+ q: Tensor,
146
+ k: Tensor,
147
+ v: Tensor,
148
+ kcache: Tensor,
149
+ vcache: Tensor,
150
+ seq: Tensor,
151
+ scale: Tensor,
152
+ block_table: Tensor,
153
+ block_size: int,
154
+ partition: int,
155
+ ) -> Tensor:
156
+ """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
157
+
158
+ Returns a tensor with the same shape as q.
159
+ """
160
+ return torch.empty_like(q)
161
+
162
+
163
+ @paged_flash_causal_attn_prefill.register_fake
164
+ def paged_flash_causal_attn_prefill_fake(
165
+ q: Tensor,
166
+ k: Tensor,
167
+ v: Tensor,
168
+ kcache: Tensor,
169
+ vcache: Tensor,
170
+ seq: Tensor,
171
+ scale: Tensor,
172
+ block_table: Tensor,
173
+ block_size: int,
174
+ partition: int,
175
+ ) -> Tensor:
176
+ return torch.empty_like(q)
@@ -12,49 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
-
17
15
  import torch
18
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
19
-
16
+ from torch import Tensor
20
17
 
21
- if is_torch_greater_or_equal_than_2_4:
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
18
 
26
-
27
- @lru_cache
28
- def register_rbln_custom_cache_update():
19
+ @torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
20
+ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
21
  # Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
30
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
31
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
32
- torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
33
-
34
- # Implementation of the "rbln_cache_update" operation for the CPU.
35
- @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
36
- def rbln_cache_update_cpu(cache, state, position, axis):
37
- assert position.dim() == 0
38
- assert axis.dim() == 0
39
-
40
- # Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
41
- s = position # Start index for the update, specified by the position.
42
- e = (
43
- position + state.shape[axis]
44
- ) # End index is determined by adding the size of the state along the given axis.
45
-
46
- # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
47
- # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
48
- cache.slice_scatter(state, dim=axis, start=s, end=e)
49
-
50
- # 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
51
- return torch.empty([256])
52
-
53
- # Register a "fake" implementation of the "rbln_cache_update" operation.
54
- # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
55
- @register_fake("rbln_custom_ops::rbln_cache_update")
56
- def rbln_cache_update_abstract(cache, state, position, axis):
57
- # Return a tensor with the same shape as the input cache tensor.
58
- # This is a placeholder for the abstract implementation and does not perform any actual computation.
59
- # Like the actual implementation, the abstraction assumes in-place device-side updates.
60
- return torch.empty([256])
24
+ return torch.empty_like(cache)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
15
16
  from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -12,4 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import (
16
+ paged_attn_decode,
17
+ paged_attn_prefill,
18
+ paged_causal_attn_decode,
19
+ paged_causal_attn_prefill,
20
+ paged_flash_attn_decode,
21
+ paged_flash_attn_prefill,
22
+ paged_flash_causal_attn_decode,
23
+ paged_flash_causal_attn_prefill,
24
+ )
15
25
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -19,12 +19,6 @@ import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
- from ....ops import (
23
- register_rbln_custom_paged_attention,
24
- register_rbln_custom_paged_causal_attention,
25
- register_rbln_custom_paged_flash_attention,
26
- register_rbln_custom_paged_flash_causal_attention,
27
- )
28
22
  from ....utils import logging
29
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
30
24
 
@@ -162,16 +156,8 @@ class DecoderOnlyWrapper(nn.Module):
162
156
  self.use_attention_mask = use_attention_mask
163
157
  if self.attn_impl == "flash_attn":
164
158
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
165
- if self.use_attention_mask:
166
- register_rbln_custom_paged_flash_attention()
167
- else:
168
- register_rbln_custom_paged_flash_causal_attention()
169
159
  elif self.attn_impl == "eager":
170
160
  self.kvcache_partition_len = None
171
- if self.use_attention_mask:
172
- register_rbln_custom_paged_attention()
173
- else:
174
- register_rbln_custom_paged_causal_attention()
175
161
  else:
176
162
  raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
177
163
 
@@ -756,55 +742,55 @@ class AttentionOp(nn.Module):
756
742
  if self.phase == "decode":
757
743
  if self.use_attention_mask:
758
744
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
759
- query_state,
760
- key_state,
761
- value_state,
762
- attn_mask,
763
- past_key_state.unsqueeze(2),
764
- past_value_state.unsqueeze(2),
765
- seq_position,
766
- scale,
767
- block_tables,
768
- block_size,
745
+ q=query_state,
746
+ k=key_state,
747
+ v=value_state,
748
+ mask=attn_mask,
749
+ kcache=past_key_state.unsqueeze(2),
750
+ vcache=past_value_state.unsqueeze(2),
751
+ seq=seq_position,
752
+ scale=scale,
753
+ block_table=block_tables,
754
+ block_size=block_size,
769
755
  )
770
756
  else:
771
757
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
772
- query_state,
773
- key_state,
774
- value_state,
775
- past_key_state.unsqueeze(2),
776
- past_value_state.unsqueeze(2),
777
- seq_position,
778
- scale,
779
- block_tables,
780
- block_size,
758
+ q=query_state,
759
+ k=key_state,
760
+ v=value_state,
761
+ kcache=past_key_state.unsqueeze(2),
762
+ vcache=past_value_state.unsqueeze(2),
763
+ seq=seq_position,
764
+ scale=scale,
765
+ block_table=block_tables,
766
+ block_size=block_size,
781
767
  )
782
768
 
783
769
  else:
784
770
  if self.use_attention_mask:
785
771
  attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
786
- query_state,
787
- key_state,
788
- value_state,
789
- attn_mask,
790
- past_key_state.unsqueeze(2),
791
- past_value_state.unsqueeze(2),
792
- seq_position,
793
- scale,
794
- block_tables,
795
- block_size,
772
+ q=query_state,
773
+ k=key_state,
774
+ v=value_state,
775
+ mask=attn_mask,
776
+ kcache=past_key_state.unsqueeze(2),
777
+ vcache=past_value_state.unsqueeze(2),
778
+ seq=seq_position,
779
+ scale=scale,
780
+ block_table=block_tables,
781
+ block_size=block_size,
796
782
  )
797
783
  else:
798
784
  attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
799
- query_state,
800
- key_state,
801
- value_state,
802
- past_key_state.unsqueeze(2),
803
- past_value_state.unsqueeze(2),
804
- seq_position,
805
- scale,
806
- block_tables,
807
- block_size,
785
+ q=query_state,
786
+ k=key_state,
787
+ v=value_state,
788
+ kcache=past_key_state.unsqueeze(2),
789
+ vcache=past_value_state.unsqueeze(2),
790
+ seq=seq_position,
791
+ scale=scale,
792
+ block_table=block_tables,
793
+ block_size=block_size,
808
794
  )
809
795
 
810
796
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -1015,58 +1001,58 @@ class FlashAttentionOp(AttentionOp):
1015
1001
  if self.phase == "decode":
1016
1002
  if self.use_attention_mask:
1017
1003
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1018
- query_state,
1019
- key_state,
1020
- value_state,
1021
- attn_mask,
1022
- past_key_state.unsqueeze(2),
1023
- past_value_state.unsqueeze(2),
1024
- seq_position,
1025
- scale,
1026
- block_tables,
1027
- kvcache_block_size,
1028
- self.kvcache_partition_size,
1004
+ q=query_state,
1005
+ k=key_state,
1006
+ v=value_state,
1007
+ mask=attn_mask,
1008
+ kcache=past_key_state.unsqueeze(2),
1009
+ vcache=past_value_state.unsqueeze(2),
1010
+ seq=seq_position,
1011
+ scale=scale,
1012
+ block_table=block_tables,
1013
+ block_size=kvcache_block_size,
1014
+ partition=self.kvcache_partition_size,
1029
1015
  )
1030
1016
  else:
1031
1017
  attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1032
- query_state,
1033
- key_state,
1034
- value_state,
1035
- past_key_state.unsqueeze(2),
1036
- past_value_state.unsqueeze(2),
1037
- seq_position,
1038
- scale,
1039
- block_tables,
1040
- kvcache_block_size,
1041
- self.kvcache_partition_size,
1018
+ q=query_state,
1019
+ k=key_state,
1020
+ v=value_state,
1021
+ kcache=past_key_state.unsqueeze(2),
1022
+ vcache=past_value_state.unsqueeze(2),
1023
+ seq=seq_position,
1024
+ scale=scale,
1025
+ block_table=block_tables,
1026
+ block_size=kvcache_block_size,
1027
+ partition=self.kvcache_partition_size,
1042
1028
  )
1043
1029
  else:
1044
1030
  if self.use_attention_mask:
1045
1031
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1046
- query_state,
1047
- key_state,
1048
- value_state,
1049
- attn_mask,
1050
- past_key_state.unsqueeze(2),
1051
- past_value_state.unsqueeze(2),
1052
- seq_position,
1053
- scale,
1054
- block_tables,
1055
- kvcache_block_size,
1056
- self.kvcache_partition_size,
1032
+ q=query_state,
1033
+ k=key_state,
1034
+ v=value_state,
1035
+ mask=attn_mask,
1036
+ kcache=past_key_state.unsqueeze(2),
1037
+ vcache=past_value_state.unsqueeze(2),
1038
+ seq=seq_position,
1039
+ scale=scale,
1040
+ block_table=block_tables,
1041
+ block_size=kvcache_block_size,
1042
+ partition=self.kvcache_partition_size,
1057
1043
  )
1058
1044
  else:
1059
1045
  attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1060
- query_state,
1061
- key_state,
1062
- value_state,
1063
- past_key_state.unsqueeze(2),
1064
- past_value_state.unsqueeze(2),
1065
- seq_position,
1066
- scale,
1067
- block_tables,
1068
- kvcache_block_size,
1069
- self.kvcache_partition_size,
1046
+ q=query_state,
1047
+ k=key_state,
1048
+ v=value_state,
1049
+ kcache=past_key_state.unsqueeze(2),
1050
+ vcache=past_value_state.unsqueeze(2),
1051
+ seq=seq_position,
1052
+ scale=scale,
1053
+ block_table=block_tables,
1054
+ block_size=kvcache_block_size,
1055
+ partition=self.kvcache_partition_size,
1070
1056
  )
1071
1057
 
1072
1058
  # reshape for removing repeat_kv
@@ -578,11 +578,41 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
578
578
  nbits_per_param: int,
579
579
  n_model_params: int,
580
580
  ) -> int:
581
+ """
582
+ We are finding max_n_blocks(x) that satisfies the following equation:
583
+
584
+ available_dram - kernel_size - buffer
585
+ - num_layers * 2 * tensor_parallel_size
586
+ * align_2MB(
587
+ x
588
+ * block_size
589
+ * align_64(head_dim)
590
+ * math.ceil(num_key_value_heads / tensor_parallel_size)
591
+ * 2
592
+ ) > 0
593
+
594
+ This inequality can be rewritten as follows:
595
+
596
+ a - c * align_2MB(b * x) > 0
597
+ where
598
+ a = available_dram - kernel_size - buffer
599
+ b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
600
+ c = num_layers * 2 * tensor_parallel_size
601
+
602
+ We can rewrite the inequality as follows:
603
+ k > align_2MB(b*x)
604
+ where
605
+ k = a / c
606
+
607
+ After that, we can derive the following equation:
608
+ x = floor(2**21 / b * floor((k - 1) / 2**21))
609
+ """
610
+
581
611
  def align(x: int, nbytes: int) -> int:
582
612
  return int(math.ceil(x / nbytes) * nbytes)
583
613
 
584
614
  def align_2MB(x: int) -> int:
585
- return align(x, 2 * 1024 * 1024)
615
+ return align(x, 2**21)
586
616
 
587
617
  num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
588
618
  num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
@@ -612,27 +642,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
612
642
  available_dram -= kernel_size
613
643
 
614
644
  # TODO: Accurate buffer estimation
615
- buffer = 2**30 # 1GB Buffer
616
- if tensor_parallel_size <= 4:
617
- buffer /= 4
618
-
645
+ buffer_per_core = 2**29 # 500MB per npu
646
+ buffer = buffer_per_core * tensor_parallel_size
619
647
  available_dram -= buffer
620
648
 
621
- # Estimate nbytes per a single kvcache block
622
- nbytes_per_block = (
623
- align_2MB(
624
- kvcache_block_size
625
- * head_dim
626
- * math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
627
- * 2 # (fp16)
628
- )
629
- * num_layers
630
- * 2 # (k, v)
631
- * tensor_parallel_size
632
- )
633
- n_blocks = available_dram // nbytes_per_block
649
+ b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
650
+ c = num_layers * 2 * tensor_parallel_size
651
+ k = available_dram / c
652
+ max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
634
653
 
635
- return n_blocks, nbytes_per_block
654
+ return max_n_blocks
636
655
 
637
656
  @classmethod
638
657
  def _get_rbln_config(
@@ -689,7 +708,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
689
708
 
690
709
  rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
691
710
  if rbln_attn_impl == "flash_attn":
692
- max_num_blocks, _ = cls.get_maximum_num_blocks(
711
+ max_num_blocks = cls.get_maximum_num_blocks(
693
712
  config=model_config,
694
713
  tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
695
714
  kvcache_block_size=rbln_kvcache_block_size,
@@ -247,19 +247,23 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
247
247
  enc_input_info = [
248
248
  ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
249
249
  ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
250
- (
251
- "cross_key_value_states",
252
- [
253
- n_layer * 2,
254
- rbln_batch_size,
255
- n_head,
256
- rbln_enc_max_seq_len,
257
- d_kv,
258
- ],
259
- "float32",
260
- ),
261
250
  ("block_tables", [1], "int16"),
262
251
  ]
252
+ enc_input_info.extend(
253
+ [
254
+ (
255
+ f"cross_key_value_states_{i}",
256
+ [
257
+ rbln_batch_size,
258
+ n_head,
259
+ rbln_enc_max_seq_len,
260
+ d_kv,
261
+ ],
262
+ "float32",
263
+ )
264
+ for i in range(n_layer * 2)
265
+ ]
266
+ )
263
267
 
264
268
  dec_input_info = [
265
269
  ("input_ids", [rbln_batch_size, 1], "int64"),
@@ -274,9 +278,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
274
278
  dec_input_info.extend(
275
279
  [
276
280
  (
277
- "cross_key_value_states",
281
+ f"cross_key_value_states_{i}",
278
282
  [
279
- n_layer * 2,
280
283
  rbln_batch_size,
281
284
  n_head,
282
285
  rbln_enc_max_seq_len,
@@ -284,6 +287,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
284
287
  ],
285
288
  "float32",
286
289
  )
290
+ for i in range(n_layer * 2)
287
291
  ]
288
292
  )
289
293
  dec_input_info.extend(